diff --git a/src/net/LocalSocketAddress.cxx b/src/net/LocalSocketAddress.cxx new file mode 100644 index 000000000..80b2939ea --- /dev/null +++ b/src/net/LocalSocketAddress.cxx @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: BSD-2-Clause +// author: Max Kellermann + +#include "LocalSocketAddress.hxx" + +const char * +LocalSocketAddress::GetLocalPath() const noexcept +{ + const auto raw = GetLocalRaw(); + return !raw.empty() && + /* must be an absolute path */ + raw.front() == '/' && + /* must be null-terminated and there must not be any + other null byte */ + raw.find('\0') == raw.size() - 1 + ? raw.data() + : nullptr; +} diff --git a/src/net/LocalSocketAddress.hxx b/src/net/LocalSocketAddress.hxx new file mode 100644 index 000000000..e6f5ca86d --- /dev/null +++ b/src/net/LocalSocketAddress.hxx @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: BSD-2-Clause +// author: Max Kellermann + +#pragma once + +#include "SocketAddress.hxx" // IWYU pragma: export + +#include // for std::copy() +#include // for std::length_error +#include + +#include + +/** + * An OO wrapper for struct sockaddr_un. + */ +class LocalSocketAddress { + friend class SocketDescriptor; + +public: + typedef SocketAddress::size_type size_type; + +private: + size_type size; + struct sockaddr_un address; + +public: + constexpr LocalSocketAddress() noexcept = default; + + constexpr explicit LocalSocketAddress(std::string_view path) noexcept + :address{} { + SetLocal(path); + } + + constexpr operator SocketAddress() const noexcept { + return SocketAddress{*this, size}; + } + + constexpr operator struct sockaddr *() noexcept { + return (struct sockaddr *)(void *)&address; + } + + constexpr operator const struct sockaddr *() const noexcept { + return (const struct sockaddr *)(const void *)&address; + } + + constexpr size_type GetCapacity() const noexcept { + return sizeof(address); + } + + constexpr size_type GetSize() const noexcept { + return size; + } + + constexpr int GetFamily() const noexcept { + return address.sun_family; + } + + constexpr bool IsDefined() const noexcept { + return GetFamily() != AF_UNSPEC; + } + + constexpr void Clear() noexcept { + address.sun_family = AF_UNSPEC; + } + + /** + * @see SocketAddress::GetLocalRaw() + */ + constexpr std::string_view GetLocalRaw() const noexcept { + if (GetFamily() != AF_LOCAL) + return {}; + + const auto start = (const char *)&address; + const auto path = address.sun_path; + const size_t header_size = path - start; + if (size < size_type(header_size)) + /* malformed address */ + return {}; + + return {path, size - header_size}; + } + + /** + * @see SocketAddress::GetLocalPath() + */ + [[nodiscard]] [[gnu::pure]] + const char *GetLocalPath() const noexcept; + + /** + * Make this a "local" address (UNIX domain socket). If the path + * begins with a '@', then the rest specifies an "abstract" local + * address. + */ + constexpr LocalSocketAddress &SetLocal(std::string_view path) { + const bool is_abstract = path.starts_with('@'); + + /* sun_path must be null-terminated unless it's an abstract + socket */ + const size_t path_length = path.size() + !is_abstract; + + if (path_length > sizeof(address.sun_path)) + throw std::length_error{"Path is too long"}; + + size = sizeof(address) - sizeof(address.sun_path) + path_length; + + address.sun_family = AF_LOCAL; + auto out = std::copy(path.begin(), path.end(), address.sun_path); + if (is_abstract) + address.sun_path[0] = 0; + else + *out = 0; + + return *this; + } + + [[nodiscard]] [[gnu::pure]] + std::span GetSteadyPart() const noexcept; + + [[nodiscard]] [[gnu::pure]] + bool operator==(SocketAddress other) const noexcept { + return static_cast(*this) == other; + } + + [[nodiscard]] [[gnu::pure]] + bool operator!=(SocketAddress other) const noexcept { + return !(*this == other); + } +}; diff --git a/src/net/meson.build b/src/net/meson.build index 03af2d307..7e52c438c 100644 --- a/src/net/meson.build +++ b/src/net/meson.build @@ -27,6 +27,7 @@ conf.set('HAVE_UN', have_local_socket) if have_local_socket conf.set('HAVE_STRUCT_UCRED', compiler.has_header_symbol('sys/socket.h', 'struct ucred') and compiler.has_header_symbol('sys/socket.h', 'SO_PEERCRED')) conf.set('HAVE_GETPEEREID', compiler.has_function('getpeereid')) + net_sources += 'LocalSocketAddress.cxx' endif if not have_tcp and not have_local_socket diff --git a/test/net/TestLocalSocketAddress.cxx b/test/net/TestLocalSocketAddress.cxx index f7d87c748..94aaaa1ef 100644 --- a/test/net/TestLocalSocketAddress.cxx +++ b/test/net/TestLocalSocketAddress.cxx @@ -2,12 +2,44 @@ // author: Max Kellermann #include "net/AllocatedSocketAddress.hxx" +#include "net/LocalSocketAddress.hxx" #include "net/ToString.hxx" #include #include +TEST(LocalSocketAddress, Path1) +{ + const char *path = "/run/foo/bar.socket"; + LocalSocketAddress a; + a.SetLocal(path); + EXPECT_TRUE(a.IsDefined()); + EXPECT_EQ(a.GetFamily(), AF_LOCAL); + EXPECT_EQ(ToString(a), path); + EXPECT_STREQ(a.GetLocalPath(), path); + + const struct sockaddr *const sa = a; + const auto &sun = *(const struct sockaddr_un *)sa; + EXPECT_STREQ(sun.sun_path, path); + EXPECT_EQ(sun.sun_path + strlen(path) + 1, (const char *)sa + a.GetSize()); +} + +TEST(LocalSocketAddress, Path2) +{ + static constexpr const char *path = "/run/foo/bar.socket"; + static constexpr LocalSocketAddress a{path}; + EXPECT_TRUE(a.IsDefined()); + EXPECT_EQ(a.GetFamily(), AF_LOCAL); + EXPECT_EQ(ToString(a), path); + EXPECT_STREQ(a.GetLocalPath(), path); + + const struct sockaddr *const sa = a; + const auto &sun = *(const struct sockaddr_un *)sa; + EXPECT_STREQ(sun.sun_path, path); + EXPECT_EQ(sun.sun_path + strlen(path) + 1, (const char *)sa + a.GetSize()); +} + TEST(LocalSocketAddress, Path) { const char *path = "/run/foo/bar.socket"; @@ -17,6 +49,7 @@ TEST(LocalSocketAddress, Path) EXPECT_TRUE(a.IsDefined()); EXPECT_EQ(a.GetFamily(), AF_LOCAL); EXPECT_EQ(ToString(a), path); + EXPECT_STREQ(a.GetLocalPath(), path); const auto &sun = *(const struct sockaddr_un *)a.GetAddress(); EXPECT_STREQ(sun.sun_path, path); @@ -25,6 +58,27 @@ TEST(LocalSocketAddress, Path) #ifdef __linux__ +TEST(LocalSocketAddress, Abstract1) +{ + const char *path = "@foo.bar"; + LocalSocketAddress a; + a.SetLocal(path); + EXPECT_TRUE(a.IsDefined()); + EXPECT_EQ(a.GetFamily(), AF_LOCAL); + EXPECT_EQ(ToString(a), path); + EXPECT_EQ(a.GetLocalPath(), nullptr); + + const struct sockaddr *const sa = a; + const auto &sun = *(const struct sockaddr_un *)sa; + + /* Linux's abstract sockets start with a null byte, ... */ + EXPECT_EQ(sun.sun_path[0], 0); + + /* ... but are not null-terminated */ + EXPECT_EQ(memcmp(sun.sun_path + 1, path + 1, strlen(path) - 1), 0); + EXPECT_EQ(sun.sun_path + strlen(path), (const char *)sa + a.GetSize()); +} + TEST(LocalSocketAddress, Abstract) { const char *path = "@foo.bar"; @@ -34,6 +88,7 @@ TEST(LocalSocketAddress, Abstract) EXPECT_TRUE(a.IsDefined()); EXPECT_EQ(a.GetFamily(), AF_LOCAL); EXPECT_EQ(ToString(a), path); + EXPECT_EQ(a.GetLocalPath(), nullptr); const auto &sun = *(const struct sockaddr_un *)a.GetAddress();