diff --git a/wish/python/src/wish_ext.cc b/wish/python/src/wish_ext.cc index 82fd4c4..f3d2a31 100644 --- a/wish/python/src/wish_ext.cc +++ b/wish/python/src/wish_ext.cc @@ -64,6 +64,10 @@ struct TlsClientPy { std::mutex stopped_mu; std::condition_variable stopped_cv; + // Guards against running cleanup twice (once from tp_finalize, once from + // tp_clear in the GC path). + std::atomic finalized{false}; + TlsClientPy(const std::string& ca, const std::string& cert, const std::string& key, const std::string& host, int port) : client(ca, cert, key, host, port) {}; @@ -79,27 +83,22 @@ struct PlainClientPy { std::mutex stopped_mu; std::condition_variable stopped_cv; + std::atomic finalized{false}; + PlainClientPy(const std::string& host, int port) : client(host, port) {} }; // --------------------------------------------------------------------------- -// tp_traverse / tp_clear for TlsClientPy +// tp_traverse / tp_clear / tp_finalize for TlsClientPy // --------------------------------------------------------------------------- -static int tls_traverse(PyObject* self, visitproc visit, void* arg) { - TlsClientPy* w = nb::inst_ptr(nb::handle(self)); - Py_VISIT(w->on_open_cb.ptr()); - Py_VISIT(w->on_message_cb.ptr()); - return 0; -} - -static int tls_clear(PyObject* self) { - TlsClientPy* w = nb::inst_ptr(nb::handle(self)); +// tls_do_cleanup: stop the event loop, wait for it to exit, then release all +// C++ callbacks. Safe to call multiple times (idempotent via finalized flag). +static void tls_do_cleanup(TlsClientPy* w) { + if (w->finalized.exchange(true, std::memory_order_acq_rel)) return; - // 1. Ask the event loop to stop. - // - // event_base_loopexit is thread-safe. + // 1. Ask the event loop to stop. event_base_loopexit is thread-safe. w->client.Stop(); // 2. Release the GIL and wait for Run() to return. @@ -128,26 +127,42 @@ static int tls_clear(PyObject* self) { w->handler_ref->ptr = nullptr; } w->handler_ref.reset(); - w->on_open_cb = nb::object(); - w->on_message_cb = nb::object(); - return 0; } -// --------------------------------------------------------------------------- -// tp_traverse / tp_clear for PlainClientPy -// --------------------------------------------------------------------------- +// tp_finalize is called for BOTH the normal refcount destruction path and the +// GC path (before tp_clear / tp_dealloc). Putting the cleanup here ensures +// it runs regardless of whether a reference cycle was involved. +static void tls_finalize(PyObject* self) { + TlsClientPy* w = nb::inst_ptr(nb::handle(self)); + tls_do_cleanup(w); + // Do NOT release on_open_cb / on_message_cb here: they are Python objects + // that tp_traverse must still be able to visit until tp_clear runs. +} -static int plain_traverse(PyObject* self, visitproc visit, void* arg) { - PlainClientPy* w = nb::inst_ptr(nb::handle(self)); +static int tls_traverse(PyObject* self, visitproc visit, void* arg) { + TlsClientPy* w = nb::inst_ptr(nb::handle(self)); Py_VISIT(w->on_open_cb.ptr()); Py_VISIT(w->on_message_cb.ptr()); return 0; } -static int plain_clear(PyObject* self) { - PlainClientPy* w = nb::inst_ptr(nb::handle(self)); +static int tls_clear(PyObject* self) { + TlsClientPy* w = nb::inst_ptr(nb::handle(self)); + // tls_do_cleanup is idempotent; if tp_finalize already ran this is a no-op. + tls_do_cleanup(w); + // Drop Python object references to break the cycle. + w->on_open_cb = nb::object(); + w->on_message_cb = nb::object(); + return 0; +} + +// --------------------------------------------------------------------------- +// tp_traverse / tp_clear / tp_finalize for PlainClientPy +// --------------------------------------------------------------------------- + +static void plain_do_cleanup(PlainClientPy* w) { + if (w->finalized.exchange(true, std::memory_order_acq_rel)) return; - // Same Stop-then-wait pattern as tls_clear. See comments there. w->client.Stop(); { @@ -160,12 +175,28 @@ static int plain_clear(PyObject* self) { w->client.SetOnOpen({}); w->client.SetOnClose({}); w->client.SetOnMessage({}); - // Invalidate the safe handle so Python code can no longer call through it. if (w->handler_ref) { std::lock_guard lock(w->handler_ref->mu); w->handler_ref->ptr = nullptr; } w->handler_ref.reset(); +} + +static void plain_finalize(PyObject* self) { + PlainClientPy* w = nb::inst_ptr(nb::handle(self)); + plain_do_cleanup(w); +} + +static int plain_traverse(PyObject* self, visitproc visit, void* arg) { + PlainClientPy* w = nb::inst_ptr(nb::handle(self)); + Py_VISIT(w->on_open_cb.ptr()); + Py_VISIT(w->on_message_cb.ptr()); + return 0; +} + +static int plain_clear(PyObject* self) { + PlainClientPy* w = nb::inst_ptr(nb::handle(self)); + plain_do_cleanup(w); w->on_open_cb = nb::object(); w->on_message_cb = nb::object(); return 0; @@ -189,6 +220,7 @@ NB_MODULE(wish_ext, m) { static PyType_Slot tls_slots[] = { {Py_tp_traverse, (void*)tls_traverse}, {Py_tp_clear, (void*)tls_clear}, + {Py_tp_finalize, (void*)tls_finalize}, {0, nullptr}, }; @@ -252,6 +284,7 @@ NB_MODULE(wish_ext, m) { static PyType_Slot plain_slots[] = { {Py_tp_traverse, (void*)plain_traverse}, {Py_tp_clear, (void*)plain_clear}, + {Py_tp_finalize, (void*)plain_finalize}, {0, nullptr}, };