Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions wish/cpp/src/plain_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,9 @@ void PlainClient::Run() {

event_base_dispatch(base_);
}

void PlainClient::Stop() {
if (base_) {
event_base_loopexit(base_, nullptr);
}
}
1 change: 1 addition & 0 deletions wish/cpp/src/plain_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class PlainClient {
void SetOnOpen(OpenCallback cb);
void SetOnMessage(MessageCallback cb);
void Run();
void Stop();

private:
std::string host_;
Expand Down
6 changes: 6 additions & 0 deletions wish/cpp/src/tls_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,9 @@ void TlsClient::Run() {

event_base_dispatch(base_);
}

void TlsClient::Stop() {
if (base_) {
event_base_loopexit(base_, nullptr);
}
}
1 change: 1 addition & 0 deletions wish/cpp/src/tls_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TlsClient {
void SetOnOpen(OpenCallback cb);
void SetOnMessage(MessageCallback cb);
void Run();
void Stop();

private:
std::string ca_file_;
Expand Down
212 changes: 105 additions & 107 deletions wish/python/src/wish_ext.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include <event2/thread.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/function.h>
#include <event2/thread.h>
#include <nanobind/stl/string.h>

#include "wish_handler.h"
#include "tls_client.h"
#include "plain_client.h"
#include "tls_client.h"
#include "wish_handler.h"

namespace nb = nanobind;

Expand All @@ -21,141 +21,139 @@ namespace nb = nanobind;
// ---------------------------------------------------------------------------

struct TlsClientPy {
TlsClient client;
nb::object on_open_cb;
nb::object on_message_cb;
TlsClient client;
nb::object on_open_cb;
nb::object on_message_cb;

TlsClientPy(const std::string& ca, const std::string& cert,
const std::string& key, const std::string& host, int port)
: client(ca, cert, key, host, port) {}
TlsClientPy(const std::string& ca, const std::string& cert,
const std::string& key, const std::string& host, int port)
: client(ca, cert, key, host, port) {}
};

struct PlainClientPy {
PlainClient client;
nb::object on_open_cb;
nb::object on_message_cb;
PlainClient client;
nb::object on_open_cb;
nb::object on_message_cb;

PlainClientPy(const std::string& host, int port)
: client(host, port) {}
PlainClientPy(const std::string& host, int port)
: client(host, port) {}
};

// ---------------------------------------------------------------------------
// tp_traverse / tp_clear for TlsClientPy
// ---------------------------------------------------------------------------

static int tls_traverse(PyObject* self, visitproc visit, void* arg) {
TlsClientPy* w = nb::inst_ptr<TlsClientPy>(nb::handle(self));
Py_VISIT(w->on_open_cb.ptr());
Py_VISIT(w->on_message_cb.ptr());
return 0;
TlsClientPy* w = nb::inst_ptr<TlsClientPy>(nb::handle(self));
Py_VISIT(w->on_open_cb.ptr());
Py_VISIT(w->on_message_cb.ptr());
return 0;
}

static int tls_clear(PyObject* self) {
TlsClientPy* w = nb::inst_ptr<TlsClientPy>(nb::handle(self));
// Clear the C++ callbacks first so the lambda (which captures &*w) is
// dropped before we invalidate on_open_cb / on_message_cb.
w->client.SetOnOpen({});
w->client.SetOnMessage({});
w->on_open_cb = nb::object();
w->on_message_cb = nb::object();
return 0;
TlsClientPy* w = nb::inst_ptr<TlsClientPy>(nb::handle(self));
// Clear the C++ callbacks first so the lambda (which captures &*w) is
// dropped before we invalidate on_open_cb / on_message_cb.
w->client.SetOnOpen({});
w->client.SetOnMessage({});
w->on_open_cb = nb::object();
w->on_message_cb = nb::object();
return 0;
}

// ---------------------------------------------------------------------------
// tp_traverse / tp_clear for PlainClientPy
// ---------------------------------------------------------------------------

static int plain_traverse(PyObject* self, visitproc visit, void* arg) {
PlainClientPy* w = nb::inst_ptr<PlainClientPy>(nb::handle(self));
Py_VISIT(w->on_open_cb.ptr());
Py_VISIT(w->on_message_cb.ptr());
return 0;
PlainClientPy* w = nb::inst_ptr<PlainClientPy>(nb::handle(self));
Py_VISIT(w->on_open_cb.ptr());
Py_VISIT(w->on_message_cb.ptr());
return 0;
}

static int plain_clear(PyObject* self) {
PlainClientPy* w = nb::inst_ptr<PlainClientPy>(nb::handle(self));
w->client.SetOnOpen({});
w->client.SetOnMessage({});
w->on_open_cb = nb::object();
w->on_message_cb = nb::object();
return 0;
PlainClientPy* w = nb::inst_ptr<PlainClientPy>(nb::handle(self));
w->client.SetOnOpen({});
w->client.SetOnMessage({});
w->on_open_cb = nb::object();
w->on_message_cb = nb::object();
return 0;
}

// ---------------------------------------------------------------------------

NB_MODULE(wish_ext, m) {
// Enable libevent thread-safety
// Enable libevent thread-safety
#ifdef _WIN32
evthread_use_windows_threads();
evthread_use_windows_threads();
#else
evthread_use_pthreads();
evthread_use_pthreads();
#endif

nb::class_<WishHandler>(m, "WishHandler")
.def("send_text", &WishHandler::SendText)
.def("send_binary", &WishHandler::SendBinary);

// ---- TlsClient --------------------------------------------------------
static PyType_Slot tls_slots[] = {
{Py_tp_traverse, (void*)tls_traverse},
{Py_tp_clear, (void*)tls_clear},
{0, nullptr},
};

nb::class_<TlsClientPy>(m, "TlsClient", nb::type_slots(tls_slots))
.def(nb::init<const std::string&, const std::string&,
const std::string&, const std::string&, int>())
.def("init", [](TlsClientPy& self) {
return self.client.Init();
})
.def("set_on_open", [](TlsClientPy& self, nb::object cb) {
self.on_open_cb = cb;
// Capture self by pointer; lifetime is safe because the lambda
// lives inside self.client and is always cleared before self
// is destroyed (either by tp_clear or ~TlsClient).
self.client.SetOnOpen([&self](WishHandler* handler) {
nb::gil_scoped_acquire acquire;
self.on_open_cb(nb::cast(handler, nb::rv_policy::reference));
});
})
.def("set_on_message", [](TlsClientPy& self, nb::object cb) {
self.on_message_cb = cb;
self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) {
nb::gil_scoped_acquire acquire;
self.on_message_cb(opcode, msg);
});
})
.def("run", [](TlsClientPy& self) { self.client.Run(); },
nb::call_guard<nb::gil_scoped_release>())
.def("stop", [](TlsClientPy& self) { self.client.Stop(); });

// ---- PlainClient ------------------------------------------------------
static PyType_Slot plain_slots[] = {
{Py_tp_traverse, (void*)plain_traverse},
{Py_tp_clear, (void*)plain_clear},
{0, nullptr},
};

nb::class_<PlainClientPy>(m, "PlainClient", nb::type_slots(plain_slots))
.def(nb::init<const std::string&, int>())
.def("init", [](PlainClientPy& self) {
return self.client.Init();
})
.def("set_on_open", [](PlainClientPy& self, nb::object cb) {
self.on_open_cb = cb;
self.client.SetOnOpen([&self](WishHandler* handler) {
nb::gil_scoped_acquire acquire;
self.on_open_cb(nb::cast(handler, nb::rv_policy::reference));
});
})
.def("set_on_message", [](PlainClientPy& self, nb::object cb) {
self.on_message_cb = cb;
self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) {
nb::gil_scoped_acquire acquire;
self.on_message_cb(opcode, msg);
});
})
.def("run", [](PlainClientPy& self) { self.client.Run(); },
nb::call_guard<nb::gil_scoped_release>())
.def("stop", [](PlainClientPy& self) { self.client.Stop(); });
nb::class_<WishHandler>(m, "WishHandler")
.def("send_text", &WishHandler::SendText)
.def("send_binary", &WishHandler::SendBinary);

