Skip to content

Commit 47e4aa4

Browse files
committed
Add support for Qbox as backend
1 parent 40632b4 commit 47e4aa4

4 files changed

Lines changed: 241 additions & 5 deletions

File tree

pysages/backends/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# SPDX-License-Identifier: MIT
22
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES
33

4-
from .contexts import JaxMDContext, JaxMDContextState # noqa: E402, F401
4+
from .contexts import ( # noqa: E402, F401
5+
JaxMDContext,
6+
JaxMDContextState,
7+
QboxContextGenerator,
8+
)
59
from .core import SamplingContext, supported_backends # noqa: E402, F401

pysages/backends/contexts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,12 @@ def finalize(qb):
159159
initial_state += state + b"\n</fpmd:simulation>"
160160
super().__setattr__("initial_state", et.fromstring(initial_state))
161161

162-
def process_input(self, entries: Union[str, Iterable[str]], target=r"\[qbox\] "):
162+
def process_input(self, entries: Union[str, Iterable[str]], target=r"\[qbox\] ", timeout=None):
163163
qb = self.instance
164164
state = b""
165165
for entry in splitlines(entries):
166166
qb.sendline(entry)
167-
qb.expect(target)
167+
qb.expect(target, timeout=timeout)
168168
state += qb.before
169169
# We add tags to ensure that the state corresponds to a valid xml section
170170
super().__setattr__("state", et.fromstring(b"<root>\n" + state + b"\n</root>"))

pysages/backends/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from importlib import import_module
55

6-
from pysages.backends.contexts import JaxMDContext
6+
from pysages.backends.contexts import JaxMDContext, QboxContextGenerator
77
from pysages.typing import Callable, Optional
88

99

