From 23c39844eadd4e326f3c592ef2867e2fbe9859b9 Mon Sep 17 00:00:00 2001 From: Johan Mabille Date: Wed, 25 Mar 2026 11:11:05 +0100 Subject: [PATCH] Stores comms created by create_comm in a map --- src/xcomm.cpp | 55 ++++++++++++++++++++++++++++++++++++++++----------- src/xcomm.hpp | 21 +++++++++++++++++--- 2 files changed, 62 insertions(+), 14 deletions(-) diff --git a/src/xcomm.cpp b/src/xcomm.cpp index 83be80e7..1ae915b1 100644 --- a/src/xcomm.cpp +++ b/src/xcomm.cpp @@ -48,10 +48,6 @@ namespace xpyt { } - xcomm::~xcomm() - { - } - std::string xcomm::comm_id() const { return m_comm.id(); @@ -79,7 +75,16 @@ namespace xpyt void xcomm::on_close(const python_callback_type& callback) { - m_comm.on_close(cpp_callback(callback)); + m_comm.on_close(cpp_close_callback(callback)); + } + + void xcomm::on_close_cleanup(close_callback_type callback) + { + m_close_callback = std::move(callback); + m_comm.on_close([this](const xeus::xmessage&) + { + m_close_callback(comm_id()); + }); } const xeus::xtarget* xcomm::target(const py::object& target_name) const @@ -116,6 +121,18 @@ namespace xpyt }; } + auto xcomm::cpp_close_callback(const python_callback_type& py_callback) const -> cpp_callback_type + { + return [this, py_callback](const xeus::xmessage& msg) + { + XPYT_HOLDING_GIL(py_callback(cppmessage_to_pymessage(msg))) + if (m_close_callback) + { + m_close_callback(comm_id()); + } + }; + } + void xcomm_manager::register_target(const py::str& target_name, const py::object& callback) { auto target_callback = [callback] (xeus::xcomm&& comm, const xeus::xmessage& msg) @@ -128,6 +145,20 @@ namespace xpyt ); } + void xcomm_manager::register_comm(py::object pycomm) + { + // pycomm was created with initial refcount = 0 + // Therefore we need to increment it to avoid its + // deletion by the garbage collector + pycomm.inc_ref(); + xcomm& comm = pycomm.cast(); + comm.on_close_cleanup([this](std::string id) + { + XPYT_HOLDING_GIL(m_comms.erase(id)) + }); + m_comms[comm.comm_id()] = pycomm; + } + /*************** * comm module * ***************/ @@ -150,18 +181,20 @@ namespace xpyt py::class_(comm_module, "CommManager") .def(py::init<>()) - .def("register_target", &xcomm_manager::register_target); - - comm_module.def("create_comm", [&comm_module](py::args objs, py::kwargs kw) { - return comm_module.attr("Comm")(*objs, **kw); - }); + .def("register_target", &xcomm_manager::register_target) + .def("register_comm", &xcomm_manager::register_comm); comm_module.def("get_comm_manager", [&comm_module]() { static py::object comm_manager = comm_module.attr("CommManager")(); - return comm_manager; }); + comm_module.def("create_comm", [&comm_module](py::args objs, py::kwargs kw) { + py::object comm = comm_module.attr("Comm")(*objs, **kw); + comm_module.attr("get_comm_manager")().attr("register_comm")(comm); + return comm; + }); + return comm_module; } diff --git a/src/xcomm.hpp b/src/xcomm.hpp index f6909c89..de4af03f 100644 --- a/src/xcomm.hpp +++ b/src/xcomm.hpp @@ -21,14 +21,18 @@ namespace xpyt { public: + using close_callback_type = std::function; using python_callback_type = std::function; using cpp_callback_type = std::function; using buffers_sequence = xeus::buffer_sequence; xcomm(const py::object& target_name, const py::object& data, const py::object& metadata, const py::object& buffers, const py::kwargs& kwargs); xcomm(xeus::xcomm&& comm); - xcomm(xcomm&& comm) = default; - virtual ~xcomm(); + xcomm(xcomm&& comm) = delete; + xcomm& operator=(xcomm&& rhs) = delete; + xcomm(const xcomm&) = delete; + xcomm& operator=(xcomm& rhs) = delete; + ~xcomm() = default; std::string comm_id() const; bool kernel() const; @@ -38,6 +42,8 @@ namespace xpyt void on_msg(const python_callback_type& callback); void on_close(const python_callback_type& callback); + void on_close_cleanup(close_callback_type callback); + private: // Warning: this function creates and register the target with a dummy @@ -46,15 +52,24 @@ namespace xpyt const xeus::xtarget* target(const py::object& target_name) const; xeus::xguid id(const py::kwargs& kwargs) const; cpp_callback_type cpp_callback(const python_callback_type& callback) const; + cpp_callback_type cpp_close_callback(const python_callback_type& callback) const; xeus::xcomm m_comm; + close_callback_type m_close_callback; }; - struct xcomm_manager + class xcomm_manager { + public: + xcomm_manager() = default; void register_target(const py::str& target_name, const py::object& callback); + void register_comm(py::object comm); + + private: + + std::map m_comms; }; py::module get_comm_module();