Skip to content

Commit a307fd2

Browse files
Python: Fix keep_alive specifications, add tests for keep_alive (#1851)
* Fix keepalives * Add simple GC test * Add keepalive for snapshots api * Add more extensive keepalive test * Slightly API-breaking.. need to del everything * tmp, check sth * tmp check ci * Revert "tmp check ci" This reverts commit 1cf6973. * Revert "tmp, check sth" This reverts commit 93ed467. * Fix typing issues with load/store_chunk in Python
1 parent 2552e63 commit a307fd2

File tree

4 files changed

+188
-6
lines changed

4 files changed

+188
-6
lines changed

src/binding/python/Iteration.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,15 @@ void init_Iteration(py::module &m)
101101
py::return_value_policy::copy,
102102
// garbage collection: return value must be freed before
103103
// Iteration
104-
py::keep_alive<1, 0>()))
104+
py::keep_alive<0, 1>()))
105105
.def_property_readonly(
106106
"particles",
107107
py::cpp_function(
108108
[](Iteration &i) { return i.particles; },
109109
py::return_value_policy::copy,
110110
// garbage collection: return value must be freed before
111111
// Iteration
112-
py::keep_alive<1, 0>()));
112+
py::keep_alive<0, 1>()));
113113

114114
add_pickle(
115115
cl, [](openPMD::Series series, std::vector<std::string> const &group) {

src/binding/python/ParticleSpecies.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void init_ParticleSpecies(py::module &m)
5454
[](ParticleSpecies &ps) { return ps.particlePatches; },
5555
py::return_value_policy::copy,
5656
// garbage collection: return value must be freed before Series
57-
py::keep_alive<1, 0>()));
57+
py::keep_alive<0, 1>()));
5858
add_pickle(
5959
cl, [](openPMD::Series series, std::vector<std::string> const &group) {
6060
uint64_t const n_it = std::stoull(group.at(1));

src/binding/python/Series.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ not possible once it has been closed.
301301
throw std::runtime_error("Unreachable");
302302
},
303303
// copy + keepalive
304-
py::return_value_policy::copy)
304+
py::return_value_policy::copy,
305+
py::keep_alive<0, 1>())
305306
.def(
306307
"current_iteration",
307308
[](Snapshots &s) -> std::optional<IndexedIteration> {
@@ -315,6 +316,7 @@ not possible once it has been closed.
315316
return std::nullopt;
316317
}
317318
},
319+
py::keep_alive<0, 1>(),
318320
"Return the iteration that is currently being written to, if "
319321
"it "
320322
"exists.");
@@ -503,7 +505,7 @@ this method.
503505
[](Series &s) { return s.iterations; },
504506
py::return_value_policy::copy,
505507
// garbage collection: return value must be freed before Series
506-
py::keep_alive<1, 0>()))
508+
py::keep_alive<0, 1>()))
507509
.def(
508510
"read_iterations",
509511
[](Series &s) {
@@ -645,5 +647,18 @@ users to overwrite default options, while keeping any other ones.
645647
py::arg("comm"),
646648
docs_merge_json)
647649
#endif
648-
;
650+
.def("__del__", [](Series &s) {
651+
try
652+
{
653+
s.close();
654+
}
655+
catch (std::exception const &e)
656+
{
657+
std::cerr << "Error during close: " << e.what() << std::endl;
658+
}
659+
catch (...)
660+
{
661+
std::cerr << "Unknown error during close." << std::endl;
662+
}
663+
});
649664
}

test/python/unittest/API/APITest.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,6 +2026,9 @@ def makeAvailableChunksRoundTrip(self, ext):
20262026
# Cleaner: write.close()
20272027
# But let's keep this instance to test that that workflow stays
20282028
# functional.
2029+
# Need to delete everything as garbage collection will keep `write`
2030+
# alive as long as E_x is around.
2031+
del E_x
20292032
del write
20302033

