From 29747377467168ac8928081d3d37aa5016acd778 Mon Sep 17 00:00:00 2001
From: Shen-Ta Hsieh <ibmibmibm.tw@gmail.com>
Date: Wed, 2 Dec 2020 07:14:51 +0800
Subject: [PATCH] src/win32: Add ComWorker to run all COM function on same
 thread

---
 src/mixer/plugins/WasapiMixerPlugin.cxx   | 149 ++++++++++++----------
 src/output/plugins/WasapiOutputPlugin.cxx | 128 ++++++++++---------
 2 files changed, 149 insertions(+), 128 deletions(-)

diff --git a/src/mixer/plugins/WasapiMixerPlugin.cxx b/src/mixer/plugins/WasapiMixerPlugin.cxx
index d45af3dfa..bc2604633 100644
--- a/src/mixer/plugins/WasapiMixerPlugin.cxx
+++ b/src/mixer/plugins/WasapiMixerPlugin.cxx
@@ -19,7 +19,7 @@
 
 #include "mixer/MixerInternal.hxx"
 #include "output/plugins/WasapiOutputPlugin.hxx"
-#include "win32/Com.hxx"
+#include "win32/ComWorker.hxx"
 #include "win32/HResult.hxx"
 
 #include <cmath>
@@ -28,92 +28,103 @@
 
 class WasapiMixer final : public Mixer {
 	WasapiOutput &output;
-	std::optional<COM> com;
 
 public:
 	WasapiMixer(WasapiOutput &_output, MixerListener &_listener)
 	: Mixer(wasapi_mixer_plugin, _listener), output(_output) {}
 
-	void Open() override { com.emplace(); }
+	void Open() override {}
 
-	void Close() noexcept override { com.reset(); }
+	void Close() noexcept override {}
 
 	int GetVolume() override {
-		HRESULT result;
-		float volume_level;
+		auto future = COMWorker::Async([&]() -> int {
+			HRESULT result;
+			float volume_level;
 
-		if (wasapi_is_exclusive(output)) {
-			ComPtr<IAudioEndpointVolume> endpoint_volume;
-			result = wasapi_output_get_device(output)->Activate(
-				__uuidof(IAudioEndpointVolume), CLSCTX_ALL, nullptr,
-				endpoint_volume.AddressCast());
-			if (FAILED(result)) {
-				throw FormatHResultError(
-					result, "Unable to get device endpoint volume");
+			if (wasapi_is_exclusive(output)) {
+				ComPtr<IAudioEndpointVolume> endpoint_volume;
+				result = wasapi_output_get_device(output)->Activate(
+					__uuidof(IAudioEndpointVolume), CLSCTX_ALL,
+					nullptr, endpoint_volume.AddressCast());
+				if (FAILED(result)) {
+					throw FormatHResultError(result,
+								 "Unable to get device "
+								 "endpoint volume");
+				}
+
+				result = endpoint_volume->GetMasterVolumeLevelScalar(
+					&volume_level);
+				if (FAILED(result)) {
+					throw FormatHResultError(result,
+								 "Unable to get master "
+								 "volume level");
+				}
+			} else {
+				ComPtr<ISimpleAudioVolume> session_volume;
+				result = wasapi_output_get_client(output)->GetService(
+					__uuidof(ISimpleAudioVolume),
+					session_volume.AddressCast<void>());
+				if (FAILED(result)) {
+					throw FormatHResultError(result,
+								 "Unable to get client "
+								 "session volume");
+				}
+
+				result = session_volume->GetMasterVolume(&volume_level);
+				if (FAILED(result)) {
+					throw FormatHResultError(
+						result, "Unable to get master volume");
+				}
 			}
 
-			result = endpoint_volume->GetMasterVolumeLevelScalar(
-				&volume_level);
-			if (FAILED(result)) {
-				throw FormatHResultError(
-					result, "Unable to get master volume level");
-			}
-		} else {
-			ComPtr<ISimpleAudioVolume> session_volume;
-			result = wasapi_output_get_client(output)->GetService(
-				__uuidof(ISimpleAudioVolume),
-				session_volume.AddressCast<void>());
-			if (FAILED(result)) {
-				throw FormatHResultError(
-					result, "Unable to get client session volume");
-			}
-
-			result = session_volume->GetMasterVolume(&volume_level);
-			if (FAILED(result)) {
-				throw FormatHResultError(result,
-							 "Unable to get master volume");
-			}
-		}
-
-		return std::lround(volume_level * 100.0f);
+			return std::lround(volume_level * 100.0f);
+		});
+		return future.get();
 	}
 
 	void SetVolume(unsigned volume) override {
-		HRESULT result;
-		const float volume_level = volume / 100.0f;
+		COMWorker::Async([&]() {
+			HRESULT result;
+			const float volume_level = volume / 100.0f;
 
-		if (wasapi_is_exclusive(output)) {
-			ComPtr<IAudioEndpointVolume> endpoint_volume;
-			result = wasapi_output_get_device(output)->Activate(
-				__uuidof(IAudioEndpointVolume), CLSCTX_ALL, nullptr,
-				endpoint_volume.AddressCast());
-			if (FAILED(result)) {
-				throw FormatHResultError(
-					result, "Unable to get device endpoint volume");
-			}
+			if (wasapi_is_exclusive(output)) {
+				ComPtr<IAudioEndpointVolume> endpoint_volume;
+				result = wasapi_output_get_device(output)->Activate(
+					__uuidof(IAudioEndpointVolume), CLSCTX_ALL,
+					nullptr, endpoint_volume.AddressCast());
+				if (FAILED(result)) {
+					throw FormatHResultError(
+						result,
+						"Unable to get device endpoint volume");
+				}
 
-			result = endpoint_volume->SetMasterVolumeLevelScalar(volume_level,
-									     nullptr);
-			if (FAILED(result)) {
-				throw FormatHResultError(
-					result, "Unable to set master volume level");
-			}
-		} else {
-			ComPtr<ISimpleAudioVolume> session_volume;
-			result = wasapi_output_get_client(output)->GetService(
-				__uuidof(ISimpleAudioVolume),
-				session_volume.AddressCast<void>());
-			if (FAILED(result)) {
-				throw FormatHResultError(
-					result, "Unable to get client session volume");
-			}
+				result = endpoint_volume->SetMasterVolumeLevelScalar(
+					volume_level, nullptr);
+				if (FAILED(result)) {
+					throw FormatHResultError(
+						result,
+						"Unable to set master volume level");
+				}
+			} else {
+				ComPtr<ISimpleAudioVolume> session_volume;
+				result = wasapi_output_get_client(output)->GetService(
+					__uuidof(ISimpleAudioVolume),
+					session_volume.AddressCast<void>());
+				if (FAILED(result)) {
+					throw FormatHResultError(
+						result,
+						"Unable to get client session volume");
+				}
 
-			result = session_volume->SetMasterVolume(volume_level, nullptr);
-			if (FAILED(result)) {
-				throw FormatHResultError(result,
-							 "Unable to set master volume");
+				result = session_volume->SetMasterVolume(volume_level,
+									 nullptr);
+				if (FAILED(result)) {
+					throw FormatHResultError(
+						result, "Unable to set master volume");
+				}
 			}
-		}
+		}).get();
 	}
 };
 
diff --git a/src/output/plugins/WasapiOutputPlugin.cxx b/src/output/plugins/WasapiOutputPlugin.cxx
index d9d16e97e..8c52bf15f 100644
--- a/src/output/plugins/WasapiOutputPlugin.cxx
+++ b/src/output/plugins/WasapiOutputPlugin.cxx
@@ -32,6 +32,7 @@
 #include "util/ScopeExit.hxx"
 #include "win32/Com.hxx"
 #include "win32/ComHeapPtr.hxx"
+#include "win32/ComWorker.hxx"
 #include "win32/HResult.hxx"
 #include "win32/WinEvent.hxx"
 
@@ -144,7 +145,6 @@ public:
 
 private:
 	std::shared_ptr<WinEvent> event;
-	std::optional<COM> com;
 	ComPtr<IAudioClient> client;
 	ComPtr<IAudioRenderClient> render_client;
 	const UINT32 frame_size;
@@ -174,9 +174,17 @@ class WasapiOutput final : public AudioOutput {
 public:
 	static AudioOutput *Create(EventLoop &, const ConfigBlock &block);
 	WasapiOutput(const ConfigBlock &block);
-	void Enable() override;
-	void Disable() noexcept override;
-	void Open(AudioFormat &audio_format) override;
+	void Enable() override {
+		COMWorker::Aquire();
+		COMWorker::Async([&]() { OpenDevice(); }).get();
+	}
+	void Disable() noexcept override {
+		COMWorker::Async([&]() { DoDisable(); }).get();
+		COMWorker::Release();
+	}
+	void Open(AudioFormat &audio_format) override {
+		COMWorker::Async([&]() { DoOpen(audio_format); }).get();
+	}
 	void Close() noexcept override;
 	std::chrono::steady_clock::duration Delay() const noexcept override;
 	size_t Play(const void *chunk, size_t size) override;
@@ -196,7 +204,6 @@ private:
 	std::string device_config;
 	std::vector<std::pair<unsigned int, AllocatedString>> device_desc;
 	std::shared_ptr<WinEvent> event;
-	std::optional<COM> com;
 	ComPtr<IMMDeviceEnumerator> enumerator;
 	ComPtr<IMMDevice> device;
 	ComPtr<IAudioClient> client;
@@ -209,6 +216,10 @@ private:
 	friend IMMDevice *wasapi_output_get_device(WasapiOutput &output) noexcept;
 	friend IAudioClient *wasapi_output_get_client(WasapiOutput &output) noexcept;
 
+	void DoDisable() noexcept;
+	void DoOpen(AudioFormat &audio_format);
+
+	void OpenDevice();
 	void FindExclusiveFormatSupported(AudioFormat &audio_format);
 	void FindSharedFormatSupported(AudioFormat &audio_format);
 	void EnumerateDevices();
@@ -234,15 +245,7 @@ IAudioClient *wasapi_output_get_client(WasapiOutput &output) noexcept {
 void WasapiOutputThread::Work() noexcept {
 	SetThreadName("Wasapi Output Worker");
 	FormatDebug(wasapi_output_domain, "Working thread started");
-	try {
-		com.emplace();
-	} catch (...) {
-		std::unique_lock<Mutex> lock(error.mutex);
-		error.error_ptr = std::current_exception();
-		error.cond.wait(lock);
-		assert(error.error_ptr == nullptr);
-		return;
-	}
+	COM com{true};
 	while (true) {
 		try {
 			event->Wait(INFINITE);
@@ -316,41 +319,8 @@ WasapiOutput::WasapiOutput(const ConfigBlock &block)
   enumerate_devices(block.GetBlockValue("enumerate", false)),
   device_config(block.GetBlockValue("device", "")) {}
 
-void WasapiOutput::Enable() {
-	com.emplace();
-	event = std::make_shared<WinEvent>();
-	enumerator.CoCreateInstance(__uuidof(MMDeviceEnumerator), nullptr,
-				    CLSCTX_INPROC_SERVER);
-
-	device_desc.clear();
-	device.reset();
-
-	if (enumerate_devices && SafeTry([this]() { EnumerateDevices(); })) {
-		for (const auto &desc : device_desc) {
-			FormatNotice(wasapi_output_domain, "Device \"%u\" \"%s\"",
-				     desc.first, desc.second.c_str());
-		}
-	}
-
-	unsigned int id = kErrorId;
-	if (!device_config.empty()) {
-		if (!SafeSilenceTry([this, &id]() { id = std::stoul(device_config); })) {
-			id = SearchDevice(device_config);
-		}
-	}
-
-	if (id != kErrorId) {
-		SafeTry([this, id]() { GetDevice(id); });
-	}
-
-	if (!device) {
-		GetDefaultDevice();
-	}
-
-	device_desc.clear();
-}
-
-void WasapiOutput::Disable() noexcept {
+/// run inside COMWorkerThread
+void WasapiOutput::DoDisable() noexcept {
 	if (thread) {
 		try {
 			thread->Finish();
@@ -369,7 +339,8 @@ void WasapiOutput::Disable() noexcept {
 	event.reset();
 }
 
-void WasapiOutput::Open(AudioFormat &audio_format) {
+/// run inside COMWorkerThread
+void WasapiOutput::DoOpen(AudioFormat &audio_format) {
 	if (audio_format.channels == 0) {
 		throw FormatInvalidArgument("channels should > 0");
 	}
@@ -497,9 +468,11 @@ void WasapiOutput::Close() noexcept {
 	Pause();
 	thread->Finish();
 	thread->Join();
-	thread.reset();
+	COMWorker::Async([&]() {
+		thread.reset();
+		client.reset();
+	}).get();
 	spsc_buffer.reset();
-	client.reset();
 }
 
 std::chrono::steady_clock::duration WasapiOutput::Delay() const noexcept {
@@ -534,13 +507,14 @@ size_t WasapiOutput::Play(const void *chunk, size_t size) {
 			is_started = true;
 
 			thread->Play();
-
-			HRESULT result;
-			result = client->Start();
-			if (FAILED(result)) {
-				throw FormatHResultError(result,
-							 "Failed to start client");
-			}
+			COMWorker::Async([&]() {
+				HRESULT result;
+				result = client->Start();
+				if (FAILED(result)) {
+					throw FormatHResultError(
+						result, "Failed to start client");
+				}
+			}).wait();
 		}
 
 		thread->CheckException();
@@ -575,6 +549,37 @@ bool WasapiOutput::Pause() {
 	return true;
 }
 
+/// run inside COMWorkerThread
+void WasapiOutput::OpenDevice() {
+	enumerator.CoCreateInstance(__uuidof(MMDeviceEnumerator), nullptr,
+				    CLSCTX_INPROC_SERVER);
+
+	if (enumerate_devices && SafeTry([this]() { EnumerateDevices(); })) {
+		for (const auto &desc : device_desc) {
+			FormatNotice(wasapi_output_domain, "Device \"%u\" \"%s\"",
+				     desc.first, desc.second.c_str());
+		}
+	}
+
+	unsigned int id = kErrorId;
+	if (!device_config.empty()) {
+		if (!SafeSilenceTry([this, &id]() { id = std::stoul(device_config); })) {
+			id = SearchDevice(device_config);
+		}
+	}
+
+	if (id != kErrorId) {
+		SafeTry([this, id]() { GetDevice(id); });
+	}
+
+	if (!device) {
+		GetDefaultDevice();
+	}
+
+	device_desc.clear();
+}
+
+/// run inside COMWorkerThread
 void WasapiOutput::FindExclusiveFormatSupported(AudioFormat &audio_format) {
 	SetFormat(device_format, audio_format);
 
@@ -641,6 +646,7 @@ void WasapiOutput::FindExclusiveFormatSupported(AudioFormat &audio_format) {
 	} while (true);
 }
 
+/// run inside COMWorkerThread
 void WasapiOutput::FindSharedFormatSupported(AudioFormat &audio_format) {
 	HRESULT result;
 	ComHeapPtr<WAVEFORMATEX> mixer_format;
@@ -724,6 +730,7 @@ void WasapiOutput::FindSharedFormatSupported(AudioFormat &audio_format) {
 	}
 }
 
+/// run inside COMWorkerThread
 void WasapiOutput::EnumerateDevices() {
 	if (!device_desc.empty()) {
 		return;
@@ -776,6 +783,7 @@ void WasapiOutput::EnumerateDevices() {
 	}
 }
 
+/// run inside COMWorkerThread
 void WasapiOutput::GetDevice(unsigned int index) {
 	HRESULT result;
 
@@ -792,6 +800,7 @@ void WasapiOutput::GetDevice(unsigned int index) {
 	}
 }
 
+/// run inside COMWorkerThread
 unsigned int WasapiOutput::SearchDevice(std::string_view name) {
 	if (!SafeTry([this]() { EnumerateDevices(); })) {
 		return kErrorId;
@@ -809,6 +818,7 @@ unsigned int WasapiOutput::SearchDevice(std::string_view name) {
 	return iter->first;
 }
 
+/// run inside COMWorkerThread
 void WasapiOutput::GetDefaultDevice() {
 	HRESULT result;
 	result = enumerator->GetDefaultAudioEndpoint(eRender, eMultimedia,