// ---- TlsClient --------------------------------------------------------
static PyType_Slot tls_slots[] = {
{Py_tp_traverse, (void*)tls_traverse},
{Py_tp_clear, (void*)tls_clear},
{0, nullptr},
};

nb::class_<TlsClientPy>(m, "TlsClient", nb::type_slots(tls_slots))
.def(nb::init<const std::string&, const std::string&,
const std::string&, const std::string&, int>())
.def("init", [](TlsClientPy& self) {
return self.client.Init();
})
.def("set_on_open", [](TlsClientPy& self, nb::object cb) {
self.on_open_cb = cb;
// Capture self by pointer; lifetime is safe because the lambda
// lives inside self.client and is always cleared before self
// is destroyed (either by tp_clear or ~TlsClient).
self.client.SetOnOpen([&self](WishHandler* handler) {
nb::gil_scoped_acquire acquire;
self.on_open_cb(nb::cast(handler, nb::rv_policy::reference));
});
})
.def("set_on_message", [](TlsClientPy& self, nb::object cb) {
self.on_message_cb = cb;
self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) {
nb::gil_scoped_acquire acquire;
self.on_message_cb(opcode, msg);
});
})
.def("run", [](TlsClientPy& self) { self.client.Run(); }, nb::call_guard<nb::gil_scoped_release>())
.def("stop", [](TlsClientPy& self) { self.client.Stop(); });

