event/SocketMonitor: use class SocketDescriptor

This commit is contained in:
Max Kellermann 2017-08-10 18:25:22 +02:00
parent fcfc8bacc0
commit 492b20a89d
21 changed files with 80 additions and 83 deletions

View File

@ -52,7 +52,7 @@ public:
private: private:
void OnAccept(int fd, SocketAddress address, int uid) override { void OnAccept(int fd, SocketAddress address, int uid) override {
client_new(GetEventLoop(), partition, client_new(GetEventLoop(), partition,
fd, address, uid); SocketDescriptor(fd), address, uid);
} }
}; };

View File

@ -94,7 +94,7 @@ public:
std::list<ClientMessage> messages; std::list<ClientMessage> messages;
Client(EventLoop &loop, Partition &partition, Client(EventLoop &loop, Partition &partition,
int fd, int uid, int num); SocketDescriptor fd, int uid, int num);
~Client() { ~Client() {
if (FullyBufferedSocket::IsDefined()) if (FullyBufferedSocket::IsDefined())
@ -236,7 +236,7 @@ client_manager_init();
void void
client_new(EventLoop &loop, Partition &partition, client_new(EventLoop &loop, Partition &partition,
int fd, SocketAddress address, int uid); SocketDescriptor fd, SocketAddress address, int uid);
/** /**
* Write a printf-like formatted string to the client. * Write a printf-like formatted string to the client.

View File

@ -42,7 +42,7 @@
static const char GREETING[] = "OK MPD " PROTOCOL_VERSION "\n"; static const char GREETING[] = "OK MPD " PROTOCOL_VERSION "\n";
Client::Client(EventLoop &_loop, Partition &_partition, Client::Client(EventLoop &_loop, Partition &_partition,
int _fd, int _uid, int _num) SocketDescriptor _fd, int _uid, int _num)
:FullyBufferedSocket(_fd, _loop, 16384, client_max_output_buffer_size), :FullyBufferedSocket(_fd, _loop, 16384, client_max_output_buffer_size),
TimeoutMonitor(_loop), TimeoutMonitor(_loop),
partition(&_partition), partition(&_partition),
@ -57,12 +57,12 @@ Client::Client(EventLoop &_loop, Partition &_partition,
void void
client_new(EventLoop &loop, Partition &partition, client_new(EventLoop &loop, Partition &partition,
int fd, SocketAddress address, int uid) SocketDescriptor fd, SocketAddress address, int uid)
{ {
static unsigned int next_client_num; static unsigned int next_client_num;
const auto remote = ToString(address); const auto remote = ToString(address);
assert(fd >= 0); assert(fd.IsDefined());
#ifdef HAVE_LIBWRAP #ifdef HAVE_LIBWRAP
if (address.GetFamily() != AF_LOCAL) { if (address.GetFamily() != AF_LOCAL) {
@ -70,7 +70,7 @@ client_new(EventLoop &loop, Partition &partition,
const char *progname = "mpd"; const char *progname = "mpd";
struct request_info req; struct request_info req;
request_init(&req, RQ_FILE, fd, RQ_DAEMON, progname, 0); request_init(&req, RQ_FILE, fd.Get(), RQ_DAEMON, progname, 0);
fromhost(&req); fromhost(&req);
@ -80,7 +80,7 @@ client_new(EventLoop &loop, Partition &partition,
"libwrap refused connection (libwrap=%s) from %s", "libwrap refused connection (libwrap=%s) from %s",
progname, remote.c_str()); progname, remote.c_str());
close_socket(fd); fd.Close();
return; return;
} }
} }
@ -89,14 +89,14 @@ client_new(EventLoop &loop, Partition &partition,
ClientList &client_list = *partition.instance.client_list; ClientList &client_list = *partition.instance.client_list;
if (client_list.IsFull()) { if (client_list.IsFull()) {
LogWarning(client_domain, "Max connections reached"); LogWarning(client_domain, "Max connections reached");
close_socket(fd); fd.Close();
return; return;
} }
Client *client = new Client(loop, partition, fd, uid, Client *client = new Client(loop, partition, fd, uid,
next_client_num++); next_client_num++);
(void)send(fd, GREETING, sizeof(GREETING) - 1, 0); (void)fd.Write(GREETING, sizeof(GREETING) - 1);
client_list.Add(*client); client_list.Add(*client);

View File

@ -398,7 +398,7 @@ ProxyDatabase::Connect()
idle_received = unsigned(-1); idle_received = unsigned(-1);
is_idle = false; is_idle = false;
SocketMonitor::Open(mpd_async_get_fd(mpd_connection_get_async(connection))); SocketMonitor::Open(SocketDescriptor(mpd_async_get_fd(mpd_connection_get_async(connection))));
IdleMonitor::Schedule(); IdleMonitor::Schedule();
} }

View File

@ -38,7 +38,8 @@ InotifySource::OnSocketReady(gcc_unused unsigned flags)
static_assert(sizeof(buffer) >= sizeof(struct inotify_event) + NAME_MAX + 1, static_assert(sizeof(buffer) >= sizeof(struct inotify_event) + NAME_MAX + 1,
"inotify buffer too small"); "inotify buffer too small");
ssize_t nbytes = read(Get(), buffer, sizeof(buffer)); auto ifd = Get().ToFileDescriptor();
ssize_t nbytes = ifd.Read(buffer, sizeof(buffer));
if (nbytes < 0) if (nbytes < 0)
FatalSystemError("Failed to read from inotify"); FatalSystemError("Failed to read from inotify");
if (nbytes == 0) if (nbytes == 0)
@ -67,19 +68,20 @@ InotifySource::OnSocketReady(gcc_unused unsigned flags)
return true; return true;
} }
static int static FileDescriptor
InotifyInit() InotifyInit()
{ {
FileDescriptor fd; FileDescriptor fd;
if (!fd.CreateInotify()) if (!fd.CreateInotify())
throw MakeErrno("inotify_init() has failed"); throw MakeErrno("inotify_init() has failed");
return fd.Get(); return fd;
} }
InotifySource::InotifySource(EventLoop &_loop, InotifySource::InotifySource(EventLoop &_loop,
mpd_inotify_callback_t _callback, void *_ctx) mpd_inotify_callback_t _callback, void *_ctx)
:SocketMonitor(InotifyInit(), _loop), :SocketMonitor(SocketDescriptor::FromFileDescriptor(InotifyInit()),
_loop),
callback(_callback), callback_ctx(_ctx) callback(_callback), callback_ctx(_ctx)
{ {
ScheduleRead(); ScheduleRead();
@ -88,7 +90,8 @@ InotifySource::InotifySource(EventLoop &_loop,
int int
InotifySource::Add(const char *path_fs, unsigned mask) InotifySource::Add(const char *path_fs, unsigned mask)
{ {
int wd = inotify_add_watch(Get(), path_fs, mask); auto ifd = Get().ToFileDescriptor();
int wd = inotify_add_watch(ifd.Get(), path_fs, mask);
if (wd < 0) if (wd < 0)
throw MakeErrno("inotify_add_watch() has failed"); throw MakeErrno("inotify_add_watch() has failed");
@ -98,7 +101,8 @@ InotifySource::Add(const char *path_fs, unsigned mask)
void void
InotifySource::Remove(unsigned wd) InotifySource::Remove(unsigned wd)
{ {
int ret = inotify_rm_watch(Get(), wd); auto ifd = Get().ToFileDescriptor();
int ret = inotify_rm_watch(ifd.Get(), wd);
if (ret < 0 && errno != EINVAL) if (ret < 0 && errno != EINVAL)
LogErrno(inotify_domain, "inotify_rm_watch() has failed"); LogErrno(inotify_domain, "inotify_rm_watch() has failed");

View File

@ -38,7 +38,7 @@ class BufferedSocket : protected SocketMonitor {
StaticFifoBuffer<uint8_t, 8192> input; StaticFifoBuffer<uint8_t, 8192> input;
public: public:
BufferedSocket(int _fd, EventLoop &_loop) BufferedSocket(SocketDescriptor _fd, EventLoop &_loop)
:SocketMonitor(_fd, _loop) { :SocketMonitor(_fd, _loop) {
ScheduleRead(); ScheduleRead();
} }

View File

@ -32,7 +32,7 @@ class FullyBufferedSocket : protected BufferedSocket, private IdleMonitor {
PeakBuffer output; PeakBuffer output;
public: public:
FullyBufferedSocket(int _fd, EventLoop &_loop, FullyBufferedSocket(SocketDescriptor _fd, EventLoop &_loop,
size_t normal_size, size_t peak_size=0) size_t normal_size, size_t peak_size=0)
:BufferedSocket(_fd, _loop), IdleMonitor(_loop), :BufferedSocket(_fd, _loop), IdleMonitor(_loop),
output(normal_size, peak_size) { output(normal_size, peak_size) {

View File

@ -29,7 +29,7 @@
EventLoop::EventLoop() EventLoop::EventLoop()
:SocketMonitor(*this) :SocketMonitor(*this)
{ {
SocketMonitor::Open(wake_fd.Get()); SocketMonitor::Open(SocketDescriptor(wake_fd.Get()));
SocketMonitor::Schedule(SocketMonitor::READ); SocketMonitor::Schedule(SocketMonitor::READ);
} }

View File

@ -57,9 +57,9 @@ MultiSocketMonitor::ReplaceSocketList(pollfd *pfds, unsigned n)
{ {
pollfd *const end = pfds + n; pollfd *const end = pfds + n;
UpdateSocketList([pfds, end](int fd) -> unsigned { UpdateSocketList([pfds, end](SocketDescriptor fd) -> unsigned {
auto i = std::find_if(pfds, end, [fd](const struct pollfd &pfd){ auto i = std::find_if(pfds, end, [fd](const struct pollfd &pfd){
return pfd.fd == fd; return pfd.fd == fd.Get();
}); });
if (i == end) if (i == end)
return 0; return 0;
@ -71,7 +71,7 @@ MultiSocketMonitor::ReplaceSocketList(pollfd *pfds, unsigned n)
for (auto i = pfds; i != end; ++i) for (auto i = pfds; i != end; ++i)
if (i->events != 0) if (i->events != 0)
AddSocket(i->fd, i->events); AddSocket(SocketDescriptor(i->fd), i->events);
} }
#endif #endif

View File

@ -59,13 +59,14 @@ class MultiSocketMonitor : IdleMonitor, TimeoutMonitor
unsigned revents; unsigned revents;
public: public:
SingleFD(MultiSocketMonitor &_multi, int _fd, unsigned events) SingleFD(MultiSocketMonitor &_multi, SocketDescriptor _fd,
unsigned events)
:SocketMonitor(_fd, _multi.GetEventLoop()), :SocketMonitor(_fd, _multi.GetEventLoop()),
multi(_multi), revents(0) { multi(_multi), revents(0) {
Schedule(events); Schedule(events);
} }
int GetFD() const { SocketDescriptor GetFD() const {
return SocketMonitor::Get(); return SocketMonitor::Get();
} }
@ -153,7 +154,7 @@ public:
* *
* May only be called from PrepareSockets(). * May only be called from PrepareSockets().
*/ */
void AddSocket(int fd, unsigned events) { void AddSocket(SocketDescriptor fd, unsigned events) {
fds.emplace_front(*this, fd, events); fds.emplace_front(*this, fd, events);
} }

View File

@ -108,7 +108,7 @@ public:
return ::ToString(address); return ::ToString(address);
} }
void SetFD(int _fd) noexcept { void SetFD(SocketDescriptor _fd) noexcept {
SocketMonitor::Open(_fd); SocketMonitor::Open(_fd);
SocketMonitor::ScheduleRead(); SocketMonitor::ScheduleRead();
} }
@ -150,28 +150,23 @@ inline void
OneServerSocket::Accept() noexcept OneServerSocket::Accept() noexcept
{ {
StaticSocketAddress peer_address; StaticSocketAddress peer_address;
size_t peer_address_length = sizeof(peer_address); auto peer_fd = Get().AcceptNonBlock(peer_address);
int peer_fd = if (!peer_fd.IsDefined()) {
accept_cloexec_nonblock(Get(), peer_address.GetAddress(),
&peer_address_length);
if (peer_fd < 0) {
const SocketErrorMessage msg; const SocketErrorMessage msg;
FormatError(server_socket_domain, FormatError(server_socket_domain,
"accept() failed: %s", (const char *)msg); "accept() failed: %s", (const char *)msg);
return; return;
} }
peer_address.SetSize(peer_address_length); if (socket_keepalive(peer_fd.Get())) {
if (socket_keepalive(peer_fd)) {
const SocketErrorMessage msg; const SocketErrorMessage msg;
FormatError(server_socket_domain, FormatError(server_socket_domain,
"Could not set TCP keepalive option: %s", "Could not set TCP keepalive option: %s",
(const char *)msg); (const char *)msg);
} }
parent.OnAccept(peer_fd, peer_address, parent.OnAccept(peer_fd.Get(), peer_address,
get_remote_uid(peer_fd)); get_remote_uid(peer_fd.Get()));
} }
bool bool
@ -199,7 +194,7 @@ OneServerSocket::Open()
/* register in the EventLoop */ /* register in the EventLoop */
SetFD(_fd.Steal()); SetFD(_fd.Release());
} }
ServerSocket::ServerSocket(EventLoop &_loop) ServerSocket::ServerSocket(EventLoop &_loop)
@ -296,18 +291,16 @@ ServerSocket::AddAddress(AllocatedSocketAddress &&address)
} }
void void
ServerSocket::AddFD(int fd) ServerSocket::AddFD(int _fd)
{ {
assert(fd >= 0); assert(_fd >= 0);
StaticSocketAddress address; SocketDescriptor fd(_fd);
socklen_t address_length = sizeof(address);
if (getsockname(fd, address.GetAddress(), StaticSocketAddress address = fd.GetLocalAddress();
&address_length) < 0) if (!address.IsDefined())
throw MakeSocketError("Failed to get socket address"); throw MakeSocketError("Failed to get socket address");
address.SetSize(address_length);
OneServerSocket &s = AddAddress(address); OneServerSocket &s = AddAddress(address);
s.SetFD(fd); s.SetFD(fd);
} }

