77
88import hoomd
99from hoomd import md
10- from hoomd .dlext import (
11- AccessLocation ,
12- AccessMode ,
13- DLExtSampler ,
14- SystemView ,
15- images ,
16- net_forces ,
17- positions_types ,
18- rtags ,
19- velocities_masses ,
20- )
10+ from hoomd .dlext import AccessLocation , AccessMode , DLExtSampler , SystemView
2111from jax import jit
2212from jax import numpy as np
2313from jax .dlpack import from_dlpack
2414
15+ from pysages .backends import snapshot as pbs
2516from pysages .backends .core import SamplingContext
2617from pysages .backends .snapshot import (
2718 Box ,
3021 SnapshotMethods ,
3122 build_data_querier ,
3223)
33- from pysages .backends .snapshot import restore as _restore
34- from pysages .typing import Callable
35- from pysages .utils import check_device_array , copy
24+ from pysages .typing import Callable , Optional
25+ from pysages .utils import copy
26+
27+ SamplerBase = DLExtSampler
3628
3729# TODO: Figure out a way to automatically tie the lifetime of Sampler
3830# objects to the contexts they bind to
3931CONTEXTS_SAMPLERS = {}
4032
4133
4234if getattr (hoomd , "__version__" , "" ).startswith ("2." ):
43- SamplerBase = DLExtSampler
4435
4536 def is_on_gpu (context ):
4637 return context .on_gpu ()
@@ -61,10 +52,7 @@ def remove_half_step_hook(context):
6152 context .integrator .cpp_integrator .removeHalfStepHook ()
6253
6354else :
64- if hasattr (hoomd .dlext , "__version__" ):
65- SamplerBase = DLExtSampler
66-
67- else :
55+ if not hasattr (hoomd .dlext , "__version__" ):
6856
6957 class SamplerBase (DLExtSampler , md .HalfStepHook ):
7058 def __init__ (self , sysview , update , location , mode ):
@@ -91,41 +79,71 @@ def remove_half_step_hook(context):
9179 context .operations .integrator .half_step_hook = None
9280
9381
82+ if hasattr (AccessLocation , "OnDevice" ):
83+
84+ def default_location ():
85+ return AccessLocation .OnDevice
86+
87+ else :
88+
89+ def default_location ():
90+ return AccessLocation .OnHost
91+
92+
9493class Sampler (SamplerBase ):
95- def __init__ ( self , sysview , method_bundle , bias , callback : Callable , restore ):
96- initial_snapshot , initialize , method_update = method_bundle
94+ """
95+ HOOMD-blue HalfStepHook that connects PySAGES sampling methods to HOOMD-blue simulations.
9796
98- def update (positions , vel_mass , rtags , images , forces , timestep ):
99- snapshot = self ._pack_snapshot (positions , vel_mass , forces , rtags , images )
100- self .state = method_update (snapshot , self .state )
101- self .bias (snapshot , self .state )
102- if self .callback :
103- self .callback (snapshot , self .state , timestep )
97+ Parameters
98+ ----------
10499
105- super ().__init__ (sysview , update , default_location (), AccessMode .Read )
106- self .state = initialize ()
107- self .bias = bias
108- self .box = initial_snapshot .box
100+ context: hoomd.Simulation
101+ The HOOMD-blue simulation instance to which the PySAGES sampling machinery will be hooked.
102+
103+ sampling_method: pysages.methods.SamplingMethod
104+ The sampling method used.
105+
106+ callbacks: Optional[Callback]
107+ A callback. Some methods define one for logging, but it can also be user-defined.
108+
109+ location: hoomd.dlext.AccessLocation
110+ Device where the simulation data will be retrieved.
111+ """
112+
113+ def __init__ (
114+ self , context , sampling_method , callback : Optional [Callable ], location = default_location ()
115+ ):
116+ self .context = context
109117 self .callback = callback
110- self .dt = initial_snapshot . dt
111- self ._restore = restore
118+ self .location = location
119+ self .view = SystemView ( get_system ( context ))
112120
113- def restore (self , prev_snapshot ):
114- def restore_callback (positions , vel_mass , rtags , images , forces , n ):
115- snapshot = self ._pack_snapshot (positions , vel_mass , forces , rtags , images )
116- self ._restore (snapshot , prev_snapshot )
121+ self .box = Box (* get_global_box (self .view ))
122+ self .dt = get_timestep (self .context )
123+ self .update_box = lambda : self .box # NOTE: extend for NPT simulations
117124
118- self . forward_data ( restore_callback , default_location (), AccessMode . Overwrite , 0 )
125+ super (). __init__ ( self . view , self . _update_callback , location , AccessMode . Read )
119126
120- def take_snapshot (self ):
121- container = []
127+ # Create initial snapshot and setup sampling method
128+ snapshot = self .take_snapshot () # sets `self.snapshot`
129+ helpers , restore , bias = build_helpers (context , sampling_method )
130+ _ , initialize , method_update = sampling_method .build (snapshot , helpers )
122131
123- def snapshot_callback (positions , vel_mass , rtags , images , forces , n ):
124- snapshot = self ._pack_snapshot (positions , vel_mass , forces , rtags , images )
125- container .append (copy (snapshot ))
132+ # Initialize state and store method references
133+ self .state = initialize ()
134+ self ._restore = restore
135+ self ._method_update = method_update
136+ self ._bias = bias
137+
138+ def restore (self , prev_snapshot ):
139+ """Restore the simulation state from a previous snapshot."""
140+ self .snapshot = prev_snapshot
141+ self .forward_data (self ._restore_callback , self .location , AccessMode .Overwrite , 0 )
126142
127- self .forward_data (snapshot_callback , default_location (), AccessMode .Read , 0 )
128- return container [0 ]
143+ def take_snapshot (self ):
144+ """Take a snapshot of the current simulation state."""
145+ self .forward_data (self ._snapshot_callback , self .location , AccessMode .Read , 0 )
146+ return self .snapshot
129147
130148 def _pack_snapshot (self , positions , vel_mass , forces , rtags , images ):
131149 return Snapshot (
@@ -134,44 +152,25 @@ def _pack_snapshot(self, positions, vel_mass, forces, rtags, images):
134152 from_dlpack (forces ),
135153 from_dlpack (rtags ),
136154 from_dlpack (images ),
137- self .box ,
155+ self .update_box () ,
138156 self .dt ,
139157 )
140158
159+ # NOTE: The order of the callbacks arguments do not match that of the `Snapshot` attributes
160+ def _restore_callback (self , positions , vel_mass , rtags , images , forces , _ ):
161+ snapshot = self ._pack_snapshot (positions , vel_mass , forces , rtags , images )
162+ self ._restore (snapshot , self .snapshot )
141163
142- if hasattr (AccessLocation , "OnDevice" ):
164+ def _snapshot_callback (self , positions , vel_mass , rtags , images , forces , _ ):
165+ snapshot = self ._pack_snapshot (positions , vel_mass , forces , rtags , images )
166+ self .snapshot = copy (snapshot )
143167
144- def default_location ():
145- return AccessLocation .OnDevice
146-
147- else :
148-
149- def default_location ():
150- return AccessLocation .OnHost
151-
152-
153- def take_snapshot (sampling_context , location = default_location ()):
154- context = sampling_context .context
155- sysview = sampling_context .view
156- positions = copy (from_dlpack (positions_types (sysview , location , AccessMode .Read )))
157- vel_mass = copy (from_dlpack (velocities_masses (sysview , location , AccessMode .Read )))
158- forces = copy (from_dlpack (net_forces (sysview , location , AccessMode .ReadWrite )))
159- ids = copy (from_dlpack (rtags (sysview , location , AccessMode .Read )))
160- imgs = copy (from_dlpack (images (sysview , location , AccessMode .Read )))
161-
162- check_device_array (positions ) # currently, we only support `DeviceArray`s
163-
164- box = sysview .particle_data .getGlobalBox ()
165- L = box .getL ()
166- xy = box .getTiltFactorXY ()
167- xz = box .getTiltFactorXZ ()
168- yz = box .getTiltFactorYZ ()
169- lo = box .getLo ()
170- H = ((L .x , xy * L .y , xz * L .z ), (0.0 , L .y , yz * L .z ), (0.0 , 0.0 , L .z ))
171- origin = (lo .x , lo .y , lo .z )
172- dt = get_integrator (context ).dt
173-
174- return Snapshot (positions , vel_mass , forces , ids , imgs , Box (H , origin ), dt )
168+ def _update_callback (self , positions , vel_mass , rtags , images , forces , timestep ):
169+ snapshot = self ._pack_snapshot (positions , vel_mass , forces , rtags , images )
170+ self .state = self ._method_update (snapshot , self .state )
171+ self ._bias (snapshot , self .state , self .view .synchronize )
172+ if self .callback :
173+ self .callback (snapshot , self .state , timestep )
175174
176175
177176def build_snapshot_methods (sampling_method ):
@@ -235,30 +234,40 @@ def dimensionality():
235234
236235 snapshot_methods = build_snapshot_methods (sampling_method )
237236 flags = sampling_method .snapshot_flags
238- restore = partial (_restore , view )
237+ restore = partial (pbs . restore , view )
239238 helpers = HelperMethods (build_data_querier (snapshot_methods , flags ), dimensionality )
240239
241240 return helpers , restore , bias
242241
243242
243+ def get_global_box (system_view ):
244+ """Get the box and origin of a HOOMD-blue simulation."""
245+ box = system_view .particle_data .getGlobalBox ()
246+ L = box .getL ()
247+ xy = box .getTiltFactorXY ()
248+ xz = box .getTiltFactorXZ ()
249+ yz = box .getTiltFactorYZ ()
250+ lo = box .getLo ()
251+ H = ((L .x , xy * L .y , xz * L .z ), (0.0 , L .y , yz * L .z ), (0.0 , 0.0 , L .z ))
252+ origin = (lo .x , lo .y , lo .z )
253+ return H , origin
254+
255+
256+ def get_timestep (context ):
257+ """Get the timestep of a HOOMD-blue simulation."""
258+ return get_integrator (context ).dt
259+
260+
244261def bind (sampling_context : SamplingContext , callback : Callable , ** kwargs ):
245262 context = sampling_context .context
246263 sampling_method = sampling_context .method
247- sysview = SystemView (get_system (context ))
248- sampling_context .view = sysview
249- sampling_context .run = get_run_method (context )
250- helpers , restore , bias = build_helpers (context , sampling_method )
251264
252- with sysview :
253- snapshot = take_snapshot (sampling_context )
254-
255- method_bundle = sampling_method .build (snapshot , helpers )
256- sync_and_bias = partial (bias , sync_backend = sysview .synchronize )
257- sampler = Sampler (sysview , method_bundle , sync_and_bias , callback , restore )
265+ sampler = Sampler (context , sampling_method , callback )
258266 set_half_step_hook (context , sampler )
259-
260267 CONTEXTS_SAMPLERS [context ] = sampler
261268
269+ sampling_context .run = get_run_method (context )
270+
262271 return sampler
263272
264273
0 commit comments