From 56d28001462ddaeccd5f222e61b3cfbe36143798 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Thu, 23 May 2024 14:11:08 -0500 Subject: [PATCH] Update box information at each timestep --- pysages/backends/hoomd.py | 21 +++++++++++++++------ pysages/backends/openmm.py | 23 +++++++++++++++++------ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/pysages/backends/hoomd.py b/pysages/backends/hoomd.py index a458d7a1..2b86eac4 100644 --- a/pysages/backends/hoomd.py +++ b/pysages/backends/hoomd.py @@ -48,8 +48,8 @@ def is_on_gpu(context): def get_integrator(context): return context.integrator - def get_run_method(context): - return hoomd.run + def get_run_method(_context): + return hoomd.run # pylint: disable=E1101 def get_system(context): return context.system @@ -109,6 +109,9 @@ def update(positions, vel_mass, rtags, images, forces, timestep): self.callback = callback self.dt = initial_snapshot.dt self._restore = restore + # NOTE: Add optimization: Reuse the initial box information when + # no constant-pressure methods are used + self._update_box = lambda: get_global_box(sysview) def restore(self, prev_snapshot): def restore_callback(positions, vel_mass, rtags, images, forces, n): @@ -134,7 +137,7 @@ def _pack_snapshot(self, positions, vel_mass, forces, rtags, images): from_dlpack(forces), from_dlpack(rtags), from_dlpack(images), - self.box, + self._update_box(), self.dt, ) @@ -161,6 +164,14 @@ def take_snapshot(sampling_context, location=default_location()): check_device_array(positions) # currently, we only support `DeviceArray`s + box = get_global_box(sysview) + dt = get_integrator(context).dt + + return Snapshot(positions, vel_mass, forces, ids, imgs, box, dt) + + +def get_global_box(sysview): + """Get the box and origin of a HOOMD-blue simulation.""" box = sysview.particle_data.getGlobalBox() L = box.getL() xy = box.getTiltFactorXY() @@ -169,9 +180,7 @@ def take_snapshot(sampling_context, location=default_location()): lo = box.getLo() H = ((L.x, xy * L.y, xz * L.z), (0.0, L.y, yz * L.z), (0.0, 0.0, L.z)) origin = (lo.x, lo.y, lo.z) - dt = get_integrator(context).dt - - return Snapshot(positions, vel_mass, forces, ids, imgs, Box(H, origin), dt) + return Box(H, origin) def build_snapshot_methods(sampling_method): diff --git a/pysages/backends/openmm.py b/pysages/backends/openmm.py index 0c5f8e44..fe623b79 100644 --- a/pysages/backends/openmm.py +++ b/pysages/backends/openmm.py @@ -30,7 +30,7 @@ class Sampler: - def __init__(self, method_bundle, bias, callback: Callable, restore): + def __init__(self, method_bundle, context, bias, callback: Callable, restore): initial_snapshot, initialize, method_update = method_bundle self.state = initialize() self.bias = bias @@ -38,8 +38,10 @@ def __init__(self, method_bundle, bias, callback: Callable, restore): self.snapshot = initial_snapshot self._restore = restore self._update = method_update + self._update_box = lambda: get_box(context) def update(self, timestep=0): + self.snapshot = self.update_snapshot() self.state = self._update(self.snapshot, self.state) self.bias(self.snapshot, self.state) if self.callback: @@ -51,6 +53,10 @@ def restore(self, prev_snapshot): def take_snapshot(self): return copy(self.snapshot) + def update_snapshot(self): + box = self._update_box() + return self.snapshot._replace(box=box) + def is_on_gpu(view: ContextView): return view.device_type() == DeviceType.GPU @@ -73,16 +79,21 @@ def take_snapshot(sampling_context): check_device_array(positions) # currently, we only support `DeviceArray`s + box = get_box(context) + dt = context.getIntegrator().getStepSize() / unit.picosecond + + # OpenMM doesn't have images + return Snapshot(positions, vel_mass, forces, ids, None, box, dt) + + +def get_box(context): box_vectors = context.getSystem().getDefaultPeriodicBoxVectors() a = box_vectors[0].value_in_unit(unit.nanometer) b = box_vectors[1].value_in_unit(unit.nanometer) c = box_vectors[2].value_in_unit(unit.nanometer) H = ((a[0], b[0], c[0]), (a[1], b[1], c[1]), (a[2], b[2], c[2])) origin = (0.0, 0.0, 0.0) - dt = context.getIntegrator().getStepSize() / unit.picosecond - - # OpenMM doesn't have images - return Snapshot(positions, vel_mass, forces, ids, None, Box(H, origin), dt) + return Box(H, origin) def identity(x): @@ -199,6 +210,6 @@ def bind(sampling_context: SamplingContext, callback: Callable, **kwargs): snapshot = take_snapshot(sampling_context) method_bundle = sampling_method.build(snapshot, helpers) sync_and_bias = partial(bias, sync_backend=sampling_context.view.synchronize) - sampler = Sampler(method_bundle, sync_and_bias, callback, restore) + sampler = Sampler(method_bundle, context, sync_and_bias, callback, restore) force.set_callback_in(context, sampler.update) return sampler