event/SocketMonitor: use class SocketDescriptor

This commit is contained in:
Max Kellermann 2017-08-10 18:25:22 +02:00
parent fcfc8bacc0
commit 492b20a89d
21 changed files with 80 additions and 83 deletions

View File

@ -52,7 +52,7 @@ public:
private:
void OnAccept(int fd, SocketAddress address, int uid) override {
client_new(GetEventLoop(), partition,
fd, address, uid);
SocketDescriptor(fd), address, uid);
}
};

View File

@ -94,7 +94,7 @@ public:
std::list<ClientMessage> messages;
Client(EventLoop &loop, Partition &partition,
int fd, int uid, int num);
SocketDescriptor fd, int uid, int num);
~Client() {
if (FullyBufferedSocket::IsDefined())
@ -236,7 +236,7 @@ client_manager_init();
void
client_new(EventLoop &loop, Partition &partition,
int fd, SocketAddress address, int uid);
SocketDescriptor fd, SocketAddress address, int uid);
/**
* Write a printf-like formatted string to the client.

View File

@ -42,7 +42,7 @@
static const char GREETING[] = "OK MPD " PROTOCOL_VERSION "\n";
Client::Client(EventLoop &_loop, Partition &_partition,
int _fd, int _uid, int _num)
SocketDescriptor _fd, int _uid, int _num)
:FullyBufferedSocket(_fd, _loop, 16384, client_max_output_buffer_size),
TimeoutMonitor(_loop),
partition(&_partition),
@ -57,12 +57,12 @@ Client::Client(EventLoop &_loop, Partition &_partition,
void
client_new(EventLoop &loop, Partition &partition,
int fd, SocketAddress address, int uid)
SocketDescriptor fd, SocketAddress address, int uid)
{
static unsigned int next_client_num;
const auto remote = ToString(address);
assert(fd >= 0);
assert(fd.IsDefined());
#ifdef HAVE_LIBWRAP
if (address.GetFamily() != AF_LOCAL) {
@ -70,7 +70,7 @@ client_new(EventLoop &loop, Partition &partition,
const char *progname = "mpd";
struct request_info req;
request_init(&req, RQ_FILE, fd, RQ_DAEMON, progname, 0);
request_init(&req, RQ_FILE, fd.Get(), RQ_DAEMON, progname, 0);
fromhost(&req);
@ -80,7 +80,7 @@ client_new(EventLoop &loop, Partition &partition,
"libwrap refused connection (libwrap=%s) from %s",
progname, remote.c_str());
close_socket(fd);
fd.Close();
return;
}
}
@ -89,14 +89,14 @@ client_new(EventLoop &loop, Partition &partition,
ClientList &client_list = *partition.instance.client_list;
if (client_list.IsFull()) {
LogWarning(client_domain, "Max connections reached");
close_socket(fd);
fd.Close();
return;
}
Client *client = new Client(loop, partition, fd, uid,
next_client_num++);
(void)send(fd, GREETING, sizeof(GREETING) - 1, 0);
(void)fd.Write(GREETING, sizeof(GREETING) - 1);
client_list.Add(*client);

View File

@ -398,7 +398,7 @@ ProxyDatabase::Connect()
idle_received = unsigned(-1);
is_idle = false;
SocketMonitor::Open(mpd_async_get_fd(mpd_connection_get_async(connection)));
SocketMonitor::Open(SocketDescriptor(mpd_async_get_fd(mpd_connection_get_async(connection))));
IdleMonitor::Schedule();
}

View File

@ -38,7 +38,8 @@ InotifySource::OnSocketReady(gcc_unused unsigned flags)
static_assert(sizeof(buffer) >= sizeof(struct inotify_event) + NAME_MAX + 1,
"inotify buffer too small");
ssize_t nbytes = read(Get(), buffer, sizeof(buffer));
auto ifd = Get().ToFileDescriptor();
ssize_t nbytes = ifd.Read(buffer, sizeof(buffer));
if (nbytes < 0)
FatalSystemError("Failed to read from inotify");
if (nbytes == 0)
@ -67,19 +68,20 @@ InotifySource::OnSocketReady(gcc_unused unsigned flags)
return true;
}
static int
static FileDescriptor
InotifyInit()
{
FileDescriptor fd;
if (!fd.CreateInotify())
throw MakeErrno("inotify_init() has failed");
return fd.Get();
return fd;
}
InotifySource::InotifySource(EventLoop &_loop,
mpd_inotify_callback_t _callback, void *_ctx)
:SocketMonitor(InotifyInit(), _loop),
:SocketMonitor(SocketDescriptor::FromFileDescriptor(InotifyInit()),
_loop),
callback(_callback), callback_ctx(_ctx)
{
ScheduleRead();
@ -88,7 +90,8 @@ InotifySource::InotifySource(EventLoop &_loop,
int
InotifySource::Add(const char *path_fs, unsigned mask)
{
int wd = inotify_add_watch(Get(), path_fs, mask);
auto ifd = Get().ToFileDescriptor();
int wd = inotify_add_watch(ifd.Get(), path_fs, mask);
if (wd < 0)
throw MakeErrno("inotify_add_watch() has failed");
@ -98,7 +101,8 @@ InotifySource::Add(const char *path_fs, unsigned mask)
void
InotifySource::Remove(unsigned wd)
{
int ret = inotify_rm_watch(Get(), wd);
auto ifd = Get().ToFileDescriptor();
int ret = inotify_rm_watch(ifd.Get(), wd);
if (ret < 0 && errno != EINVAL)
LogErrno(inotify_domain, "inotify_rm_watch() has failed");

View File

@ -38,7 +38,7 @@ class BufferedSocket : protected SocketMonitor {
StaticFifoBuffer<uint8_t, 8192> input;
public:
BufferedSocket(int _fd, EventLoop &_loop)
BufferedSocket(SocketDescriptor _fd, EventLoop &_loop)
:SocketMonitor(_fd, _loop) {
ScheduleRead();
}

View File

@ -32,7 +32,7 @@ class FullyBufferedSocket : protected BufferedSocket, private IdleMonitor {
PeakBuffer output;
public:
FullyBufferedSocket(int _fd, EventLoop &_loop,
FullyBufferedSocket(SocketDescriptor _fd, EventLoop &_loop,
size_t normal_size, size_t peak_size=0)
:BufferedSocket(_fd, _loop), IdleMonitor(_loop),
output(normal_size, peak_size) {

View File

@ -29,7 +29,7 @@
EventLoop::EventLoop()
:SocketMonitor(*this)
{
SocketMonitor::Open(wake_fd.Get());
SocketMonitor::Open(SocketDescriptor(wake_fd.Get()));
SocketMonitor::Schedule(SocketMonitor::READ);
}

View File

@ -57,9 +57,9 @@ MultiSocketMonitor::ReplaceSocketList(pollfd *pfds, unsigned n)
{
pollfd *const end = pfds + n;
UpdateSocketList([pfds, end](int fd) -> unsigned {
UpdateSocketList([pfds, end](SocketDescriptor fd) -> unsigned {
auto i = std::find_if(pfds, end, [fd](const struct pollfd &pfd){
return pfd.fd == fd;
return pfd.fd == fd.Get();
});
if (i == end)
return 0;
@ -71,7 +71,7 @@ MultiSocketMonitor::ReplaceSocketList(pollfd *pfds, unsigned n)
for (auto i = pfds; i != end; ++i)
if (i->events != 0)
AddSocket(i->fd, i->events);
AddSocket(SocketDescriptor(i->fd), i->events);
}
#endif

View File

@ -59,13 +59,14 @@ class MultiSocketMonitor : IdleMonitor, TimeoutMonitor
unsigned revents;
public:
SingleFD(MultiSocketMonitor &_multi, int _fd, unsigned events)
SingleFD(MultiSocketMonitor &_multi, SocketDescriptor _fd,
unsigned events)
:SocketMonitor(_fd, _multi.GetEventLoop()),
multi(_multi), revents(0) {
Schedule(events);
}
int GetFD() const {
SocketDescriptor GetFD() const {
return SocketMonitor::Get();
}
@ -153,7 +154,7 @@ public:
*
* May only be called from PrepareSockets().
*/
void AddSocket(int fd, unsigned events) {
void AddSocket(SocketDescriptor fd, unsigned events) {
fds.emplace_front(*this, fd, events);
}

View File

@ -108,7 +108,7 @@ public:
return ::ToString(address);
}
void SetFD(int _fd) noexcept {
void SetFD(SocketDescriptor _fd) noexcept {
SocketMonitor::Open(_fd);
SocketMonitor::ScheduleRead();
}
@ -150,28 +150,23 @@ inline void
OneServerSocket::Accept() noexcept
{
StaticSocketAddress peer_address;
size_t peer_address_length = sizeof(peer_address);
int peer_fd =
accept_cloexec_nonblock(Get(), peer_address.GetAddress(),
&peer_address_length);
if (peer_fd < 0) {
auto peer_fd = Get().AcceptNonBlock(peer_address);
if (!peer_fd.IsDefined()) {
const SocketErrorMessage msg;
FormatError(server_socket_domain,
"accept() failed: %s", (const char *)msg);
return;
}
peer_address.SetSize(peer_address_length);
if (socket_keepalive(peer_fd)) {
if (socket_keepalive(peer_fd.Get())) {
const SocketErrorMessage msg;
FormatError(server_socket_domain,
"Could not set TCP keepalive option: %s",
(const char *)msg);
}
parent.OnAccept(peer_fd, peer_address,
get_remote_uid(peer_fd));
parent.OnAccept(peer_fd.Get(), peer_address,
get_remote_uid(peer_fd.Get()));
}
bool
@ -199,7 +194,7 @@ OneServerSocket::Open()
/* register in the EventLoop */
SetFD(_fd.Steal());
SetFD(_fd.Release());
}
ServerSocket::ServerSocket(EventLoop &_loop)
@ -296,18 +291,16 @@ ServerSocket::AddAddress(AllocatedSocketAddress &&address)
}
void
ServerSocket::AddFD(int fd)
ServerSocket::AddFD(int _fd)
{
assert(fd >= 0);
assert(_fd >= 0);
StaticSocketAddress address;
socklen_t address_length = sizeof(address);
if (getsockname(fd, address.GetAddress(),
&address_length) < 0)
SocketDescriptor fd(_fd);
StaticSocketAddress address = fd.GetLocalAddress();
if (!address.IsDefined())
throw MakeSocketError("Failed to get socket address");
address.SetSize(address_length);
OneServerSocket &s = AddAddress(address);
s.SetFD(fd);
}

View File

@ -56,7 +56,7 @@ public:
SignalMonitor(EventLoop &_loop)
:SocketMonitor(_loop) {
#ifndef USE_SIGNALFD
SocketMonitor::Open(fd.Get());
SocketMonitor::Open(SocketDescriptor(fd.Get()));
SocketMonitor::ScheduleRead();
#endif
}
@ -70,7 +70,7 @@ public:
fd.Create(mask);
if (!was_open) {
SocketMonitor::Open(fd.Get());
SocketMonitor::Open(SocketDescriptor(fd.Get()));
SocketMonitor::ScheduleRead();
}
}