View File

@ -56,7 +56,7 @@ public:
SignalMonitor(EventLoop &_loop) SignalMonitor(EventLoop &_loop)
:SocketMonitor(_loop) { :SocketMonitor(_loop) {
#ifndef USE_SIGNALFD #ifndef USE_SIGNALFD
SocketMonitor::Open(fd.Get()); SocketMonitor::Open(SocketDescriptor(fd.Get()));
SocketMonitor::ScheduleRead(); SocketMonitor::ScheduleRead();
#endif #endif
} }
@ -70,7 +70,7 @@ public:
fd.Create(mask); fd.Create(mask);
if (!was_open) { if (!was_open) {
SocketMonitor::Open(fd.Get()); SocketMonitor::Open(SocketDescriptor(fd.Get()));
SocketMonitor::ScheduleRead(); SocketMonitor::ScheduleRead();
} }
} }

View File

@ -46,25 +46,22 @@ SocketMonitor::~SocketMonitor()
} }
void void
SocketMonitor::Open(int _fd) SocketMonitor::Open(SocketDescriptor _fd)
{ {
assert(fd < 0); assert(!fd.IsDefined());
assert(_fd >= 0); assert(_fd.IsDefined());
fd = _fd; fd = _fd;
} }
int SocketDescriptor
SocketMonitor::Steal() SocketMonitor::Steal()
{ {
assert(IsDefined()); assert(IsDefined());
Cancel(); Cancel();
int result = fd; return std::exchange(fd, SocketDescriptor::Undefined());
fd = -1;
return result;
} }
void void
@ -72,15 +69,14 @@ SocketMonitor::Abandon()
{ {
assert(IsDefined()); assert(IsDefined());
int old_fd = fd; loop.Abandon(std::exchange(fd, SocketDescriptor::Undefined()).Get(),
fd = -1; *this);
loop.Abandon(old_fd, *this);
} }
void void
SocketMonitor::Close() SocketMonitor::Close()
{ {
close_socket(Steal()); Steal().Close();
} }
void void
@ -92,11 +88,11 @@ SocketMonitor::Schedule(unsigned flags)
return; return;
if (scheduled_flags == 0) if (scheduled_flags == 0)
loop.AddFD(fd, flags, *this); loop.AddFD(fd.Get(), flags, *this);
else if (flags == 0) else if (flags == 0)
loop.RemoveFD(fd, *this); loop.RemoveFD(fd.Get(), *this);
else else
loop.ModifyFD(fd, flags, *this); loop.ModifyFD(fd.Get(), flags, *this);
scheduled_flags = flags; scheduled_flags = flags;
} }
@ -111,7 +107,7 @@ SocketMonitor::Read(void *data, size_t length)
flags |= MSG_DONTWAIT; flags |= MSG_DONTWAIT;
#endif #endif
return recv(Get(), (char *)data, length, flags); return recv(Get().Get(), (char *)data, length, flags);
} }
SocketMonitor::ssize_t SocketMonitor::ssize_t
@ -127,5 +123,5 @@ SocketMonitor::Write(const void *data, size_t length)
flags |= MSG_DONTWAIT; flags |= MSG_DONTWAIT;
#endif #endif
return send(Get(), (const char *)data, length, flags); return send(Get().Get(), (const char *)data, length, flags);
} }

