From 9c5ac63438f0c8d78599624b5692efa628958c36 Mon Sep 17 00:00:00 2001 From: Takeshi Yoshino <4511440+tyoshino@users.noreply.github.com> Date: Wed, 22 Apr 2026 06:41:27 +0000 Subject: [PATCH] fix(wish/python): fix the Python bindings to be garbage-collectable --- wish/python/src/wish_ext.cc | 151 +++++++++++++++++++++++++++++------- 1 file changed, 124 insertions(+), 27 deletions(-) diff --git a/wish/python/src/wish_ext.cc b/wish/python/src/wish_ext.cc index 58fc668..2b92b54 100644 --- a/wish/python/src/wish_ext.cc +++ b/wish/python/src/wish_ext.cc @@ -9,6 +9,80 @@ namespace nb = nanobind; +// --------------------------------------------------------------------------- +// Wrapper structs +// +// Storing the Python callbacks as direct nb::object members (rather than +// inside a std::function closure) makes them visible to Python's cyclic GC +// via tp_traverse / tp_clear, breaking cycles such as: +// +// WishConnection → TlsClient (Python) → on_open_ lambda +// → nb::object → Python closure → WishConnection +// --------------------------------------------------------------------------- + +struct TlsClientPy { + TlsClient client; + nb::object on_open_cb; + nb::object on_message_cb; + + TlsClientPy(const std::string& ca, const std::string& cert, + const std::string& key, const std::string& host, int port) + : client(ca, cert, key, host, port) {} +}; + +struct PlainClientPy { + PlainClient client; + nb::object on_open_cb; + nb::object on_message_cb; + + PlainClientPy(const std::string& host, int port) + : client(host, port) {} +}; + +// --------------------------------------------------------------------------- +// tp_traverse / tp_clear for TlsClientPy +// --------------------------------------------------------------------------- + +static int tls_traverse(PyObject* self, visitproc visit, void* arg) { + TlsClientPy* w = nb::inst_ptr(nb::handle(self)); + Py_VISIT(w->on_open_cb.ptr()); + Py_VISIT(w->on_message_cb.ptr()); + return 0; +} + +static int tls_clear(PyObject* self) { + TlsClientPy* w = nb::inst_ptr(nb::handle(self)); + // Clear the C++ callbacks first so the lambda (which captures &*w) is + // dropped before we invalidate on_open_cb / on_message_cb. + w->client.SetOnOpen({}); + w->client.SetOnMessage({}); + w->on_open_cb = nb::object(); + w->on_message_cb = nb::object(); + return 0; +} + +// --------------------------------------------------------------------------- +// tp_traverse / tp_clear for PlainClientPy +// --------------------------------------------------------------------------- + +static int plain_traverse(PyObject* self, visitproc visit, void* arg) { + PlainClientPy* w = nb::inst_ptr(nb::handle(self)); + Py_VISIT(w->on_open_cb.ptr()); + Py_VISIT(w->on_message_cb.ptr()); + return 0; +} + +static int plain_clear(PyObject* self) { + PlainClientPy* w = nb::inst_ptr(nb::handle(self)); + w->client.SetOnOpen({}); + w->client.SetOnMessage({}); + w->on_open_cb = nb::object(); + w->on_message_cb = nb::object(); + return 0; +} + +// --------------------------------------------------------------------------- + NB_MODULE(wish_ext, m) { // Enable libevent thread-safety #ifdef _WIN32 @@ -18,47 +92,70 @@ NB_MODULE(wish_ext, m) { #endif nb::class_(m, "WishHandler") - .def("send_text", &WishHandler::SendText) + .def("send_text", &WishHandler::SendText) .def("send_binary", &WishHandler::SendBinary); - nb::class_(m, "TlsClient") - .def(nb::init()) - .def("init", &TlsClient::Init) - .def("set_on_open", [](TlsClient& self, nb::object cb) { - // Store the Python callable in a shared_ptr so it survives GC. - // The callback will be invoked from the libevent thread, so we - // must acquire the GIL before touching any Python objects. - auto cb_ptr = std::make_shared(cb); - self.SetOnOpen([cb_ptr](WishHandler* handler) { + // ---- TlsClient -------------------------------------------------------- + static PyType_Slot tls_slots[] = { + {Py_tp_traverse, (void*)tls_traverse}, + {Py_tp_clear, (void*)tls_clear}, + {0, nullptr}, + }; + + nb::class_(m, "TlsClient", nb::type_slots(tls_slots)) + .def(nb::init()) + .def("init", [](TlsClientPy& self) { + return self.client.Init(); + }) + .def("set_on_open", [](TlsClientPy& self, nb::object cb) { + self.on_open_cb = cb; + // Capture self by pointer; lifetime is safe because the lambda + // lives inside self.client and is always cleared before self + // is destroyed (either by tp_clear or ~TlsClient). + self.client.SetOnOpen([&self](WishHandler* handler) { nb::gil_scoped_acquire acquire; - (*cb_ptr)(nb::cast(handler, nb::rv_policy::reference)); + self.on_open_cb(nb::cast(handler, nb::rv_policy::reference)); }); }) - .def("set_on_message", [](TlsClient& self, nb::object cb) { - auto cb_ptr = std::make_shared(cb); - self.SetOnMessage([cb_ptr](uint8_t opcode, const std::string& msg) { + .def("set_on_message", [](TlsClientPy& self, nb::object cb) { + self.on_message_cb = cb; + self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) { nb::gil_scoped_acquire acquire; - (*cb_ptr)(opcode, msg); + self.on_message_cb(opcode, msg); }); }) - .def("run", &TlsClient::Run, nb::call_guard()); + .def("run", [](TlsClientPy& self) { self.client.Run(); }, + nb::call_guard()) + .def("stop", [](TlsClientPy& self) { self.client.Stop(); }); - nb::class_(m, "PlainClient") + // ---- PlainClient ------------------------------------------------------ + static PyType_Slot plain_slots[] = { + {Py_tp_traverse, (void*)plain_traverse}, + {Py_tp_clear, (void*)plain_clear}, + {0, nullptr}, + }; + + nb::class_(m, "PlainClient", nb::type_slots(plain_slots)) .def(nb::init()) - .def("init", &PlainClient::Init) - .def("set_on_open", [](PlainClient& self, nb::object cb) { - auto cb_ptr = std::make_shared(cb); - self.SetOnOpen([cb_ptr](WishHandler* handler) { + .def("init", [](PlainClientPy& self) { + return self.client.Init(); + }) + .def("set_on_open", [](PlainClientPy& self, nb::object cb) { + self.on_open_cb = cb; + self.client.SetOnOpen([&self](WishHandler* handler) { nb::gil_scoped_acquire acquire; - (*cb_ptr)(nb::cast(handler, nb::rv_policy::reference)); + self.on_open_cb(nb::cast(handler, nb::rv_policy::reference)); }); }) - .def("set_on_message", [](PlainClient& self, nb::object cb) { - auto cb_ptr = std::make_shared(cb); - self.SetOnMessage([cb_ptr](uint8_t opcode, const std::string& msg) { + .def("set_on_message", [](PlainClientPy& self, nb::object cb) { + self.on_message_cb = cb; + self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) { nb::gil_scoped_acquire acquire; - (*cb_ptr)(opcode, msg); + self.on_message_cb(opcode, msg); }); }) - .def("run", &PlainClient::Run, nb::call_guard()); + .def("run", [](PlainClientPy& self) { self.client.Run(); }, + nb::call_guard()) + .def("stop", [](PlainClientPy& self) { self.client.Stop(); }); }