// ---- PlainClient ------------------------------------------------------
static PyType_Slot plain_slots[] = {
{Py_tp_traverse, (void*)plain_traverse},
{Py_tp_clear, (void*)plain_clear},
{0, nullptr},
};

nb::class_<PlainClientPy>(m, "PlainClient", nb::type_slots(plain_slots))
.def(nb::init<const std::string&, int>())
.def("init", [](PlainClientPy& self) {
return self.client.Init();
})
.def("set_on_open", [](PlainClientPy& self, nb::object cb) {
self.on_open_cb = cb;
self.client.SetOnOpen([&self](WishHandler* handler) {
nb::gil_scoped_acquire acquire;
self.on_open_cb(nb::cast(handler, nb::rv_policy::reference));
});
})
.def("set_on_message", [](PlainClientPy& self, nb::object cb) {
self.on_message_cb = cb;
self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) {
nb::gil_scoped_acquire acquire;
self.on_message_cb(opcode, msg);
});
})
.def("run", [](PlainClientPy& self) { self.client.Run(); }, nb::call_guard<nb::gil_scoped_release>())
.def("stop", [](PlainClientPy& self) { self.client.Stop(); });
}
16 changes: 11 additions & 5 deletions wish/python/wish/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, host, port, tls, ca_file="", cert_file="", key_file=""):
self._loop = asyncio.get_running_loop()
self._recv_queue = asyncio.Queue()
self._open_future = self._loop.create_future()
self._run_future = None
self._handler = None

def on_open(handler):
Expand All @@ -32,11 +33,19 @@ def on_message(opcode, msg):

async def connect(self):
# Run the C++ event loop in a background thread.
self._loop.run_in_executor(None, self._client.run)
# Keep the Future so we can await thread completion on close.
self._run_future = self._loop.run_in_executor(None, self._client.run)
# Wait until the on_open callback fires
await self._open_future
return self

async def close(self):
"""Stop the C++ event loop and wait for the background thread to exit."""
self._client.stop()
if self._run_future is not None:
await self._run_future
self._run_future = None

async def send(self, data):
"""Sends data over the WiSH connection. If data is bytes, sends as binary, else text."""
if not self._handler:
Expand Down Expand Up @@ -76,10 +85,7 @@ async def __aenter__(self):
return self.conn

async def __aexit__(self, exc_type, exc_val, exc_tb):
# We don't have explicit close method on TlsClient yet,
# but the connection would drop if the process exits,
# or we could add a shutdown mechanism.
pass
await self.conn.close()

def connect(uri, ca_file="", cert_file="", key_file=""):
return _ConnectContextManager(uri, ca_file, cert_file, key_file)