Skip to content

Commit 1747e16

Browse files
committed
Remove SamplingContext.view and helper View class
1 parent 24cb0f4 commit 1747e16

6 files changed

Lines changed: 161 additions & 166 deletions

File tree

pysages/backends/ase.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
build_data_querier,
1717
)
1818
from pysages.backends.utils import view
19-
from pysages.typing import Callable, NamedTuple
19+
from pysages.typing import Callable
2020
from pysages.utils import ToCPU, copy
2121

2222

@@ -29,7 +29,7 @@ class Sampler(Calculator):
2929
"""
3030

3131
def __init__(self, context, method_bundle, callback: Callable):
32-
initial_snapshot, initialize, mehod_update = method_bundle
32+
initial_snapshot, initialize, method_update = method_bundle
3333

3434
atoms = context.atoms
3535
self.implemented_properties = atoms.calc.implemented_properties
@@ -41,7 +41,7 @@ def __init__(self, context, method_bundle, callback: Callable):
4141
self.callback = callback
4242
self.snapshot = initial_snapshot
4343
self.state = initialize()
44-
self.update = mehod_update
44+
self.update = method_update
4545

4646
sig = signature(atoms.calc.calculate).parameters
4747
self._calculator = atoms.calc
@@ -151,10 +151,6 @@ def dimensionality():
151151
return helpers
152152

153153

154-
class View(NamedTuple):
155-
synchronize: Callable
156-
157-
158154
def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
159155
"""
160156
Entry point for the backend code, it gets called when the simulation
@@ -166,6 +162,5 @@ def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
166162
helpers = build_helpers(sampling_context, sampling_method)
167163
method_bundle = sampling_method.build(snapshot, helpers)
168164
sampler = Sampler(context, method_bundle, callback)
169-
sampling_context.view = View((lambda: None))
170165
sampling_context.run = context.run
171166
return sampler

