From 88e277428ebefbb1f20f8dc688a5c59002276c62 Mon Sep 17 00:00:00 2001 From: Takeshi Yoshino <4511440+tyoshino@users.noreply.github.com> Date: Wed, 22 Apr 2026 06:45:43 +0000 Subject: [PATCH] fix(wish/python): implement `__aexit__` of `WishConnection` to clean up resources --- wish/cpp/src/plain_client.cc | 6 + wish/cpp/src/plain_client.h | 1 + wish/cpp/src/tls_client.cc | 6 + wish/cpp/src/tls_client.h | 1 + wish/python/src/wish_ext.cc | 212 +++++++++++++++++------------------ wish/python/wish/client.py | 16 ++- 6 files changed, 130 insertions(+), 112 deletions(-) diff --git a/wish/cpp/src/plain_client.cc b/wish/cpp/src/plain_client.cc index 28181c7..92bd6e2 100644 --- a/wish/cpp/src/plain_client.cc +++ b/wish/cpp/src/plain_client.cc @@ -78,3 +78,9 @@ void PlainClient::Run() { event_base_dispatch(base_); } + +void PlainClient::Stop() { + if (base_) { + event_base_loopexit(base_, nullptr); + } +} diff --git a/wish/cpp/src/plain_client.h b/wish/cpp/src/plain_client.h index 62e9ef7..7a85f8d 100644 --- a/wish/cpp/src/plain_client.h +++ b/wish/cpp/src/plain_client.h @@ -22,6 +22,7 @@ class PlainClient { void SetOnOpen(OpenCallback cb); void SetOnMessage(MessageCallback cb); void Run(); + void Stop(); private: std::string host_; diff --git a/wish/cpp/src/tls_client.cc b/wish/cpp/src/tls_client.cc index 3bc342e..679e214 100644 --- a/wish/cpp/src/tls_client.cc +++ b/wish/cpp/src/tls_client.cc @@ -104,3 +104,9 @@ void TlsClient::Run() { event_base_dispatch(base_); } + +void TlsClient::Stop() { + if (base_) { + event_base_loopexit(base_, nullptr); + } +} diff --git a/wish/cpp/src/tls_client.h b/wish/cpp/src/tls_client.h index eb2be90..862a4b7 100644 --- a/wish/cpp/src/tls_client.h +++ b/wish/cpp/src/tls_client.h @@ -23,6 +23,7 @@ class TlsClient { void SetOnOpen(OpenCallback cb); void SetOnMessage(MessageCallback cb); void Run(); + void Stop(); private: std::string ca_file_; diff --git a/wish/python/src/wish_ext.cc b/wish/python/src/wish_ext.cc index 2b92b54..dd3016a 100644 --- a/wish/python/src/wish_ext.cc +++ b/wish/python/src/wish_ext.cc @@ -1,11 +1,11 @@ +#include #include -#include #include -#include +#include -#include "wish_handler.h" -#include "tls_client.h" #include "plain_client.h" +#include "tls_client.h" +#include "wish_handler.h" namespace nb = nanobind; @@ -21,22 +21,22 @@ namespace nb = nanobind; // --------------------------------------------------------------------------- struct TlsClientPy { - TlsClient client; - nb::object on_open_cb; - nb::object on_message_cb; + TlsClient client; + nb::object on_open_cb; + nb::object on_message_cb; - TlsClientPy(const std::string& ca, const std::string& cert, - const std::string& key, const std::string& host, int port) - : client(ca, cert, key, host, port) {} + TlsClientPy(const std::string& ca, const std::string& cert, + const std::string& key, const std::string& host, int port) + : client(ca, cert, key, host, port) {} }; struct PlainClientPy { - PlainClient client; - nb::object on_open_cb; - nb::object on_message_cb; + PlainClient client; + nb::object on_open_cb; + nb::object on_message_cb; - PlainClientPy(const std::string& host, int port) - : client(host, port) {} + PlainClientPy(const std::string& host, int port) + : client(host, port) {} }; // --------------------------------------------------------------------------- @@ -44,21 +44,21 @@ struct PlainClientPy { // --------------------------------------------------------------------------- static int tls_traverse(PyObject* self, visitproc visit, void* arg) { - TlsClientPy* w = nb::inst_ptr(nb::handle(self)); - Py_VISIT(w->on_open_cb.ptr()); - Py_VISIT(w->on_message_cb.ptr()); - return 0; + TlsClientPy* w = nb::inst_ptr(nb::handle(self)); + Py_VISIT(w->on_open_cb.ptr()); + Py_VISIT(w->on_message_cb.ptr()); + return 0; } static int tls_clear(PyObject* self) { - TlsClientPy* w = nb::inst_ptr(nb::handle(self)); - // Clear the C++ callbacks first so the lambda (which captures &*w) is - // dropped before we invalidate on_open_cb / on_message_cb. - w->client.SetOnOpen({}); - w->client.SetOnMessage({}); - w->on_open_cb = nb::object(); - w->on_message_cb = nb::object(); - return 0; + TlsClientPy* w = nb::inst_ptr(nb::handle(self)); + // Clear the C++ callbacks first so the lambda (which captures &*w) is + // dropped before we invalidate on_open_cb / on_message_cb. + w->client.SetOnOpen({}); + w->client.SetOnMessage({}); + w->on_open_cb = nb::object(); + w->on_message_cb = nb::object(); + return 0; } // --------------------------------------------------------------------------- @@ -66,96 +66,94 @@ static int tls_clear(PyObject* self) { // --------------------------------------------------------------------------- static int plain_traverse(PyObject* self, visitproc visit, void* arg) { - PlainClientPy* w = nb::inst_ptr(nb::handle(self)); - Py_VISIT(w->on_open_cb.ptr()); - Py_VISIT(w->on_message_cb.ptr()); - return 0; + PlainClientPy* w = nb::inst_ptr(nb::handle(self)); + Py_VISIT(w->on_open_cb.ptr()); + Py_VISIT(w->on_message_cb.ptr()); + return 0; } static int plain_clear(PyObject* self) { - PlainClientPy* w = nb::inst_ptr(nb::handle(self)); - w->client.SetOnOpen({}); - w->client.SetOnMessage({}); - w->on_open_cb = nb::object(); - w->on_message_cb = nb::object(); - return 0; + PlainClientPy* w = nb::inst_ptr(nb::handle(self)); + w->client.SetOnOpen({}); + w->client.SetOnMessage({}); + w->on_open_cb = nb::object(); + w->on_message_cb = nb::object(); + return 0; } // --------------------------------------------------------------------------- NB_MODULE(wish_ext, m) { - // Enable libevent thread-safety + // Enable libevent thread-safety #ifdef _WIN32 - evthread_use_windows_threads(); + evthread_use_windows_threads(); #else - evthread_use_pthreads(); + evthread_use_pthreads(); #endif - nb::class_(m, "WishHandler") - .def("send_text", &WishHandler::SendText) - .def("send_binary", &WishHandler::SendBinary); - - // ---- TlsClient -------------------------------------------------------- - static PyType_Slot tls_slots[] = { - {Py_tp_traverse, (void*)tls_traverse}, - {Py_tp_clear, (void*)tls_clear}, - {0, nullptr}, - }; - - nb::class_(m, "TlsClient", nb::type_slots(tls_slots)) - .def(nb::init()) - .def("init", [](TlsClientPy& self) { - return self.client.Init(); - }) - .def("set_on_open", [](TlsClientPy& self, nb::object cb) { - self.on_open_cb = cb; - // Capture self by pointer; lifetime is safe because the lambda - // lives inside self.client and is always cleared before self - // is destroyed (either by tp_clear or ~TlsClient). - self.client.SetOnOpen([&self](WishHandler* handler) { - nb::gil_scoped_acquire acquire; - self.on_open_cb(nb::cast(handler, nb::rv_policy::reference)); - }); - }) - .def("set_on_message", [](TlsClientPy& self, nb::object cb) { - self.on_message_cb = cb; - self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) { - nb::gil_scoped_acquire acquire; - self.on_message_cb(opcode, msg); - }); - }) - .def("run", [](TlsClientPy& self) { self.client.Run(); }, - nb::call_guard()) - .def("stop", [](TlsClientPy& self) { self.client.Stop(); }); - - // ---- PlainClient ------------------------------------------------------ - static PyType_Slot plain_slots[] = { - {Py_tp_traverse, (void*)plain_traverse}, - {Py_tp_clear, (void*)plain_clear}, - {0, nullptr}, - }; - - nb::class_(m, "PlainClient", nb::type_slots(plain_slots)) - .def(nb::init()) - .def("init", [](PlainClientPy& self) { - return self.client.Init(); - }) - .def("set_on_open", [](PlainClientPy& self, nb::object cb) { - self.on_open_cb = cb; - self.client.SetOnOpen([&self](WishHandler* handler) { - nb::gil_scoped_acquire acquire; - self.on_open_cb(nb::cast(handler, nb::rv_policy::reference)); - }); - }) - .def("set_on_message", [](PlainClientPy& self, nb::object cb) { - self.on_message_cb = cb; - self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) { - nb::gil_scoped_acquire acquire; - self.on_message_cb(opcode, msg); - }); - }) - .def("run", [](PlainClientPy& self) { self.client.Run(); }, - nb::call_guard()) - .def("stop", [](PlainClientPy& self) { self.client.Stop(); }); + nb::class_(m, "WishHandler") + .def("send_text", &WishHandler::SendText) + .def("send_binary", &WishHandler::SendBinary); + + // ---- TlsClient -------------------------------------------------------- + static PyType_Slot tls_slots[] = { + {Py_tp_traverse, (void*)tls_traverse}, + {Py_tp_clear, (void*)tls_clear}, + {0, nullptr}, + }; + + nb::class_(m, "TlsClient", nb::type_slots(tls_slots)) + .def(nb::init()) + .def("init", [](TlsClientPy& self) { + return self.client.Init(); + }) + .def("set_on_open", [](TlsClientPy& self, nb::object cb) { + self.on_open_cb = cb; + // Capture self by pointer; lifetime is safe because the lambda + // lives inside self.client and is always cleared before self + // is destroyed (either by tp_clear or ~TlsClient). + self.client.SetOnOpen([&self](WishHandler* handler) { + nb::gil_scoped_acquire acquire; + self.on_open_cb(nb::cast(handler, nb::rv_policy::reference)); + }); + }) + .def("set_on_message", [](TlsClientPy& self, nb::object cb) { + self.on_message_cb = cb; + self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) { + nb::gil_scoped_acquire acquire; + self.on_message_cb(opcode, msg); + }); + }) + .def("run", [](TlsClientPy& self) { self.client.Run(); }, nb::call_guard()) + .def("stop", [](TlsClientPy& self) { self.client.Stop(); }); + + // ---- PlainClient ------------------------------------------------------ + static PyType_Slot plain_slots[] = { + {Py_tp_traverse, (void*)plain_traverse}, + {Py_tp_clear, (void*)plain_clear}, + {0, nullptr}, + }; + + nb::class_(m, "PlainClient", nb::type_slots(plain_slots)) + .def(nb::init()) + .def("init", [](PlainClientPy& self) { + return self.client.Init(); + }) + .def("set_on_open", [](PlainClientPy& self, nb::object cb) { + self.on_open_cb = cb; + self.client.SetOnOpen([&self](WishHandler* handler) { + nb::gil_scoped_acquire acquire; + self.on_open_cb(nb::cast(handler, nb::rv_policy::reference)); + }); + }) + .def("set_on_message", [](PlainClientPy& self, nb::object cb) { + self.on_message_cb = cb; + self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) { + nb::gil_scoped_acquire acquire; + self.on_message_cb(opcode, msg); + }); + }) + .def("run", [](PlainClientPy& self) { self.client.Run(); }, nb::call_guard()) + .def("stop", [](PlainClientPy& self) { self.client.Stop(); }); } diff --git a/wish/python/wish/client.py b/wish/python/wish/client.py index fbfccdc..eec980e 100644 --- a/wish/python/wish/client.py +++ b/wish/python/wish/client.py @@ -18,6 +18,7 @@ def __init__(self, host, port, tls, ca_file="", cert_file="", key_file=""): self._loop = asyncio.get_running_loop() self._recv_queue = asyncio.Queue() self._open_future = self._loop.create_future() + self._run_future = None self._handler = None def on_open(handler): @@ -32,11 +33,19 @@ def on_message(opcode, msg): async def connect(self): # Run the C++ event loop in a background thread. - self._loop.run_in_executor(None, self._client.run) + # Keep the Future so we can await thread completion on close. + self._run_future = self._loop.run_in_executor(None, self._client.run) # Wait until the on_open callback fires await self._open_future return self + async def close(self): + """Stop the C++ event loop and wait for the background thread to exit.""" + self._client.stop() + if self._run_future is not None: + await self._run_future + self._run_future = None + async def send(self, data): """Sends data over the WiSH connection. If data is bytes, sends as binary, else text.""" if not self._handler: @@ -76,10 +85,7 @@ async def __aenter__(self): return self.conn async def __aexit__(self, exc_type, exc_val, exc_tb): - # We don't have explicit close method on TlsClient yet, - # but the connection would drop if the process exits, - # or we could add a shutdown mechanism. - pass + await self.conn.close() def connect(uri, ca_file="", cert_file="", key_file=""): return _ConnectContextManager(uri, ca_file, cert_file, key_file)