View File

@ -22,6 +22,7 @@
#include "check.h" #include "check.h"
#include "PollGroup.hxx" #include "PollGroup.hxx"
#include "net/SocketDescriptor.hxx"
#include <type_traits> #include <type_traits>
@ -52,7 +53,7 @@ class EventLoop;
* as thread-safe. * as thread-safe.
*/ */
class SocketMonitor { class SocketMonitor {
int fd; SocketDescriptor fd;
EventLoop &loop; EventLoop &loop;
/** /**
@ -71,7 +72,7 @@ public:
SocketMonitor(EventLoop &_loop) SocketMonitor(EventLoop &_loop)
:fd(-1), loop(_loop), scheduled_flags(0) {} :fd(-1), loop(_loop), scheduled_flags(0) {}
SocketMonitor(int _fd, EventLoop &_loop) SocketMonitor(SocketDescriptor _fd, EventLoop &_loop)
:fd(_fd), loop(_loop), scheduled_flags(0) {} :fd(_fd), loop(_loop), scheduled_flags(0) {}
~SocketMonitor(); ~SocketMonitor();
@ -81,22 +82,22 @@ public:
} }
bool IsDefined() const { bool IsDefined() const {
return fd >= 0; return fd.IsDefined();
} }
int Get() const { SocketDescriptor Get() const {
assert(IsDefined()); assert(IsDefined());
return fd; return fd;
} }
void Open(int _fd); void Open(SocketDescriptor _fd);
/** /**
* "Steal" the socket descriptor. This abandons the socket * "Steal" the socket descriptor. This abandons the socket
* and returns it. * and returns it.
*/ */
int Steal(); SocketDescriptor Steal();
/** /**
* Somebody has closed the socket. Unregister this object. * Somebody has closed the socket. Unregister this object.

View File

@ -45,7 +45,7 @@ class CurlSocket final : SocketMonitor {
CurlGlobal &global; CurlGlobal &global;
public: public:
CurlSocket(CurlGlobal &_global, EventLoop &_loop, int _fd) CurlSocket(CurlGlobal &_global, EventLoop &_loop, SocketDescriptor _fd)
:SocketMonitor(_fd, _loop), global(_global) {} :SocketMonitor(_fd, _loop), global(_global) {}
~CurlSocket() { ~CurlSocket() {
@ -120,7 +120,8 @@ CurlSocket::SocketFunction(gcc_unused CURL *easy,
} }
if (cs == nullptr) { if (cs == nullptr) {
cs = new CurlSocket(global, global.GetEventLoop(), s); cs = new CurlSocket(global, global.GetEventLoop(),
SocketDescriptor(s));
global.Assign(s, *cs); global.Assign(s, *cs);
} else { } else {
#ifdef USE_EPOLL #ifdef USE_EPOLL
@ -147,7 +148,7 @@ CurlSocket::OnSocketReady(unsigned flags)
{ {
assert(GetEventLoop().IsInside()); assert(GetEventLoop().IsInside());
global.SocketAction(Get(), FlagsToCurlCSelect(flags)); global.SocketAction(Get().Get(), FlagsToCurlCSelect(flags));
return true; return true;
} }

View File

@ -413,7 +413,7 @@ NfsConnection::ScheduleSocket()
return; return;
_fd.EnableCloseOnExec(); _fd.EnableCloseOnExec();
SocketMonitor::Open(_fd.Get()); SocketMonitor::Open(_fd);
} }
SocketMonitor::Schedule(libnfs_to_events(which_events) SocketMonitor::Schedule(libnfs_to_events(which_events)

View File

@ -185,7 +185,8 @@ HttpdClient::SendResponse()
return true; return true;
} }
HttpdClient::HttpdClient(HttpdOutput &_httpd, int _fd, EventLoop &_loop, HttpdClient::HttpdClient(HttpdOutput &_httpd, SocketDescriptor _fd,
EventLoop &_loop,
bool _metadata_supported) bool _metadata_supported)
:BufferedSocket(_fd, _loop), :BufferedSocket(_fd, _loop),
httpd(_httpd), httpd(_httpd),

View File

@ -131,7 +131,7 @@ public:
* @param httpd the HTTP output device * @param httpd the HTTP output device
* @param _fd the socket file descriptor * @param _fd the socket file descriptor
*/ */
HttpdClient(HttpdOutput &httpd, int _fd, EventLoop &_loop, HttpdClient(HttpdOutput &httpd, SocketDescriptor _fd, EventLoop &_loop,
bool _metadata_supported); bool _metadata_supported);
/** /**

View File

@ -206,7 +206,7 @@ public:
return HasClients(); return HasClients();
} }
void AddClient(int fd); void AddClient(SocketDescriptor fd);
/** /**
* Removes a client from the httpd_output.clients linked list. * Removes a client from the httpd_output.clients linked list.

View File

@ -118,7 +118,7 @@ HttpdOutput::Unbind()
* HttpdOutput.clients linked list. * HttpdOutput.clients linked list.
*/ */
inline void inline void
HttpdOutput::AddClient(int fd) HttpdOutput::AddClient(SocketDescriptor fd)
{ {
auto *client = new HttpdClient(*this, fd, GetEventLoop(), auto *client = new HttpdClient(*this, fd, GetEventLoop(),
!encoder->ImplementsTag()); !encoder->ImplementsTag());
@ -184,7 +184,7 @@ HttpdOutput::OnAccept(int fd, SocketAddress address, gcc_unused int uid)
/* can we allow additional client */ /* can we allow additional client */
if (open && (clients_max == 0 || clients.size() < clients_max)) if (open && (clients_max == 0 || clients.size() < clients_max))
AddClient(fd); AddClient(SocketDescriptor(fd));
else else
close_socket(fd); close_socket(fd);
} }

View File

@ -48,7 +48,7 @@ private:
AvahiWatchEvent received; AvahiWatchEvent received;
public: public:
AvahiWatch(int _fd, AvahiWatchEvent _event, AvahiWatch(SocketDescriptor _fd, AvahiWatchEvent _event,
AvahiWatchCallback _callback, void *_userdata, AvahiWatchCallback _callback, void *_userdata,
EventLoop &_loop) EventLoop &_loop)
:SocketMonitor(_fd, _loop), :SocketMonitor(_fd, _loop),
@ -72,7 +72,7 @@ public:
protected: protected:
virtual bool OnSocketReady(unsigned flags) { virtual bool OnSocketReady(unsigned flags) {
received = ToAvahiWatchEvent(flags); received = ToAvahiWatchEvent(flags);
callback(this, Get(), received, userdata); callback(this, Get().Get(), received, userdata);
received = AvahiWatchEvent(0); received = AvahiWatchEvent(0);
return true; return true;
} }
@ -132,7 +132,7 @@ MyAvahiPoll::WatchNew(const AvahiPoll *api, int fd, AvahiWatchEvent event,
AvahiWatchCallback callback, void *userdata) { AvahiWatchCallback callback, void *userdata) {
const MyAvahiPoll &poll = *(const MyAvahiPoll *)api; const MyAvahiPoll &poll = *(const MyAvahiPoll *)api;
return new AvahiWatch(fd, event, callback, userdata, return new AvahiWatch(SocketDescriptor(fd), event, callback, userdata,
poll.event_loop); poll.event_loop);
} }