View File

@ -46,25 +46,22 @@ SocketMonitor::~SocketMonitor()
}
void
SocketMonitor::Open(int _fd)
SocketMonitor::Open(SocketDescriptor _fd)
{
assert(fd < 0);
assert(_fd >= 0);
assert(!fd.IsDefined());
assert(_fd.IsDefined());
fd = _fd;
}
int
SocketDescriptor
SocketMonitor::Steal()
{
assert(IsDefined());
Cancel();
int result = fd;
fd = -1;
return result;
return std::exchange(fd, SocketDescriptor::Undefined());
}
void
@ -72,15 +69,14 @@ SocketMonitor::Abandon()
{
assert(IsDefined());
int old_fd = fd;
fd = -1;
loop.Abandon(old_fd, *this);
loop.Abandon(std::exchange(fd, SocketDescriptor::Undefined()).Get(),
*this);
}
void
SocketMonitor::Close()
{
close_socket(Steal());
Steal().Close();
}
void
@ -92,11 +88,11 @@ SocketMonitor::Schedule(unsigned flags)
return;
if (scheduled_flags == 0)
loop.AddFD(fd, flags, *this);
loop.AddFD(fd.Get(), flags, *this);
else if (flags == 0)
loop.RemoveFD(fd, *this);
loop.RemoveFD(fd.Get(), *this);
else
loop.ModifyFD(fd, flags, *this);
loop.ModifyFD(fd.Get(), flags, *this);
scheduled_flags = flags;
}
@ -111,7 +107,7 @@ SocketMonitor::Read(void *data, size_t length)
flags |= MSG_DONTWAIT;
#endif
return recv(Get(), (char *)data, length, flags);
return recv(Get().Get(), (char *)data, length, flags);
}
SocketMonitor::ssize_t
@ -127,5 +123,5 @@ SocketMonitor::Write(const void *data, size_t length)
flags |= MSG_DONTWAIT;
#endif
return send(Get(), (const char *)data, length, flags);
return send(Get().Get(), (const char *)data, length, flags);
}

