net/ServerSocket: pass UniqueSocketDescriptor&& to OnAccept()

This commit is contained in:
Max Kellermann 2017-08-10 19:13:18 +02:00
parent 492b20a89d
commit 9a5bcc6db0
9 changed files with 35 additions and 28 deletions

View File

@ -50,9 +50,10 @@ public:
:ServerSocket(_loop), partition(_partition) {} :ServerSocket(_loop), partition(_partition) {}
private: private:
void OnAccept(int fd, SocketAddress address, int uid) override { void OnAccept(UniqueSocketDescriptor &&fd,
SocketAddress address, int uid) override {
client_new(GetEventLoop(), partition, client_new(GetEventLoop(), partition,
SocketDescriptor(fd), address, uid); std::move(fd), address, uid);
} }
}; };

View File

@ -38,6 +38,7 @@
#include <stddef.h> #include <stddef.h>
class SocketAddress; class SocketAddress;
class UniqueSocketDescriptor;
class EventLoop; class EventLoop;
class Path; class Path;
struct Instance; struct Instance;
@ -94,7 +95,7 @@ public:
std::list<ClientMessage> messages; std::list<ClientMessage> messages;
Client(EventLoop &loop, Partition &partition, Client(EventLoop &loop, Partition &partition,
SocketDescriptor fd, int uid, int num); UniqueSocketDescriptor &&fd, int uid, int num);
~Client() { ~Client() {
if (FullyBufferedSocket::IsDefined()) if (FullyBufferedSocket::IsDefined())
@ -236,7 +237,7 @@ client_manager_init();
void void
client_new(EventLoop &loop, Partition &partition, client_new(EventLoop &loop, Partition &partition,
SocketDescriptor fd, SocketAddress address, int uid); UniqueSocketDescriptor &&fd, SocketAddress address, int uid);
/** /**
* Write a printf-like formatted string to the client. * Write a printf-like formatted string to the client.

View File

@ -23,6 +23,7 @@
#include "Partition.hxx" #include "Partition.hxx"
#include "Instance.hxx" #include "Instance.hxx"
#include "system/fd_util.h" #include "system/fd_util.h"
#include "net/UniqueSocketDescriptor.hxx"
#include "net/SocketAddress.hxx" #include "net/SocketAddress.hxx"
#include "net/ToString.hxx" #include "net/ToString.hxx"
#include "Permission.hxx" #include "Permission.hxx"
@ -42,8 +43,9 @@
static const char GREETING[] = "OK MPD " PROTOCOL_VERSION "\n"; static const char GREETING[] = "OK MPD " PROTOCOL_VERSION "\n";
Client::Client(EventLoop &_loop, Partition &_partition, Client::Client(EventLoop &_loop, Partition &_partition,
SocketDescriptor _fd, int _uid, int _num) UniqueSocketDescriptor &&_fd, int _uid, int _num)
:FullyBufferedSocket(_fd, _loop, 16384, client_max_output_buffer_size), :FullyBufferedSocket(_fd.Release(), _loop,
16384, client_max_output_buffer_size),
TimeoutMonitor(_loop), TimeoutMonitor(_loop),
partition(&_partition), partition(&_partition),
permission(getDefaultPermissions()), permission(getDefaultPermissions()),
@ -57,7 +59,7 @@ Client::Client(EventLoop &_loop, Partition &_partition,
void void
client_new(EventLoop &loop, Partition &partition, client_new(EventLoop &loop, Partition &partition,
SocketDescriptor fd, SocketAddress address, int uid) UniqueSocketDescriptor &&fd, SocketAddress address, int uid)
{ {
static unsigned int next_client_num; static unsigned int next_client_num;
const auto remote = ToString(address); const auto remote = ToString(address);
@ -80,7 +82,6 @@ client_new(EventLoop &loop, Partition &partition,
"libwrap refused connection (libwrap=%s) from %s", "libwrap refused connection (libwrap=%s) from %s",
progname, remote.c_str()); progname, remote.c_str());
fd.Close();
return; return;
} }
} }
@ -89,15 +90,14 @@ client_new(EventLoop &loop, Partition &partition,
ClientList &client_list = *partition.instance.client_list; ClientList &client_list = *partition.instance.client_list;
if (client_list.IsFull()) { if (client_list.IsFull()) {
LogWarning(client_domain, "Max connections reached"); LogWarning(client_domain, "Max connections reached");
fd.Close();
return; return;
} }
Client *client = new Client(loop, partition, fd, uid,
next_client_num++);
(void)fd.Write(GREETING, sizeof(GREETING) - 1); (void)fd.Write(GREETING, sizeof(GREETING) - 1);
Client *client = new Client(loop, partition, std::move(fd), uid,
next_client_num++);
client_list.Add(*client); client_list.Add(*client);
FormatInfo(client_domain, "[%u] opened from %s", FormatInfo(client_domain, "[%u] opened from %s",

View File

@ -150,7 +150,7 @@ inline void
OneServerSocket::Accept() noexcept OneServerSocket::Accept() noexcept
{ {
StaticSocketAddress peer_address; StaticSocketAddress peer_address;
auto peer_fd = Get().AcceptNonBlock(peer_address); UniqueSocketDescriptor peer_fd(Get().AcceptNonBlock(peer_address));
if (!peer_fd.IsDefined()) { if (!peer_fd.IsDefined()) {
const SocketErrorMessage msg; const SocketErrorMessage msg;
FormatError(server_socket_domain, FormatError(server_socket_domain,
@ -165,7 +165,7 @@ OneServerSocket::Accept() noexcept
(const char *)msg); (const char *)msg);
} }
parent.OnAccept(peer_fd.Get(), peer_address, parent.OnAccept(std::move(peer_fd), peer_address,
get_remote_uid(peer_fd.Get())); get_remote_uid(peer_fd.Get()));
} }

View File

@ -24,6 +24,7 @@
class SocketAddress; class SocketAddress;
class AllocatedSocketAddress; class AllocatedSocketAddress;
class UniqueSocketDescriptor;
class EventLoop; class EventLoop;
class AllocatedPath; class AllocatedPath;
class OneServerSocket; class OneServerSocket;
@ -116,7 +117,8 @@ public:
void Close(); void Close();
protected: protected:
virtual void OnAccept(int fd, SocketAddress address, int uid) = 0; virtual void OnAccept(UniqueSocketDescriptor &&fd,
SocketAddress address, int uid) = 0;
}; };
#endif #endif

View File

@ -25,6 +25,7 @@
#include "Page.hxx" #include "Page.hxx"
#include "IcyMetaDataServer.hxx" #include "IcyMetaDataServer.hxx"
#include "net/SocketError.hxx" #include "net/SocketError.hxx"
#include "net/UniqueSocketDescriptor.hxx"
#include "Log.hxx" #include "Log.hxx"
#include <assert.h> #include <assert.h>
@ -185,10 +186,10 @@ HttpdClient::SendResponse()
return true; return true;
} }
HttpdClient::HttpdClient(HttpdOutput &_httpd, SocketDescriptor _fd, HttpdClient::HttpdClient(HttpdOutput &_httpd, UniqueSocketDescriptor &&_fd,
EventLoop &_loop, EventLoop &_loop,
bool _metadata_supported) bool _metadata_supported)
:BufferedSocket(_fd, _loop), :BufferedSocket(_fd.Release(), _loop),
httpd(_httpd), httpd(_httpd),
state(REQUEST), state(REQUEST),
queue_size(0), queue_size(0),

View File

@ -32,6 +32,7 @@
#include <stddef.h> #include <stddef.h>
class UniqueSocketDescriptor;
class HttpdOutput; class HttpdOutput;
class HttpdClient final class HttpdClient final
@ -131,7 +132,8 @@ public:
* @param httpd the HTTP output device * @param httpd the HTTP output device
* @param _fd the socket file descriptor * @param _fd the socket file descriptor
*/ */
HttpdClient(HttpdOutput &httpd, SocketDescriptor _fd, EventLoop &_loop, HttpdClient(HttpdOutput &httpd, UniqueSocketDescriptor &&_fd,
EventLoop &_loop,
bool _metadata_supported); bool _metadata_supported);
/** /**

View File

@ -206,7 +206,7 @@ public:
return HasClients(); return HasClients();
} }
void AddClient(SocketDescriptor fd); void AddClient(UniqueSocketDescriptor &&fd);
/** /**
* Removes a client from the httpd_output.clients linked list. * Removes a client from the httpd_output.clients linked list.
@ -257,7 +257,8 @@ public:
private: private:
virtual void RunDeferred() override; virtual void RunDeferred() override;
void OnAccept(int fd, SocketAddress address, int uid) override; void OnAccept(UniqueSocketDescriptor &&fd,
SocketAddress address, int uid) override;
}; };
extern const class Domain httpd_output_domain; extern const class Domain httpd_output_domain;

View File

@ -25,6 +25,7 @@
#include "encoder/EncoderInterface.hxx" #include "encoder/EncoderInterface.hxx"
#include "encoder/EncoderPlugin.hxx" #include "encoder/EncoderPlugin.hxx"
#include "encoder/EncoderList.hxx" #include "encoder/EncoderList.hxx"
#include "net/UniqueSocketDescriptor.hxx"
#include "net/SocketAddress.hxx" #include "net/SocketAddress.hxx"
#include "net/ToString.hxx" #include "net/ToString.hxx"
#include "Page.hxx" #include "Page.hxx"
@ -118,9 +119,9 @@ HttpdOutput::Unbind()
* HttpdOutput.clients linked list. * HttpdOutput.clients linked list.
*/ */
inline void inline void
HttpdOutput::AddClient(SocketDescriptor fd) HttpdOutput::AddClient(UniqueSocketDescriptor &&fd)
{ {
auto *client = new HttpdClient(*this, fd, GetEventLoop(), auto *client = new HttpdClient(*this, std::move(fd), GetEventLoop(),
!encoder->ImplementsTag()); !encoder->ImplementsTag());
clients.push_front(*client); clients.push_front(*client);
@ -151,7 +152,8 @@ HttpdOutput::RunDeferred()
} }
void void
HttpdOutput::OnAccept(int fd, SocketAddress address, gcc_unused int uid) HttpdOutput::OnAccept(UniqueSocketDescriptor &&fd,
SocketAddress address, gcc_unused int uid)
{ {
/* the listener socket has become readable - a client has /* the listener socket has become readable - a client has
connected */ connected */
@ -163,7 +165,7 @@ HttpdOutput::OnAccept(int fd, SocketAddress address, gcc_unused int uid)
const char *progname = "mpd"; const char *progname = "mpd";
struct request_info req; struct request_info req;
request_init(&req, RQ_FILE, fd, RQ_DAEMON, progname, 0); request_init(&req, RQ_FILE, fd.Get(), RQ_DAEMON, progname, 0);
fromhost(&req); fromhost(&req);
@ -172,7 +174,6 @@ HttpdOutput::OnAccept(int fd, SocketAddress address, gcc_unused int uid)
FormatWarning(httpd_output_domain, FormatWarning(httpd_output_domain,
"libwrap refused connection (libwrap=%s) from %s", "libwrap refused connection (libwrap=%s) from %s",
progname, hostaddr.c_str()); progname, hostaddr.c_str());
close_socket(fd);
return; return;
} }
} }
@ -184,9 +185,7 @@ HttpdOutput::OnAccept(int fd, SocketAddress address, gcc_unused int uid)
/* can we allow additional client */ /* can we allow additional client */
if (open && (clients_max == 0 || clients.size() < clients_max)) if (open && (clients_max == 0 || clients.size() < clients_max))
AddClient(SocketDescriptor(fd)); AddClient(std::move(fd));
else
close_socket(fd);
} }
PagePtr PagePtr