From 491cc8f54de0d4558255745b9304a395aab90bd2 Mon Sep 17 00:00:00 2001
From: Max Kellermann <mk@cm4all.com>
Date: Wed, 27 Sep 2023 09:32:48 +0200
Subject: [PATCH] net/SocketDescriptor: add {Read,Write}NoWait()

It was surprising that Read() was non-blocking, but there was no
blocking version of it.  Let's make the non-blocking behavior explicit
and change Read() to be blocking.

In order to find existing callers easily with compiler errors, this
also refactors Read()/Write() to take a std::span parameter.
---
 src/client/New.cxx                       |  7 +++++--
 src/event/BufferedSocket.cxx             |  2 +-
 src/event/FullyBufferedSocket.cxx        |  2 +-
 src/net/SocketDescriptor.cxx             | 15 +++++++++++++--
 src/net/SocketDescriptor.hxx             | 22 ++++++++++++++++++----
 src/output/plugins/httpd/HttpdClient.cxx | 12 ++++++------
 src/system/EventPipe.cxx                 |  7 ++++---
 7 files changed, 48 insertions(+), 19 deletions(-)

diff --git a/src/client/New.cxx b/src/client/New.cxx
index d3df5158b..818ae01c7 100644
--- a/src/client/New.cxx
+++ b/src/client/New.cxx
@@ -11,12 +11,15 @@
 #include "lib/fmt/SocketAddressFormatter.hxx"
 #include "net/UniqueSocketDescriptor.hxx"
 #include "net/SocketAddress.hxx"
+#include "util/SpanCast.hxx"
 #include "Log.hxx"
 #include "Version.h"
 
 #include <cassert>
 