20312034
read = io.Series(
@@ -2340,6 +2343,170 @@ def testScalarHdf5Fields(self):
23402343
self.assertEqual(loaded_from_scalar, np.array([45]))
23412344
series_read_again.close()
23422345

2346+
def testKeepaliveComponentExtraction(self):
2347+
"""Test that keepalive specifications
2348+
guard root objects from garbage collection."""
2349+
self.testKeepaliveMeshComponent()
2350+
self.testKeepaliveParticlePosition()
2351+
self.testKeepaliveParticlePatches()
2352+
2353+
def testKeepaliveMeshComponent(self):
2354+
"""Test keepalive for mesh component extraction."""
2355+
for ext in tested_file_extensions:
2356+
self.backend_keepalive_mesh_component(ext)
2357+
2358+
def testKeepaliveParticlePosition(self):
2359+
"""Test keepalive for particle position component extraction."""
2360+
for ext in tested_file_extensions:
2361+
self.backend_keepalive_particle_position(ext)
2362+
2363+
def testKeepaliveParticlePatches(self):
2364+
"""Test keepalive for particle patches component extraction."""
2365+
for ext in tested_file_extensions:
2366+
self.backend_keepalive_particle_patches(ext)
2367+
2368+
def backend_keepalive_mesh_component(self, file_ending):
2369+
"""Helper function that tests keepalive
2370+
for mesh component extraction."""
2371+
import gc
2372+
2373+
filename = "unittest_py_keepalive_mesh." + file_ending
2374+
path = filename
2375+
2376+
def get_component_only():
2377+
series = io.Series(path, io.Access.create_linear)
2378+
backend = series.backend
2379+
iteration = series.snapshots()[0]
2380+
mesh = iteration.meshes["E"]
2381+
component = mesh["x"]
2382+
2383+
mesh.axis_labels = ["x", "y"]
2384+
component.reset_dataset(io.Dataset(np.dtype("float"), [10, 10]))
2385+
2386+
del iteration
2387+
del mesh
2388+
del series
2389+
gc.collect()
2390+
2391+
return component, backend
2392+
2393+
component, backend = get_component_only()
2394+
gc.collect()
2395+
2396+
component[:, :] = np.reshape(
2397+
np.arange(100, dtype=np.dtype("float")),
2398+
[10, 10]
2399+
)
2400+
2401+
component.series_flush()
2402+
if backend == "ADIOS2":
2403+
del component
2404+
gc.collect()
2405+
2406+
read = io.Series(path, io.Access.read_only)
2407+
loaded = read.snapshots()[0].meshes["E"]["x"][:]
2408+
read.flush()
2409+
np.testing.assert_array_equal(
2410+
loaded,
2411+
np.reshape(np.arange(100, dtype=np.dtype("float")), [10, 10])
2412+
)
2413+
2414+
def backend_keepalive_particle_position(self, file_ending):
2415+
"""Helper function that tests keepalive
2416+
for particle position component extraction."""
2417+
import gc
2418+
2419+
filename = "unittest_py_keepalive_particle." + file_ending
2420+
path = filename
2421+
num_particles = 100
2422+
2423+
def get_component_only():
2424+
series = io.Series(path, io.Access.create_linear)
2425+
backend = series.backend
2426+
iteration = series.snapshots()[0]
2427+
particles = iteration.particles["electrons"]
2428+
position = particles["position"]["x"]
2429+
2430+
position.reset_dataset(
2431+
io.Dataset(np.dtype("float"), [num_particles]))
2432+
2433+
del iteration
2434+
del particles
2435+
del series
2436+
gc.collect()
2437+
2438+
return position, backend
2439+
2440+
position, backend = get_component_only()
2441+
gc.collect()
2442+
2443+
position[:] = np.arange(num_particles, dtype=np.dtype("float"))
2444+
2445+
position.series_flush()
2446+
if backend == "ADIOS2":
2447+
del position
2448+
gc.collect()
2449+
2450+
read = io.Series(path, io.Access.read_only)
2451+
loaded = read.snapshots()[0] \
2452+
.particles["electrons"]["position"]["x"][:]
2453+
read.flush()
2454+
np.testing.assert_array_equal(
2455+
loaded,
2456+
np.arange(num_particles, dtype=np.dtype("float"))
2457+
)
2458+
2459+
def backend_keepalive_particle_patches(self, file_ending):
2460+
"""Helper function that tests keepalive
2461+
for particle patches extraction."""
2462+
import gc
2463+
2464+
filename = "unittest_py_keepalive_patches." + file_ending
2465+
path = filename
2466+
2467+
def get_component_only():
2468+
series = io.Series(path, io.Access.create_linear)
2469+
backend = series.backend
2470+
iteration = series.snapshots()[0]
2471+
particles = iteration.particles["electrons"]
2472+
2473+
dset = io.Dataset(np.dtype(np.float32), [30])
2474+
position_x = particles["position"]["x"]
2475+
position_x.reset_dataset(dset)
2476+
position_x[:] = np.arange(30, dtype=np.float32)
2477+
2478+
dset = io.Dataset(np.dtype("uint64"), [2])
2479+
num_particles_comp = particles.particle_patches["numParticles"]
2480+
num_particles_comp.reset_dataset(dset)
2481+
num_particles_comp.store(0, np.uint64(10))
2482+
num_particles_comp.store(1, np.uint64(20))
2483+
2484+
del iteration
2485+
del particles
2486+
del series
2487+
gc.collect()
2488+
2489+
return num_particles_comp, backend
2490+
2491+
component, backend = get_component_only()
2492+
gc.collect()
2493+
2494+
component.store(0, np.uint64(50))
2495+
2496+
component.series_flush()
2497+
if backend == "ADIOS2":
2498+
del component
2499+
gc.collect()
2500+
2501+
read = io.Series(path, io.Access.read_only)
2502+
loaded = read.snapshots()[0] \
2503+
.particles["electrons"].particle_patches["numParticles"].load()
2504+
read.flush()
2505+
np.testing.assert_array_equal(
2506+
loaded[0],
2507+
np.uint64(50)
2508+
)
2509+
23432510

23442511
if __name__ == '__main__':
23452512
unittest.main()

0 commit comments

Comments
 (0)