Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 58 additions & 25 deletions wish/python/src/wish_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ struct TlsClientPy {
std::mutex stopped_mu;
std::condition_variable stopped_cv;

// Guards against running cleanup twice (once from tp_finalize, once from
// tp_clear in the GC path).
std::atomic<bool> finalized{false};

TlsClientPy(const std::string& ca, const std::string& cert,
const std::string& key, const std::string& host, int port)
: client(ca, cert, key, host, port) {};
Expand All @@ -79,27 +83,22 @@ struct PlainClientPy {
std::mutex stopped_mu;
std::condition_variable stopped_cv;

std::atomic<bool> finalized{false};

PlainClientPy(const std::string& host, int port)
: client(host, port) {}
};

// ---------------------------------------------------------------------------
// tp_traverse / tp_clear for TlsClientPy
// tp_traverse / tp_clear / tp_finalize for TlsClientPy
// ---------------------------------------------------------------------------

static int tls_traverse(PyObject* self, visitproc visit, void* arg) {
TlsClientPy* w = nb::inst_ptr<TlsClientPy>(nb::handle(self));
Py_VISIT(w->on_open_cb.ptr());
Py_VISIT(w->on_message_cb.ptr());
return 0;
}

static int tls_clear(PyObject* self) {
TlsClientPy* w = nb::inst_ptr<TlsClientPy>(nb::handle(self));
// tls_do_cleanup: stop the event loop, wait for it to exit, then release all
// C++ callbacks. Safe to call multiple times (idempotent via finalized flag).
static void tls_do_cleanup(TlsClientPy* w) {
if (w->finalized.exchange(true, std::memory_order_acq_rel)) return;

// 1. Ask the event loop to stop.
//
// event_base_loopexit is thread-safe.
// 1. Ask the event loop to stop. event_base_loopexit is thread-safe.
w->client.Stop();

// 2. Release the GIL and wait for Run() to return.
Expand Down Expand Up @@ -128,26 +127,42 @@ static int tls_clear(PyObject* self) {
w->handler_ref->ptr = nullptr;
}
w->handler_ref.reset();
w->on_open_cb = nb::object();
w->on_message_cb = nb::object();
return 0;
}

// ---------------------------------------------------------------------------
// tp_traverse / tp_clear for PlainClientPy
// ---------------------------------------------------------------------------
// tp_finalize is called for BOTH the normal refcount destruction path and the
// GC path (before tp_clear / tp_dealloc). Putting the cleanup here ensures
// it runs regardless of whether a reference cycle was involved.
static void tls_finalize(PyObject* self) {
TlsClientPy* w = nb::inst_ptr<TlsClientPy>(nb::handle(self));
tls_do_cleanup(w);
// Do NOT release on_open_cb / on_message_cb here: they are Python objects
// that tp_traverse must still be able to visit until tp_clear runs.
}

static int plain_traverse(PyObject* self, visitproc visit, void* arg) {
PlainClientPy* w = nb::inst_ptr<PlainClientPy>(nb::handle(self));
static int tls_traverse(PyObject* self, visitproc visit, void* arg) {
TlsClientPy* w = nb::inst_ptr<TlsClientPy>(nb::handle(self));
Py_VISIT(w->on_open_cb.ptr());
Py_VISIT(w->on_message_cb.ptr());
return 0;
}

static int plain_clear(PyObject* self) {
PlainClientPy* w = nb::inst_ptr<PlainClientPy>(nb::handle(self));
static int tls_clear(PyObject* self) {
TlsClientPy* w = nb::inst_ptr<TlsClientPy>(nb::handle(self));
// tls_do_cleanup is idempotent; if tp_finalize already ran this is a no-op.
tls_do_cleanup(w);
// Drop Python object references to break the cycle.
w->on_open_cb = nb::object();
w->on_message_cb = nb::object();
return 0;
}

// ---------------------------------------------------------------------------
// tp_traverse / tp_clear / tp_finalize for PlainClientPy
// ---------------------------------------------------------------------------

static void plain_do_cleanup(PlainClientPy* w) {
if (w->finalized.exchange(true, std::memory_order_acq_rel)) return;

// Same Stop-then-wait pattern as tls_clear. See comments there.
w->client.Stop();

{
Expand All @@ -160,12 +175,28 @@ static int plain_clear(PyObject* self) {
w->client.SetOnOpen({});
w->client.SetOnClose({});
w->client.SetOnMessage({});
// Invalidate the safe handle so Python code can no longer call through it.
if (w->handler_ref) {
std::lock_guard<std::mutex> lock(w->handler_ref->mu);
w->handler_ref->ptr = nullptr;
}
w->handler_ref.reset();
}

static void plain_finalize(PyObject* self) {
PlainClientPy* w = nb::inst_ptr<PlainClientPy>(nb::handle(self));
plain_do_cleanup(w);
}

static int plain_traverse(PyObject* self, visitproc visit, void* arg) {
PlainClientPy* w = nb::inst_ptr<PlainClientPy>(nb::handle(self));
Py_VISIT(w->on_open_cb.ptr());
Py_VISIT(w->on_message_cb.ptr());
return 0;
}

static int plain_clear(PyObject* self) {
PlainClientPy* w = nb::inst_ptr<PlainClientPy>(nb::handle(self));
plain_do_cleanup(w);
w->on_open_cb = nb::object();
w->on_message_cb = nb::object();
return 0;
Expand All @@ -189,6 +220,7 @@ NB_MODULE(wish_ext, m) {
static PyType_Slot tls_slots[] = {
{Py_tp_traverse, (void*)tls_traverse},
{Py_tp_clear, (void*)tls_clear},
{Py_tp_finalize, (void*)tls_finalize},
{0, nullptr},
};

Expand Down Expand Up @@ -252,6 +284,7 @@ NB_MODULE(wish_ext, m) {
static PyType_Slot plain_slots[] = {
{Py_tp_traverse, (void*)plain_traverse},
{Py_tp_clear, (void*)plain_clear},
{Py_tp_finalize, (void*)plain_finalize},
{0, nullptr},
};

Expand Down