@@ -38,6 +38,8 @@ def __init__(
3838
self._backend_name = "lammps"
3939
elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"):
4040
self._backend_name = "openmm"
41+
elif isinstance(context, QboxContextGenerator):
42+
self._backend_name = "qbox"
4143

4244
if self._backend_name is None:
4345
backends = ", ".join(supported_backends())
@@ -74,4 +76,4 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
7476

7577

7678
def supported_backends():
77-
return ("ase", "hoomd", "jax-md", "lammps", "openmm")
79+
return ("ase", "hoomd", "jax-md", "lammps", "openmm", "qbox")

pysages/backends/qbox.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# SPDX-License-Identifier: MIT
2+
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES
3+
4+
"""
5+
This module defines the Sampler class, which is a LAMMPS fix that enables any PySAGES
6+
SamplingMethod to be hooked to a LAMMPS simulation instance.
7+
"""
8+
9+
from jax import jit
10+
from jax import numpy as np
11+
12+
from pysages.backends.core import SamplingContext
13+
from pysages.backends.snapshot import (
14+
Box,
15+
HelperMethods,
16+
Snapshot,
17+
SnapshotMethods,
18+
build_data_querier,
19+
)
20+
from pysages.typing import Callable, Optional
21+
from pysages.utils import PatternMatcher, contains, identity, last, parse_array
22+
23+
24+
class Sampler:
25+
"""
26+
Allows performing enhanced sampling simulations with Qbox as a backend.
27+
28+
Parameters
29+
----------
30+
31+
context: QboxContext
32+
Contains a running instance of a Qbox simulation to which the PySAGES sampling
33+
machinery will be hooked.
34+
35+
sampling_method: SamplingMethod
36+
The sampling method to be used.
37+
38+
callbacks: Optional[Callback]
39+
Some methods define callbacks for logging, but it can also be user-defined.
40+
"""
41+
42+
def __init__(self, context, sampling_method, callback: Optional[Callable]):
43+
self.context = context
44+
self.callback = callback
45+
46+
self.snapshot = self.take_snapshot()
47+
helpers, bias = build_helpers(context, sampling_method)
48+
_, initialize, method_update = sampling_method.build(self.snapshot, helpers)
49+
50+
# Initialize external forces for each atom
51+
for name in atom_names(context):
52+
# Initialize with zero force
53+
cmd = f"extforce define atomic {name} {name} 0.0 0.0 0.0"
54+
context.process_input(cmd)
55+
56+
self.state = initialize()
57+
self._update_box = lambda: self.snapshot.box
58+
self._method_update = method_update
59+
self._bias = bias
60+
61+
def _pack_snapshot(self, masses, ids, box, dt):
62+
"""Returns the dynamic properties of the system."""
63+
positions = atom_property(self.context, "position")
64+
velocities = atom_property(self.context, "velocity")
65+
forces = atom_property(self.context, "force")
66+
return Snapshot(positions, (velocities, masses), forces, ids, None, box, dt)
67+
68+
def _update_snapshot(self):
69+
"""Updates the snapshot with the latest properties from Qbox."""
70+
snapshot = self.snapshot
71+
_, masses = snapshot.vel_mass
72+
return self._pack_snapshot(masses, snapshot.ids, self._update_box(), snapshot.dt)
73+
74+
def restore(self, prev_snapshot):
75+
"""Replaces this sampler's snapshot with `prev_snapshot`."""
76+
context = self.context
77+
names = atom_names(context)
78+
positions = prev_snapshot.positions
79+
velocities, _ = prev_snapshot.vel_mass
80+
81+
for name, x, v in zip(names, positions, velocities):
82+
cmd = f"move {name} to {x[0]} {x[1]} {x[2]} {v[0]} {v[1]} {v[2]}"
83+
context.process_input(cmd)
84+
85+
# Recompute ground-state energies and forces.
86+
# NOTE: Check in the future how to use Qbox `load` and `save` commands to also
87+
# include the electronic wave function data.
88+
context.process_input(f"run 0 {context.niter} {context.nitscf}")
89+
self.snapshot = self._update_snapshot()
90+
91+
def take_snapshot(self):
92+
"""Returns a copy of the current snapshot of the system."""
93+
masses = atom_masses(self.context)
94+
ids = np.arange(len(masses))
95+
snapshot_box = Box(*box(self.context))
96+
dt = timestep(self.context)
97+
return self._pack_snapshot(masses, ids, snapshot_box, dt)
98+
99+
def update(self, timestep):
100+
"""Update the sampling method state and apply bias."""
101+
self.snapshot = self._update_snapshot()
102+
self.state = self._method_update(self.snapshot, self.state)
103+
self._bias(self.snapshot, self.state)
104+
if self.callback:
105+
self.callback(self.snapshot, self.state, timestep)
106+
107+
def run(self, nsteps: int):
108+
"""Run the Qbox simulation for nsteps."""
109+
cmd = f"run 1 {self.context.niter} {self.context.nitscf}"
110+
for step in range(nsteps):
111+
# Send run command to Qbox for a single step
112+
self.context.process_input(cmd)
113+
# Update sampling method state after each step
114+
self.update(step)
115+
116+
117+
def build_snapshot_methods(sampling_method):
118+
"""
119+
Builds methods for retrieving snapshot properties in a format useful for collective
120+
variable calculations.
121+
"""
122+
123+
def positions(snapshot):
124+
return snapshot.positions
125+
126+
def indices(snapshot):
127+
return snapshot.ids
128+
129+
def momenta(snapshot):
130+
V, M = snapshot.vel_mass
131+
return (M * V).flatten()
132+
133+
def masses(snapshot):
134+
_, M = snapshot.vel_mass
135+
return M
136+
137+
return SnapshotMethods(positions, indices, jit(momenta), masses)
138+
139+
140+
def build_helpers(context, sampling_method):
141+
"""
142+
Builds helper methods used for restoring snapshots and biasing a simulation.
143+
"""
144+
# Precompute atom names since they won't change
145+
names = atom_names(context)
146+
147+
def extforce_cmd(name, force):
148+
return f"extforce set {name} {force[0]} {force[1]} {force[2]}"
149+
150+
def bias(snapshot, state):
151+
"""Adds the computed bias to the forces using Qbox's extforce command."""
152+
if state.bias is None:
153+
return
154+
# Generate and send all extforce commands
155+
context.process_input(extforce_cmd(name, force) for name, force in zip(names, state.bias))
156+
157+
snapshot_methods = build_snapshot_methods(sampling_method)
158+
flags = sampling_method.snapshot_flags
159+
helpers = HelperMethods(build_data_querier(snapshot_methods, flags), lambda: 3)
160+
161+
return helpers, bias
162+
163+
164+
species_name = PatternMatcher(r"atom\s+[^\s]+\s+([^\s]+)")
165+
atom_name = PatternMatcher(r"atom\s+([^\s]+)")
166+
167+
168+
def is_cmd(name, elem):
169+
return contains(elem.text, r"^\s*" + name)
170+
171+
172+
def species_masses(context):
173+
species = context.initial_state.iter("species")
174+
return {elem.attrib["name"]: float(elem.find("mass").text) for elem in species}
175+
176+
177+
def static_atom_property(context, prop: Callable):
178+
cmds = context.initial_state.iter("cmd")
179+
return [prop(elem.text) for elem in cmds if is_cmd("atom", elem)]
180+
181+
182+
def atom_property(context, prop: str):
183+
atomset = last(context.state.iter("atomset"))
184+
if atomset is None:
185+
context.process_input("run 0 0 0")
186+
atomset = last(context.state.iter("atomset"))
187+
return parse_array(" ".join(elem.text for elem in atomset.iter(prop)))
188+
189+
190+
def atom_masses(context):
191+
masses_mapping = species_masses(context)
192+
atom_mass = identity(lambda s: masses_mapping[species_name(s)])
193+
return np.array(static_atom_property(context, atom_mass)).reshape(-1, 1)
194+
195+
196+
def atom_names(context):
197+
return static_atom_property(context, atom_name)
198+
199+
200+
def box(context):
201+
elem = last(context.state.iter("unit_cell"))
202+
if elem is None:
203+
context.process_input("print cell")
204+
elem = context.state.find("unit_cell")
205+
cell_vecs = " ".join(elem.attrib.values())
206+
H = parse_array(cell_vecs, transpose=True)
207+
origin = np.array([0.0, 0.0, 0.0])
208+
return Box(H, origin)
209+
210+
211+
def timestep(context):
212+
context.process_input("print dt")
213+
elem = context.state.find("cmd")
214+
return float(elem.tail.strip("\ndt= "))
215+
216+
217+
def bind(sampling_context: SamplingContext, callback: Optional[Callable], **kwargs):
218+
"""
219+
Sets up and returns a Sampler which enables performing enhanced sampling simulations.
220+
221+
This function takes a `sampling_context` that has its context attribute as an instance
222+
of a `QboxContext,` and creates a `Sampler` object that connects the PySAGES
223+
sampling method to the Qbox simulation. It also modifies the `sampling_context`'s
224+
`view` and `run` attributes to call the Qbox `run` command.
225+
"""
226+
context = sampling_context.context
227+
sampler = Sampler(context, sampling_context.method, callback)
228+
sampling_context.run = sampler.run
229+
230+
return sampler

0 commit comments

Comments
 (0)