-static constexpr char GREETING[] = "OK MPD " PROTOCOL_VERSION "\n";
+using std::string_view_literals::operator""sv;
+
+static constexpr auto GREETING = "OK MPD " PROTOCOL_VERSION "\n"sv;
 
 Client::Client(EventLoop &_loop, Partition &_partition,
 	       UniqueSocketDescriptor _fd,
@@ -49,7 +52,7 @@ client_new(EventLoop &loop, Partition &partition,
 		return;
 	}
 
-	(void)fd.Write(GREETING, sizeof(GREETING) - 1);
+	(void)fd.WriteNoWait(AsBytes(GREETING));
 
 	const unsigned num = next_client_num++;
 	auto *client = new Client(loop, partition, std::move(fd), uid,
diff --git a/src/event/BufferedSocket.cxx b/src/event/BufferedSocket.cxx
index 8a9a6e047..b9f68ad50 100644
--- a/src/event/BufferedSocket.cxx
+++ b/src/event/BufferedSocket.cxx
@@ -9,7 +9,7 @@
 inline BufferedSocket::ssize_t
 BufferedSocket::DirectRead(std::span<std::byte> dest) noexcept
 {
-	const auto nbytes = GetSocket().Read((char *)dest.data(), dest.size());
+	const auto nbytes = GetSocket().ReadNoWait(dest);
 	if (nbytes > 0) [[likely]]
 		return nbytes;
 
diff --git a/src/event/FullyBufferedSocket.cxx b/src/event/FullyBufferedSocket.cxx
index b20457627..8456d01ca 100644
--- a/src/event/FullyBufferedSocket.cxx
+++ b/src/event/FullyBufferedSocket.cxx
@@ -11,7 +11,7 @@
 inline FullyBufferedSocket::ssize_t
 FullyBufferedSocket::DirectWrite(std::span<const std::byte> src) noexcept
 {
-	const auto nbytes = GetSocket().Write((const char *)src.data(), src.size());
+	const auto nbytes = GetSocket().WriteNoWait(src);
 	if (nbytes < 0) [[unlikely]] {
 		const auto code = GetSocketError();
 		if (IsSocketErrorSendWouldBlock(code))
diff --git a/src/net/SocketDescriptor.cxx b/src/net/SocketDescriptor.cxx
index 35f23a7f9..aceb605ab 100644
--- a/src/net/SocketDescriptor.cxx
+++ b/src/net/SocketDescriptor.cxx
@@ -420,14 +420,25 @@ SocketDescriptor::Send(std::span<const std::byte> src, int flags) const noexcept
 }
 
 ssize_t
-SocketDescriptor::Read(void *buffer, std::size_t length) const noexcept
+SocketDescriptor::ReadNoWait(std::span<std::byte> dest) const noexcept
 {
 	int flags = 0;
 #ifndef _WIN32
 	flags |= MSG_DONTWAIT;
 #endif
 
-	return Receive({static_cast<std::byte *>(buffer), length}, flags);
+	return Receive(dest, flags);
+}
+
+ssize_t
+SocketDescriptor::WriteNoWait(std::span<const std::byte> src) const noexcept
+{
+	int flags = 0;
+#ifndef _WIN32
+	flags |= MSG_DONTWAIT;
+#endif
+
+	return Send(src, flags);
 }
 
 #ifdef _WIN32
diff --git a/src/net/SocketDescriptor.hxx b/src/net/SocketDescriptor.hxx
index 75f5983c8..1b6120309 100644
--- a/src/net/SocketDescriptor.hxx
+++ b/src/net/SocketDescriptor.hxx
@@ -274,12 +274,26 @@ public:
 	 */
 	ssize_t Send(std::span<const std::byte> src, int flags=0) const noexcept;
 
-	ssize_t Read(void *buffer, std::size_t length) const noexcept;
-
-	ssize_t Write(const void *buffer, std::size_t length) const noexcept {
-		return Send({static_cast<const std::byte *>(buffer), length});
+	ssize_t Read(std::span<std::byte> dest) const noexcept {
+		return Receive(dest);
 	}
 
+	ssize_t Write(std::span<const std::byte> src) const noexcept {
+		return Send(src);
+	}
+
+	/**
+	 * Wrapper for Receive() with MSG_DONTWAIT (not available on
+	 * Windows).
+	 */
+	ssize_t ReadNoWait(std::span<std::byte> dest) const noexcept;
+
+	/**
+	 * Wrapper for Receive() with MSG_DONTWAIT (not available on
+	 * Windows).
+	 */
+	ssize_t WriteNoWait(std::span<const std::byte> src) const noexcept;
+
 #ifdef _WIN32
 	int WaitReadable(int timeout_ms) const noexcept;
 	int WaitWritable(int timeout_ms) const noexcept;
diff --git a/src/output/plugins/httpd/HttpdClient.cxx b/src/output/plugins/httpd/HttpdClient.cxx
index fdaf4a5e7..ac3b3010e 100644
--- a/src/output/plugins/httpd/HttpdClient.cxx
+++ b/src/output/plugins/httpd/HttpdClient.cxx
@@ -9,6 +9,7 @@
 #include "IcyMetaDataServer.hxx"
 #include "net/SocketError.hxx"
 #include "net/UniqueSocketDescriptor.hxx"
+#include "util/SpanCast.hxx"
 #include "Log.hxx"
 
 #include <fmt/core.h>
@@ -150,7 +151,7 @@ HttpdClient::SendResponse() noexcept
 		response = allocated.c_str();
 	}
 
-	ssize_t nbytes = GetSocket().Write(response, strlen(response));
+	ssize_t nbytes = GetSocket().WriteNoWait(AsBytes(std::string_view{response}));
 	if (nbytes < 0) [[unlikely]] {
 		const SocketErrorMessage msg;
 		FmtWarning(httpd_output_domain,
@@ -207,8 +208,7 @@ HttpdClient::TryWritePage(const Page &page, size_t position) noexcept
 {
 	assert(position < page.size());
 
-	return GetSocket().Write(page.data() + position,
-				 page.size() - position);
+	return GetSocket().WriteNoWait(std::span<const std::byte>{page}.subspan(position));
 }
 
 ssize_t
@@ -216,7 +216,7 @@ HttpdClient::TryWritePageN(const Page &page,
 			   size_t position, ssize_t n) noexcept
 {
 	return n >= 0
-		? GetSocket().Write(page.data() + position, n)
+		? GetSocket().WriteNoWait({page.data() + position, (std::size_t)n})
 		: TryWritePage(page, position);
 }
 
@@ -283,9 +283,9 @@ HttpdClient::TryWrite() noexcept
 				metadata_sent = true;
 			}
 		} else {
-			char empty_data = 0;
+			static constexpr std::byte empty_data[1]{};
 
-			ssize_t nbytes = GetSocket().Write(&empty_data, 1);
+			ssize_t nbytes = GetSocket().Write(empty_data);
 			if (nbytes < 0) {
 				auto e = GetSocketError();
 				if (IsSocketErrorSendWouldBlock(e))
diff --git a/src/system/EventPipe.cxx b/src/system/EventPipe.cxx
index f03638239..e6d83ceb5 100644
--- a/src/system/EventPipe.cxx
+++ b/src/system/EventPipe.cxx
@@ -37,8 +37,8 @@ EventPipe::Read() noexcept
 	assert(r.IsDefined());
 	assert(w.IsDefined());
 
-	char buffer[256];
-	return r.Read(buffer, sizeof(buffer)) > 0;
+	std::byte buffer[256];
+	return r.Read(buffer) > 0;
 }
 
 void
@@ -47,7 +47,8 @@ EventPipe::Write() noexcept
 	assert(r.IsDefined());
 	assert(w.IsDefined());
 
-	w.Write("", 1);
+	static constexpr std::byte buffer[1]{};
+	w.Write(buffer);
 }
 
 #ifdef _WIN32