Skip to content

Commit fc44123

Browse files
Fix reference counting in Python (#1775)
* Add failing test * Move write part to own function this triggers the error * Fix refcounting in Python bindings * CI fixes
1 parent 779a81a commit fc44123

File tree

3 files changed

+69
-13
lines changed

3 files changed

+69
-13
lines changed

src/IO/AbstractIOHandlerImpl.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,10 @@ std::future<void> AbstractIOHandlerImpl::flush()
268268
i.writable->parent,
269269
"->",
270270
i.writable,
271-
"] WRITE_DATASET");
271+
"] WRITE_DATASET, offset=",
272+
[&parameter]() { return vec_as_string(parameter.offset); },
273+
", extent=",
274+
[&parameter]() { return vec_as_string(parameter.extent); });
272275
writeDataset(i.writable, parameter);
273276
break;
274277
}

src/binding/python/RecordComponent.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <pybind11/detail/common.h>
2323
#include <pybind11/numpy.h>
2424
#include <pybind11/pybind11.h>
25+
#include <pybind11/pytypes.h>
2526
#include <pybind11/stl.h>
2627

2728
#include "openPMD/Dataset.hpp"
@@ -322,13 +323,16 @@ struct StoreChunkFromPythonArray
322323
Offset const &offset,
323324
Extent const &extent)
324325
{
325-
// here, we increase a reference on the user-passed data so that
326+
a.inc_ref();
327+
void *data = a.mutable_data();
328+
// here, we store an owning handle in the lambda capture so that
326329
// temporary and lost-scope variables stay alive until we flush
327330
// note: this does not yet prevent the user, as in C++, to build
328331
// a race condition by manipulating the data that was passed
329-
a.inc_ref();
330-
void *data = a.mutable_data();
331-
std::shared_ptr<T> shared((T *)data, [a](T *) { a.dec_ref(); });
332+
std::shared_ptr<T> shared(
333+
(T *)data, [owning_handle = a.cast<py::object>()](T *) {
334+
// no-op
335+
});
332336
r.storeChunk(std::move(shared), offset, extent);
333337
}
334338

@@ -343,13 +347,15 @@ struct LoadChunkIntoPythonArray
343347
Offset const &offset,
344348
Extent const &extent)
345349
{
346-
// here, we increase a reference on the user-passed data so that
350+
void *data = a.mutable_data();
351+
// here, we store an owning handle in the lambda capture so that
347352
// temporary and lost-scope variables stay alive until we flush
348353
// note: this does not yet prevent the user, as in C++, to build
349354
// a race condition by manipulating the data that was passed
350-
a.inc_ref();
351-
void *data = a.mutable_data();
352-
std::shared_ptr<T> shared((T *)data, [a](T *) { a.dec_ref(); });
355+
std::shared_ptr<T> shared(
356+
(T *)data, [owning_handle = a.cast<py::object>()](T *) {
357+
// no-op
358+
});
353359
r.loadChunk(std::move(shared), offset, extent);
354360
}
355361

@@ -365,14 +371,15 @@ struct LoadChunkIntoPythonBuffer
365371
Offset const &offset,
366372
Extent const &extent)
367373
{
368-
// here, we increase a reference on the user-passed data so that
374+
void *data = buffer_info.ptr;
375+
// here, we store an owning handle in the lambda capture so that
369376
// temporary and lost-scope variables stay alive until we flush
370377
// note: this does not yet prevent the user, as in C++, to build
371378
// a race condition by manipulating the data that was passed
372-
buffer.inc_ref();
373-
void *data = buffer_info.ptr;
374379
std::shared_ptr<T> shared(
375-
(T *)data, [buffer](T *) { buffer.dec_ref(); });
380+
(T *)data, [owning_handle = buffer.cast<py::object>()](T *) {
381+
// no-op
382+
});
376383
r.loadChunk(std::move(shared), offset, extent);
377384
}
378385

test/python/unittest/API/APITest.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,52 @@ def tearDown(self):
8282
del self.__particle_series
8383
del self.__series
8484

85+
# This function exhibits a bug in the old use of refcounting.
86+
def refcountingCreateData(self):
87+
series = io.Series(
88+
"../samples/refcounting.json",
89+
io.Access.create_linear,
90+
)
91+
92+
for i in range(10):
93+
current_iteration = series.snapshots()[i]
94+
95+
# First, write an E mesh.
96+
E = current_iteration.meshes["E"]
97+
E.axis_labels = ["x", "y"]
98+
for dim in ["x", "y"]:
99+
component = E[dim]
100+
component.reset_dataset(
101+
io.Dataset(np.dtype("float"), [10, 10]))
102+
component[:, :] = np.reshape(
103+
np.arange(i * 100, (i + 1) * 100, dtype=np.dtype("float")),
104+
[10, 10],
105+
)
106+
107+
# Now, write some e particles.
108+
e = current_iteration.particles["e"]
109+
for dim in ["x", "y"]:
110+
# Do not bother with a positionOffset
111+
position_offset = e["positionOffset"][dim]
112+
position_offset.make_constant(0)
113+
114+
position = e["position"][dim]
115+
position.reset_dataset(io.Dataset(np.dtype("float"), [100]))
116+
position[:] = np.arange(
117+
i * 100, (i + 1) * 100, dtype=np.dtype("float")
118+
)
119+
120+
def testRefCounting(self):
121+
self.refcountingCreateData()
122+
123+
read = io.Series("../samples/refcounting.json", io.Access.read_linear)
124+
iteration = read.snapshots()[0]
125+
pos_x = iteration.particles["e"]["position"]["x"]
126+
loaded = pos_x[:]
127+
read.flush()
128+
self.assertTrue(np.allclose(
129+
loaded, np.arange(0, 100, dtype=np.dtype("float"))))
130+
85131
def testFieldData(self):
86132
""" Testing serial IO on a pure field dataset. """
87133

0 commit comments

Comments
 (0)