Skip to content

Commit 3971486

Browse files
committed
Stores comms created by create_comm in a map
1 parent db7770f commit 3971486

2 files changed

Lines changed: 74 additions & 11 deletions

File tree

src/xcomm.cpp

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,19 @@ namespace xpyt
4848
{
4949
}
5050

51-
xcomm::~xcomm()
51+
xcomm::xcomm(xcomm&& comm)
52+
: m_comm(std::move(comm.m_comm))
5253
{
54+
on_close_cleanup(std::move(comm.m_close_callback));
55+
comm.m_close_callback = close_callback_type{};
56+
}
57+
58+
xcomm& xcomm::operator=(xcomm&& rhs)
59+
{
60+
m_comm = std::move(rhs.m_comm);
61+
on_close_cleanup(std::move(rhs.m_close_callback));
62+
rhs.m_close_callback = close_callback_type{};
63+
return *this;
5364
}
5465

5566
std::string xcomm::comm_id() const
@@ -79,7 +90,16 @@ namespace xpyt
7990

8091
void xcomm::on_close(const python_callback_type& callback)
8192
{
82-
m_comm.on_close(cpp_callback(callback));
93+
m_comm.on_close(cpp_close_callback(callback));
94+
}
95+
96+
void xcomm::on_close_cleanup(close_callback_type callback)
97+
{
98+
m_close_callback = std::move(callback);
99+
m_comm.on_close([this](const xeus::xmessage&)
100+
{
101+
m_close_callback(comm_id());
102+
});
83103
}
84104

85105
const xeus::xtarget* xcomm::target(const py::object& target_name) const
@@ -116,6 +136,18 @@ namespace xpyt
116136
};
117137
}
118138

139+
auto xcomm::cpp_close_callback(const python_callback_type& py_callback) const -> cpp_callback_type
140+
{
141+
return [this, py_callback](const xeus::xmessage& msg)
142+
{
143+
XPYT_HOLDING_GIL(py_callback(cppmessage_to_pymessage(msg)))
144+
if (m_close_callback)
145+
{
146+
m_close_callback(comm_id());
147+
}
148+
};
149+
}
150+
119151
void xcomm_manager::register_target(const py::str& target_name, const py::object& callback)
120152
{
121153
auto target_callback = [callback] (xeus::xcomm&& comm, const xeus::xmessage& msg)
@@ -128,6 +160,20 @@ namespace xpyt
128160
);
129161
}
130162

163+
void xcomm_manager::register_comm(py::object pycomm)
164+
{
165+
// pycomm was created with initial refcount = 0
166+
// Therefore we need to increment it to avoid its
167+
// deletion by the garbage collector
168+
pycomm.inc_ref();
169+
xcomm& comm = pycomm.cast<xcomm&>();
170+
comm.on_close_cleanup([this](std::string id)
171+
{
172+
XPYT_HOLDING_GIL(m_comms.erase(id))
173+
});
174+
m_comms[comm.comm_id()] = pycomm;
175+
}
176+
131177
/***************
132178
* comm module *
133179
***************/
@@ -150,18 +196,20 @@ namespace xpyt
150196

151197
py::class_<xcomm_manager>(comm_module, "CommManager")
152198
.def(py::init<>())
153-
.def("register_target", &xcomm_manager::register_target);
154-
155-
comm_module.def("create_comm", [&comm_module](py::args objs, py::kwargs kw) {
156-
return comm_module.attr("Comm")(*objs, **kw);
157-
});
199+
.def("register_target", &xcomm_manager::register_target)
200+
.def("register_comm", &xcomm_manager::register_comm);
158201

159202
comm_module.def("get_comm_manager", [&comm_module]() {
160203
static py::object comm_manager = comm_module.attr("CommManager")();
161-
162204
return comm_manager;
163205
});
164206

207+
comm_module.def("create_comm", [&comm_module](py::args objs, py::kwargs kw) {
208+
py::object comm = comm_module.attr("Comm")(*objs, **kw);
209+
comm_module.attr("get_comm_manager")().attr("register_comm")(comm);
210+
return comm;
211+
});
212+
165213
return comm_module;
166214
}
167215

src/xcomm.hpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,18 @@ namespace xpyt
2121
{
2222
public:
2323

24+
using close_callback_type = std::function<void(std::string)>;
2425
using python_callback_type = std::function<void(py::object)>;
2526
using cpp_callback_type = std::function<void(const xeus::xmessage&)>;
2627
using buffers_sequence = xeus::buffer_sequence;
2728

2829
xcomm(const py::object& target_name, const py::object& data, const py::object& metadata, const py::object& buffers, const py::kwargs& kwargs);
2930
xcomm(xeus::xcomm&& comm);
30-
xcomm(xcomm&& comm) = default;
31-
virtual ~xcomm();
31+
xcomm(xcomm&& comm);
32+
xcomm& operator=(xcomm&& rhs);
33+
xcomm(const xcomm&) = delete;
34+
xcomm& operator=(xcomm& rhs) = delete;
35+
~xcomm() = default;
3236

3337
std::string comm_id() const;
3438
bool kernel() const;
@@ -38,6 +42,8 @@ namespace xpyt
3842
void on_msg(const python_callback_type& callback);
3943
void on_close(const python_callback_type& callback);
4044

45+
void on_close_cleanup(close_callback_type callback);
46+
4147
private:
4248

4349
// Warning: this function creates and register the target with a dummy
@@ -46,15 +52,24 @@ namespace xpyt
4652
const xeus::xtarget* target(const py::object& target_name) const;
4753
xeus::xguid id(const py::kwargs& kwargs) const;
4854
cpp_callback_type cpp_callback(const python_callback_type& callback) const;
55+
cpp_callback_type cpp_close_callback(const python_callback_type& callback) const;
4956

5057
xeus::xcomm m_comm;
58+
close_callback_type m_close_callback;
5159
};
5260

53-
struct xcomm_manager
61+
class xcomm_manager
5462
{
63+
public:
64+
5565
xcomm_manager() = default;
5666

5767
void register_target(const py::str& target_name, const py::object& callback);
68+
void register_comm(py::object comm);
69+
70+
private:
71+
72+
std::map<std::string, py::object> m_comms;
5873
};
5974

6075
py::module get_comm_module();

0 commit comments

Comments
 (0)