pysages/backends/core.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,12 @@ def __init__(
4545

4646
self.context = context
4747
self.method = sampling_method
48-
self.view = None
4948
self.run = None
5049

5150
backend = import_module("." + self._backend_name, package="pysages.backends")
5251
self.sampler = backend.bind(self, callback, **kwargs)
5352

54-
# `self.view` and `self.run` *must* be set by the backend bind function.
55-
assert self.view is not None
53+
# `self.run` *must* be set by the backend bind function.
5654
assert self.run is not None
5755

5856
@property

pysages/backends/hoomd.py

Lines changed: 99 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,12 @@
77

88
import hoomd
99
from hoomd import md
10-
from hoomd.dlext import (
11-
AccessLocation,
12-
AccessMode,
13-
DLExtSampler,
14-
SystemView,
15-
images,
16-
net_forces,
17-
positions_types,
18-
rtags,
19-
velocities_masses,
20-
)
10+
from hoomd.dlext import AccessLocation, AccessMode, DLExtSampler, SystemView
2111
from jax import jit
2212
from jax import numpy as np
2313
from jax.dlpack import from_dlpack
2414

15+
from pysages.backends import snapshot as pbs
2516
from pysages.backends.core import SamplingContext
2617
from pysages.backends.snapshot import (
2718
Box,
@@ -30,17 +21,17 @@
3021
SnapshotMethods,
3122
build_data_querier,
3223
)
33-
from pysages.backends.snapshot import restore as _restore
34-
from pysages.typing import Callable
35-
from pysages.utils import check_device_array, copy
24+
from pysages.typing import Callable, Optional
25+
from pysages.utils import copy
26+
27+
SamplerBase = DLExtSampler
3628

3729
# TODO: Figure out a way to automatically tie the lifetime of Sampler
3830
# objects to the contexts they bind to
3931
CONTEXTS_SAMPLERS = {}
4032

4133

4234
if getattr(hoomd, "__version__", "").startswith("2."):
43-
SamplerBase = DLExtSampler
4435

4536
def is_on_gpu(context):
4637
return context.on_gpu()
@@ -61,10 +52,7 @@ def remove_half_step_hook(context):
6152
context.integrator.cpp_integrator.removeHalfStepHook()
6253

6354
else:
64-
if hasattr(hoomd.dlext, "__version__"):
65-
SamplerBase = DLExtSampler
66-
67-
else:
55+
if not hasattr(hoomd.dlext, "__version__"):
6856

6957
class SamplerBase(DLExtSampler, md.HalfStepHook):
7058
def __init__(self, sysview, update, location, mode):
@@ -91,41 +79,71 @@ def remove_half_step_hook(context):
9179
context.operations.integrator.half_step_hook = None
9280

9381

82+
if hasattr(AccessLocation, "OnDevice"):
83+
84+
def default_location():
85+
return AccessLocation.OnDevice
86+
87+
else:
88+
89+
def default_location():
90+
return AccessLocation.OnHost
91+
92+
9493
class Sampler(SamplerBase):
95-
def __init__(self, sysview, method_bundle, bias, callback: Callable, restore):
96-
initial_snapshot, initialize, method_update = method_bundle
94+
"""
95+
HOOMD-blue HalfStepHook that connects PySAGES sampling methods to HOOMD-blue simulations.
9796
98-
def update(positions, vel_mass, rtags, images, forces, timestep):
99-
snapshot = self._pack_snapshot(positions, vel_mass, forces, rtags, images)
100-
self.state = method_update(snapshot, self.state)
101-
self.bias(snapshot, self.state)
102-
if self.callback:
103-
self.callback(snapshot, self.state, timestep)
97+
Parameters
98+
----------
10499
105-
super().__init__(sysview, update, default_location(), AccessMode.Read)
106-
self.state = initialize()
107-
self.bias = bias
108-
self.box = initial_snapshot.box
100+
context: hoomd.Simulation
101+
The HOOMD-blue simulation instance to which the PySAGES sampling machinery will be hooked.
102+
103+
sampling_method: pysages.methods.SamplingMethod
104+
The sampling method used.
105+
106+
callbacks: Optional[Callback]
107+
A callback. Some methods define one for logging, but it can also be user-defined.
108+
109+
location: hoomd.dlext.AccessLocation
110+
Device where the simulation data will be retrieved.
111+
"""
112+
113+
def __init__(
114+
self, context, sampling_method, callback: Optional[Callable], location=default_location()
115+
):
116+
self.context = context
109117
self.callback = callback
110-
self.dt = initial_snapshot.dt
111-
self._restore = restore
118+
self.location = location
119+
self.view = SystemView(get_system(context))
112120

113-
def restore(self, prev_snapshot):
114-
def restore_callback(positions, vel_mass, rtags, images, forces, n):
115-
snapshot = self._pack_snapshot(positions, vel_mass, forces, rtags, images)
116-
self._restore(snapshot, prev_snapshot)
121+
self.box = Box(*get_global_box(self.view))
122+
self.dt = get_timestep(self.context)
123+
self.update_box = lambda: self.box # NOTE: extend for NPT simulations
117124

118-
self.forward_data(restore_callback, default_location(), AccessMode.Overwrite, 0)
125+
super().__init__(self.view, self._update_callback, location, AccessMode.Read)
119126

120-
def take_snapshot(self):
121-
container = []
127+
# Create initial snapshot and setup sampling method
128+
snapshot = self.take_snapshot() # sets `self.snapshot`
129+
helpers, restore, bias = build_helpers(context, sampling_method)
130+
_, initialize, method_update = sampling_method.build(snapshot, helpers)
122131

123-
def snapshot_callback(positions, vel_mass, rtags, images, forces, n):
124-
snapshot = self._pack_snapshot(positions, vel_mass, forces, rtags, images)
125-
container.append(copy(snapshot))
132+
# Initialize state and store method references
133+
self.state = initialize()
134+
self._restore = restore
135+
self._method_update = method_update
136+
self._bias = bias
137+
138+
def restore(self, prev_snapshot):
139+
"""Restore the simulation state from a previous snapshot."""
140+
self.snapshot = prev_snapshot
141+
self.forward_data(self._restore_callback, self.location, AccessMode.Overwrite, 0)
126142

127-
self.forward_data(snapshot_callback, default_location(), AccessMode.Read, 0)
128-
return container[0]
143+
def take_snapshot(self):
144+
"""Take a snapshot of the current simulation state."""
145+
self.forward_data(self._snapshot_callback, self.location, AccessMode.Read, 0)
146+
return self.snapshot
129147

130148
def _pack_snapshot(self, positions, vel_mass, forces, rtags, images):
131149
return Snapshot(
@@ -134,44 +152,25 @@ def _pack_snapshot(self, positions, vel_mass, forces, rtags, images):
134152
from_dlpack(forces),
135153
from_dlpack(rtags),
136154
from_dlpack(images),
137-
self.box,
155+
self.update_box(),
138156
self.dt,
139157
)
140158

159+
# NOTE: The order of the callbacks arguments do not match that of the `Snapshot` attributes
160+
def _restore_callback(self, positions, vel_mass, rtags, images, forces, _):
161+
snapshot = self._pack_snapshot(positions, vel_mass, forces, rtags, images)
162+
self._restore(snapshot, self.snapshot)
141163

142-
if hasattr(AccessLocation, "OnDevice"):
164+
def _snapshot_callback(self, positions, vel_mass, rtags, images, forces, _):
165+
snapshot = self._pack_snapshot(positions, vel_mass, forces, rtags, images)
166+
self.snapshot = copy(snapshot)
143167

144-
def default_location():
145-
return AccessLocation.OnDevice
146-
147-
else:
148-
149-
def default_location():
150-
return AccessLocation.OnHost
151-
152-
153-
def take_snapshot(sampling_context, location=default_location()):
154-
context = sampling_context.context
155-
sysview = sampling_context.view
156-
positions = copy(from_dlpack(positions_types(sysview, location, AccessMode.Read)))
157-
vel_mass = copy(from_dlpack(velocities_masses(sysview, location, AccessMode.Read)))
158-
forces = copy(from_dlpack(net_forces(sysview, location, AccessMode.ReadWrite)))
159-
ids = copy(from_dlpack(rtags(sysview, location, AccessMode.Read)))
160-
imgs = copy(from_dlpack(images(sysview, location, AccessMode.Read)))
161-
162-
check_device_array(positions) # currently, we only support `DeviceArray`s
163-
164-
box = sysview.particle_data.getGlobalBox()
165-
L = box.getL()
166-
xy = box.getTiltFactorXY()
167-
xz = box.getTiltFactorXZ()
168-
yz = box.getTiltFactorYZ()
169-
lo = box.getLo()
170-
H = ((L.x, xy * L.y, xz * L.z), (0.0, L.y, yz * L.z), (0.0, 0.0, L.z))
171-
origin = (lo.x, lo.y, lo.z)
172-
dt = get_integrator(context).dt
173-
174-
return Snapshot(positions, vel_mass, forces, ids, imgs, Box(H, origin), dt)
168+
def _update_callback(self, positions, vel_mass, rtags, images, forces, timestep):
169+
snapshot = self._pack_snapshot(positions, vel_mass, forces, rtags, images)
170+
self.state = self._method_update(snapshot, self.state)
171+
self._bias(snapshot, self.state, self.view.synchronize)
172+
if self.callback:
173+
self.callback(snapshot, self.state, timestep)
175174

176175

177176
def build_snapshot_methods(sampling_method):
@@ -235,30 +234,40 @@ def dimensionality():
235234

236235
snapshot_methods = build_snapshot_methods(sampling_method)
237236
flags = sampling_method.snapshot_flags
238-
restore = partial(_restore, view)
237+
restore = partial(pbs.restore, view)
239238
helpers = HelperMethods(build_data_querier(snapshot_methods, flags), dimensionality)
240239

241240
return helpers, restore, bias
242241

243242

243+
def get_global_box(system_view):
244+
"""Get the box and origin of a HOOMD-blue simulation."""
245+
box = system_view.particle_data.getGlobalBox()
246+
L = box.getL()
247+
xy = box.getTiltFactorXY()
248+
xz = box.getTiltFactorXZ()
249+
yz = box.getTiltFactorYZ()
250+
lo = box.getLo()
251+
H = ((L.x, xy * L.y, xz * L.z), (0.0, L.y, yz * L.z), (0.0, 0.0, L.z))
252+
origin = (lo.x, lo.y, lo.z)
253+
return H, origin
254+
255+
256+
def get_timestep(context):
257+
"""Get the timestep of a HOOMD-blue simulation."""
258+
return get_integrator(context).dt
259+
260+
244261
def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
245262
context = sampling_context.context
246263
sampling_method = sampling_context.method
247-
sysview = SystemView(get_system(context))
248-
sampling_context.view = sysview
249-
sampling_context.run = get_run_method(context)
250-
helpers, restore, bias = build_helpers(context, sampling_method)
251264

252-
with sysview:
253-
snapshot = take_snapshot(sampling_context)
254-
255-
method_bundle = sampling_method.build(snapshot, helpers)
256-
sync_and_bias = partial(bias, sync_backend=sysview.synchronize)
257-
sampler = Sampler(sysview, method_bundle, sync_and_bias, callback, restore)
265+
sampler = Sampler(context, sampling_method, callback)
258266
set_half_step_hook(context, sampler)
259-
260267
CONTEXTS_SAMPLERS[context] = sampler
261268

269+
sampling_context.run = get_run_method(context)
270+
262271
return sampler
263272

264273

pysages/backends/jax-md.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
SnapshotMethods,
1414
build_data_querier,
1515
)
16-
from pysages.typing import Callable, NamedTuple
16+
from pysages.typing import Callable
1717
from pysages.utils import check_device_array, copy
1818

1919

@@ -117,10 +117,6 @@ def run(timesteps):
117117
return run
118118

119119

120-
class View(NamedTuple):
121-
synchronize: Callable
122-
123-
124120
def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
125121
context = sampling_context.context
126122
sampling_method = sampling_context.method
@@ -129,7 +125,6 @@ def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
129125
helpers = build_helpers(context, sampling_method)
130126
method_bundle = sampling_method.build(snapshot, helpers)
131127
sampler = Sampler(method_bundle, context_state, callback)
132-
sampling_context.view = View((lambda: None))
133128
sampling_context.run = build_runner(
134129
context, sampler, jit_compile=kwargs.get("jit_compile", True)
135130
)

0 commit comments

Comments
 (0)