From 90d97053a867b78ecd5b52daa1935e04cf3beeb2 Mon Sep 17 00:00:00 2001
From: Max Kellermann <max@musicpd.org>
Date: Sat, 6 Mar 2021 20:33:37 +0100
Subject: [PATCH] win32/ComWorker: make COMWorker a real class, no static
 members

---
 src/mixer/plugins/WasapiMixerPlugin.cxx       | 12 +++++-
 src/output/plugins/wasapi/ForMixer.hxx        |  7 ++++
 .../plugins/wasapi/WasapiOutputPlugin.cxx     | 31 +++++++++++-----
 src/win32/ComWorker.cxx                       |  4 --
 src/win32/ComWorker.hxx                       | 37 +++++++------------
 5 files changed, 52 insertions(+), 39 deletions(-)

diff --git a/src/mixer/plugins/WasapiMixerPlugin.cxx b/src/mixer/plugins/WasapiMixerPlugin.cxx
index 18c862f29..d87f1ca26 100644
--- a/src/mixer/plugins/WasapiMixerPlugin.cxx
+++ b/src/mixer/plugins/WasapiMixerPlugin.cxx
@@ -44,7 +44,11 @@ public:
 	void Close() noexcept override {}
 
 	int GetVolume() override {
-		auto future = COMWorker::Async([&]() -> int {
+		auto com_worker = wasapi_output_get_com_worker(output);
+		if (!com_worker)
+			return -1;
+
+		auto future = com_worker->Async([&]() -> int {
 			HRESULT result;
 			float volume_level;
 
@@ -76,7 +80,11 @@ public:
 	}
 
 	void SetVolume(unsigned volume) override {
-		COMWorker::Async([&]() {
+		auto com_worker = wasapi_output_get_com_worker(output);
+		if (!com_worker)
+			throw std::runtime_error("Cannot set WASAPI volume");
+
+		com_worker->Async([&]() {
 			HRESULT result;
 			const float volume_level = volume / 100.0f;
 
diff --git a/src/output/plugins/wasapi/ForMixer.hxx b/src/output/plugins/wasapi/ForMixer.hxx
index 2d815ce61..0d090a3ae 100644
--- a/src/output/plugins/wasapi/ForMixer.hxx
+++ b/src/output/plugins/wasapi/ForMixer.hxx
@@ -20,10 +20,13 @@
 #ifndef MPD_WASAPI_OUTPUT_FOR_MIXER_HXX
 #define MPD_WASAPI_OUTPUT_FOR_MIXER_HXX
 
+#include <memory>
+
 struct IMMDevice;
 struct IAudioClient;
 class AudioOutput;
 class WasapiOutput;
+class COMWorker;
 
 [[gnu::pure]]
 WasapiOutput &
@@ -33,6 +36,10 @@ wasapi_output_downcast(AudioOutput &output) noexcept;
 bool
 wasapi_is_exclusive(WasapiOutput &output) noexcept;
 
+[[gnu::pure]]
+std::shared_ptr<COMWorker>
+wasapi_output_get_com_worker(WasapiOutput &output) noexcept;
+
 [[gnu::pure]]
 IMMDevice *
 wasapi_output_get_device(WasapiOutput &output) noexcept;
diff --git a/src/output/plugins/wasapi/WasapiOutputPlugin.cxx b/src/output/plugins/wasapi/WasapiOutputPlugin.cxx
index 3882474f1..f9dfb5429 100644
--- a/src/output/plugins/wasapi/WasapiOutputPlugin.cxx
+++ b/src/output/plugins/wasapi/WasapiOutputPlugin.cxx
@@ -214,22 +214,28 @@ class WasapiOutput final : public AudioOutput {
 public:
 	static AudioOutput *Create(EventLoop &, const ConfigBlock &block);
 	WasapiOutput(const ConfigBlock &block);
+
+	auto GetComWorker() noexcept {
+		// TODO: protect access to the shard_ptr
+		return com_worker;
+	}
+
 	void Enable() override {
-		COMWorker::Aquire();
+		com_worker = std::make_shared<COMWorker>();
 
 		try {
-			COMWorker::Async([&]() { OpenDevice(); }).get();
+			com_worker->Async([&]() { OpenDevice(); }).get();
 		} catch (...) {
-			COMWorker::Release();
+			com_worker.reset();
 			throw;
 		}
 	}
 	void Disable() noexcept override {
-		COMWorker::Async([&]() { DoDisable(); }).get();
-		COMWorker::Release();
+		com_worker->Async([&]() { DoDisable(); }).get();
+		com_worker.reset();
 	}
 	void Open(AudioFormat &audio_format) override {
-		COMWorker::Async([&]() { DoOpen(audio_format); }).get();
+		com_worker->Async([&]() { DoOpen(audio_format); }).get();
 	}
 	void Close() noexcept override;
 	std::chrono::steady_clock::duration Delay() const noexcept override;
@@ -253,6 +259,7 @@ private:
 	bool dop_setting;
 #endif
 	std::string device_config;
+	std::shared_ptr<COMWorker> com_worker;
 	ComPtr<IMMDeviceEnumerator> enumerator;
 	ComPtr<IMMDevice> device;
 	ComPtr<IAudioClient> client;
@@ -283,6 +290,12 @@ WasapiOutput &wasapi_output_downcast(AudioOutput &output) noexcept {
 
 bool wasapi_is_exclusive(WasapiOutput &output) noexcept { return output.is_exclusive; }
 
+std::shared_ptr<COMWorker>
+wasapi_output_get_com_worker(WasapiOutput &output) noexcept
+{
+	return output.GetComWorker();
+}
+
 IMMDevice *wasapi_output_get_device(WasapiOutput &output) noexcept {
 	return output.device.get();
 }
@@ -524,7 +537,7 @@ void WasapiOutput::Close() noexcept {
 	assert(thread);
 
 	try {
-		COMWorker::Async([&]() {
+		com_worker->Async([&]() {
 			Stop(*client);
 		}).get();
 		thread->CheckException();
@@ -535,7 +548,7 @@ void WasapiOutput::Close() noexcept {
 	is_started = false;
 	thread->Finish();
 	thread->Join();
-	COMWorker::Async([&]() {
+	com_worker->Async([&]() {
 		thread.reset();
 		client.reset();
 	}).get();
@@ -586,7 +599,7 @@ size_t WasapiOutput::Play(const void *chunk, size_t size) {
 		if (!is_started) {
 			is_started = true;
 			thread->Play();
-			COMWorker::Async([&]() {
+			com_worker->Async([&]() {
 				Start(*client);
 			}).wait();
 		}
diff --git a/src/win32/ComWorker.cxx b/src/win32/ComWorker.cxx
index 702c2c427..01f67a5f9 100644
--- a/src/win32/ComWorker.cxx
+++ b/src/win32/ComWorker.cxx
@@ -21,10 +21,6 @@
 #include "Com.hxx"
 #include "thread/Name.hxx"
 
-Mutex COMWorker::mutex;
-unsigned int COMWorker::reference_count = 0;
-std::optional<COMWorker::COMWorkerThread> COMWorker::thread;
-
 void COMWorker::COMWorkerThread::Work() noexcept {
 	SetThreadName("COM Worker");
 	COM com{true};
diff --git a/src/win32/ComWorker.hxx b/src/win32/ComWorker.hxx
index 10f2123a4..f7e29a395 100644
--- a/src/win32/ComWorker.hxx
+++ b/src/win32/ComWorker.hxx
@@ -22,12 +22,9 @@
 
 #include "WinEvent.hxx"
 #include "thread/Future.hxx"
-#include "thread/Mutex.hxx"
 #include "thread/Thread.hxx"
 
 #include <boost/lockfree/spsc_queue.hpp>
-#include <mutex>
-#include <optional>
 
 #include <windows.h>
 
@@ -56,31 +53,25 @@ private:
 	};
 
 public:
-	static void Aquire() {
-		std::unique_lock locker(mutex);
-		if (reference_count == 0) {
-			thread.emplace();
-			thread->Start();
-		}
-		++reference_count;
-	}
-	static void Release() noexcept {
-		std::unique_lock locker(mutex);
-		--reference_count;
-		if (reference_count == 0) {
-			thread->Finish();
-			thread->Join();
-			thread.reset();
-		}
+	COMWorker() {
+		thread.Start();
 	}
 
+	~COMWorker() noexcept {
+		thread.Finish();
+		thread.Join();
+	}
+
+	COMWorker(const COMWorker &) = delete;
+	COMWorker &operator=(const COMWorker &) = delete;
+
 	template <typename Function, typename... Args>
-	static auto Async(Function &&function, Args &&...args) {
+	auto Async(Function &&function, Args &&...args) {
 		using R = std::invoke_result_t<std::decay_t<Function>,
 					       std::decay_t<Args>...>;
 		auto promise = std::make_shared<Promise<R>>();
 		auto future = promise->get_future();
-		thread->Push([function = std::forward<Function>(function),
+		thread.Push([function = std::forward<Function>(function),
 			      args = std::make_tuple(std::forward<Args>(args)...),
 			      promise = std::move(promise)]() mutable {
 			try {
@@ -101,9 +92,7 @@ public:
 	}
 
 private:
-	static Mutex mutex;
-	static unsigned int reference_count;
-	static std::optional<COMWorkerThread> thread;
+	COMWorkerThread thread;
 };
 
 #endif