From c7621ec0e4937151d8857e1041d49600221fd11c Mon Sep 17 00:00:00 2001
From: Max Kellermann <max.kellermann@gmail.com>
Date: Wed, 29 Jan 2025 17:58:36 +0100
Subject: [PATCH] net/PeerCredentials: wrapper for struct ucred

---
 src/event/ServerSocket.cxx   | 16 +++-----
 src/net/PeerCredentials.hxx  | 71 ++++++++++++++++++++++++++++++++++++
 src/net/SocketDescriptor.cxx | 17 +++++----
 src/net/SocketDescriptor.hxx | 11 ++----
 4 files changed, 90 insertions(+), 25 deletions(-)
 create mode 100644 src/net/PeerCredentials.hxx

diff --git a/src/event/ServerSocket.cxx b/src/event/ServerSocket.cxx
index 57fd6ae46..0f6f06cc5 100644
--- a/src/event/ServerSocket.cxx
+++ b/src/event/ServerSocket.cxx
@@ -9,6 +9,7 @@
 #include "net/IPv6Address.hxx"
 #include "net/StaticSocketAddress.hxx"
 #include "net/AllocatedSocketAddress.hxx"
+#include "net/PeerCredentials.hxx"
 #include "net/SocketUtil.hxx"
 #include "net/SocketError.hxx"
 #include "net/UniqueSocketDescriptor.hxx"
@@ -104,13 +105,6 @@ static constexpr Domain server_socket_domain("server_socket");
 static int
 get_remote_uid(SocketDescriptor s) noexcept
 {
-#ifdef HAVE_STRUCT_UCRED
-	const auto cred = s.GetPeerCredentials();
-	if (cred.pid < 0)
-		return -1;
-
-	return cred.uid;
-#else
 #ifdef HAVE_GETPEEREID
 	uid_t euid;
 	gid_t egid;
@@ -118,9 +112,11 @@ get_remote_uid(SocketDescriptor s) noexcept
 	if (getpeereid(s.Get(), &euid, &egid) == 0)
 		return euid;
 #else
-	(void)s;
-#endif
-	return -1;
+	const auto cred = s.GetPeerCredentials();
+	if (!cred.IsDefined())
+		return -1;
+
+	return cred.GetUid();
 #endif
 }
 
diff --git a/src/net/PeerCredentials.hxx b/src/net/PeerCredentials.hxx
new file mode 100644
index 000000000..510f60f3b
--- /dev/null
+++ b/src/net/PeerCredentials.hxx
@@ -0,0 +1,71 @@
+// SPDX-License-Identifier: BSD-2-Clause
+// author: Max Kellermann <max.kellermann@gmail.com>
+
+#pragma once
+
+#include "Features.hxx" // for HAVE_STRUCT_UCRED
+
+#include <type_traits> // for std::is_trivial_v
+
+#ifdef HAVE_STRUCT_UCRED
+#include <sys/socket.h> // for struct ucred
+#endif
+
+/**
+ * Portable wrapper for credentials of the process on the other side
+ * of a (local) socket.
+ */
+class SocketPeerCredentials {
+	friend class SocketDescriptor;
+
+#ifdef HAVE_STRUCT_UCRED
+	struct ucred cred;
+#endif
+
+public:
+	constexpr SocketPeerCredentials() noexcept = default;
+
+	static constexpr SocketPeerCredentials Undefined() noexcept {
+		SocketPeerCredentials c;
+#ifdef HAVE_STRUCT_UCRED
+		c.cred.pid = 0;
+		c.cred.uid = -1;
+		c.cred.gid = -1;
+#endif
+		return c;
+	}
+
+	constexpr bool IsDefined() const noexcept {
+#ifdef HAVE_STRUCT_UCRED
+		return cred.pid > 0;
+#else
+		return false;
+#endif
+	}
+
+	constexpr auto GetPid() const noexcept {
+#ifdef HAVE_STRUCT_UCRED
+		return cred.pid;
+#else
+		return 0;
+#endif
+	}
+
+	constexpr auto GetUid() const noexcept {
+#ifdef HAVE_STRUCT_UCRED
+		return cred.uid;
+#else
+		return -1;
+#endif
+	}
+
+	constexpr auto GetGid() const noexcept {
+#ifdef HAVE_STRUCT_UCRED
+		return cred.gid;
+#else
+		return -1;
+#endif
+	}
+};
+
+static_assert(std::is_trivial_v<SocketPeerCredentials>);
diff --git a/src/net/SocketDescriptor.cxx b/src/net/SocketDescriptor.cxx
index 64664e133..162ccea0e 100644
--- a/src/net/SocketDescriptor.cxx
+++ b/src/net/SocketDescriptor.cxx
@@ -7,6 +7,7 @@
 #include "IPv4Address.hxx"
 #include "IPv6Address.hxx"
 #include "UniqueSocketDescriptor.hxx"
+#include "PeerCredentials.hxx"
 
 #ifdef __linux__
 #include "io/UniqueFileDescriptor.hxx"
@@ -225,19 +226,19 @@ SocketDescriptor::GetIntOption(int level, int name, int fallback) const noexcept
 	return value;
 }
 
-#ifdef HAVE_STRUCT_UCRED
-
-struct ucred
+SocketPeerCredentials
 SocketDescriptor::GetPeerCredentials() const noexcept
 {
-	struct ucred cred;
+#ifdef HAVE_STRUCT_UCRED
+	SocketPeerCredentials cred;
 	if (GetOption(SOL_SOCKET, SO_PEERCRED,
-		      &cred, sizeof(cred)) < sizeof(cred))
-		cred.pid = -1;
+		      &cred.cred, sizeof(cred.cred)) < sizeof(cred.cred))
+		return SocketPeerCredentials::Undefined();
 	return cred;
-}
-
+#else
+	return SocketPeerCredentials::Undefined();
 #endif
+}
 
 #ifdef __linux__
 
diff --git a/src/net/SocketDescriptor.hxx b/src/net/SocketDescriptor.hxx
index 1f0621c68..2c8123fc9 100644
--- a/src/net/SocketDescriptor.hxx
+++ b/src/net/SocketDescriptor.hxx
@@ -3,8 +3,6 @@
 
 #pragma once
 
-#include "Features.hxx"
-
 #ifndef _WIN32
 #include "io/FileDescriptor.hxx"
 #endif
@@ -20,6 +18,7 @@
 
 struct msghdr;
 struct iovec;
+class SocketPeerCredentials;
 class SocketAddress;
 class StaticSocketAddress;
 class IPv4Address;
@@ -221,14 +220,12 @@ public:
 	[[gnu::pure]]
 	int GetIntOption(int level, int name, int fallback) const noexcept;
 
-#ifdef HAVE_STRUCT_UCRED
 	/**
-	 * Receive peer credentials (SO_PEERCRED).  On error, the pid
-	 * is -1.
+	 * Receive peer credentials (SO_PEERCRED).  On error, an
+	 * "undefined" object is returned.
 	 */
 	[[gnu::pure]]
-	struct ucred GetPeerCredentials() const noexcept;
-#endif
+	SocketPeerCredentials GetPeerCredentials() const noexcept;
 
 #ifdef __linux__
 	/**