View File

@ -22,6 +22,7 @@
#include "check.h"
#include "PollGroup.hxx"
#include "net/SocketDescriptor.hxx"
#include <type_traits>
@ -52,7 +53,7 @@ class EventLoop;
* as thread-safe.
*/
class SocketMonitor {
int fd;
SocketDescriptor fd;
EventLoop &loop;
/**
@ -71,7 +72,7 @@ public:
SocketMonitor(EventLoop &_loop)
:fd(-1), loop(_loop), scheduled_flags(0) {}
SocketMonitor(int _fd, EventLoop &_loop)
SocketMonitor(SocketDescriptor _fd, EventLoop &_loop)
:fd(_fd), loop(_loop), scheduled_flags(0) {}
~SocketMonitor();
@ -81,22 +82,22 @@ public:
}
bool IsDefined() const {
return fd >= 0;
return fd.IsDefined();
}
int Get() const {
SocketDescriptor Get() const {
assert(IsDefined());
return fd;
}
void Open(int _fd);
void Open(SocketDescriptor _fd);
/**
* "Steal" the socket descriptor. This abandons the socket
* and returns it.
*/
int Steal();
SocketDescriptor Steal();
/**
* Somebody has closed the socket. Unregister this object.

View File

@ -45,7 +45,7 @@ class CurlSocket final : SocketMonitor {
CurlGlobal &global;
public:
CurlSocket(CurlGlobal &_global, EventLoop &_loop, int _fd)
CurlSocket(CurlGlobal &_global, EventLoop &_loop, SocketDescriptor _fd)
:SocketMonitor(_fd, _loop), global(_global) {}
~CurlSocket() {
@ -120,7 +120,8 @@ CurlSocket::SocketFunction(gcc_unused CURL *easy,
}
if (cs == nullptr) {
cs = new CurlSocket(global, global.GetEventLoop(), s);
cs = new CurlSocket(global, global.GetEventLoop(),
SocketDescriptor(s));
global.Assign(s, *cs);
} else {
#ifdef USE_EPOLL
@ -147,7 +148,7 @@ CurlSocket::OnSocketReady(unsigned flags)
{
assert(GetEventLoop().IsInside());
global.SocketAction(Get(), FlagsToCurlCSelect(flags));
global.SocketAction(Get().Get(), FlagsToCurlCSelect(flags));
return true;
}

View File

@ -413,7 +413,7 @@ NfsConnection::ScheduleSocket()
return;
_fd.EnableCloseOnExec();
SocketMonitor::Open(_fd.Get());
SocketMonitor::Open(_fd);
}
SocketMonitor::Schedule(libnfs_to_events(which_events)

View File

@ -185,7 +185,8 @@ HttpdClient::SendResponse()
return true;
}
HttpdClient::HttpdClient(HttpdOutput &_httpd, int _fd, EventLoop &_loop,
HttpdClient::HttpdClient(HttpdOutput &_httpd, SocketDescriptor _fd,
EventLoop &_loop,
bool _metadata_supported)
:BufferedSocket(_fd, _loop),
httpd(_httpd),

View File

@ -131,7 +131,7 @@ public:
* @param httpd the HTTP output device
* @param _fd the socket file descriptor
*/
HttpdClient(HttpdOutput &httpd, int _fd, EventLoop &_loop,
HttpdClient(HttpdOutput &httpd, SocketDescriptor _fd, EventLoop &_loop,
bool _metadata_supported);
/**

View File

@ -206,7 +206,7 @@ public:
return HasClients();
}
void AddClient(int fd);
void AddClient(SocketDescriptor fd);
/**
* Removes a client from the httpd_output.clients linked list.

View File

@ -118,7 +118,7 @@ HttpdOutput::Unbind()
* HttpdOutput.clients linked list.
*/
inline void
HttpdOutput::AddClient(int fd)
HttpdOutput::AddClient(SocketDescriptor fd)
{
auto *client = new HttpdClient(*this, fd, GetEventLoop(),
!encoder->ImplementsTag());
@ -184,7 +184,7 @@ HttpdOutput::OnAccept(int fd, SocketAddress address, gcc_unused int uid)
/* can we allow additional client */
if (open && (clients_max == 0 || clients.size() < clients_max))
AddClient(fd);
AddClient(SocketDescriptor(fd));
else
close_socket(fd);
}

View File

@ -48,7 +48,7 @@ private:
AvahiWatchEvent received;
public:
AvahiWatch(int _fd, AvahiWatchEvent _event,
AvahiWatch(SocketDescriptor _fd, AvahiWatchEvent _event,
AvahiWatchCallback _callback, void *_userdata,
EventLoop &_loop)
:SocketMonitor(_fd, _loop),
@ -72,7 +72,7 @@ public:
protected:
virtual bool OnSocketReady(unsigned flags) {
received = ToAvahiWatchEvent(flags);
callback(this, Get(), received, userdata);
callback(this, Get().Get(), received, userdata);
received = AvahiWatchEvent(0);
return true;
}
@ -132,7 +132,7 @@ MyAvahiPoll::WatchNew(const AvahiPoll *api, int fd, AvahiWatchEvent event,
AvahiWatchCallback callback, void *userdata) {
const MyAvahiPoll &poll = *(const MyAvahiPoll *)api;
return new AvahiWatch(fd, event, callback, userdata,
return new AvahiWatch(SocketDescriptor(fd), event, callback, userdata,
poll.event_loop);
}