Skip to content

Commit a191e14

Browse files
committed
Make all models compatible with tojax
1 parent 5bf6979 commit a191e14

6 files changed

Lines changed: 325 additions & 18 deletions

File tree

orb_models/common/models/nn_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def forward(self, x: torch.Tensor, online: bool | None = None) -> torch.Tensor:
330330

331331
mu = self.bn.running_mean # type: ignore
332332
sigma = torch.sqrt(self.bn.running_var) # type: ignore
333-
if sigma < 1e-6:
333+
if self.training and sigma < 1e-6:
334334
raise ValueError("ScalarNormalizer has ~zero std.")
335335

336336
return (x - mu) / sigma # type: ignore

orb_models/common/models/segment_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def aggregate_nodes(
4040
https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility"""
4141
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
4242
torch.use_deterministic_algorithms(True)
43-
segments = torch.arange(count, device=device).repeat_interleave(n_node, output_size=tensor.shape[0])
43+
segments = torch.arange(count, device=device).repeat_interleave(
44+
n_node, output_size=tensor.shape[0]
45+
)
4446
if reduction == "sum":
4547
return scatter_sum(tensor, segments, dim=0, dim_size=count)
4648
elif reduction == "mean":

orb_models/forcefield/models/forcefield_utils.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ def maybe_remove_net_force_and_torque(
5656

5757
if remove_torque:
5858
force_pred = _selectively_remove_net_torque_for_nonpbc_systems(
59-
force_pred, batch.positions, batch.system_features["cell"], batch.n_node
59+
force_pred,
60+
batch.positions,
61+
batch.system_features["cell"],
62+
batch.n_node,
63+
batch.node_batch_index,
6064
)
6165

6266
return force_pred
@@ -67,6 +71,7 @@ def _selectively_remove_net_torque_for_nonpbc_systems(
6771
positions: torch.Tensor,
6872
cell: torch.Tensor,
6973
n_node: torch.Tensor,
74+
node_batch_index: torch.Tensor,
7075
):
7176
"""Remove net torque from non-PBC-system forces, but preserve PBC-system forces.
7277
@@ -77,20 +82,10 @@ def _selectively_remove_net_torque_for_nonpbc_systems(
7782
n_node: The number of nodes per graph, of shape (n_batch,).
7883
"""
7984
nopbc_graph = torch.all(cell == 0.0, dim=(1, 2))
80-
if torch.any(nopbc_graph):
81-
if torch.all(nopbc_graph):
82-
pred = _remove_net_torque(positions, pred, n_node)
83-
else:
84-
# Handle a mixed batch of pbc and non-pbc systems
85-
batch_indices = torch.repeat_interleave(
86-
torch.arange(cell.size(0), device=n_node.device), n_node
87-
)
88-
nopbc_atom = nopbc_graph[batch_indices]
89-
adjusted_pred_non_pbc = _remove_net_torque(
90-
positions[nopbc_atom], pred[nopbc_atom], n_node[nopbc_graph]
91-
)
92-
pred = pred.clone()
93-
pred[nopbc_atom] = adjusted_pred_non_pbc
85+
nopbc_atom = nopbc_graph[node_batch_index]
86+
87+
adjusted = _remove_net_torque(positions, pred, n_node)
88+
pred = torch.where(nopbc_atom.unsqueeze(-1), adjusted, pred)
9489

9590
return pred
9691

orb_models/forcefield/models/pair_repulsion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _polynomial_cutoff_with_derivative(self, r, r_max, p):
193193
tuple: (envelope, derivative)
194194
"""
195195
# Convert p to float for calculations
196-
p_float = float(p)
196+
p_float = p.to(r.dtype)
197197

198198
# Mask for r < r_max
199199
mask = (r < r_max).float()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dev = [
4747
"mypy>=1.13",
4848
"torch_dftd",
4949
"torch-sim-atomistic>=0.5.1",
50+
"tojax",
5051
]
5152

5253
[tool.setuptools.dynamic]
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
"""Test tojax compatibility for all pretrained orb models."""
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import numpy as np
6+
import pytest
7+
import torch
8+
from ase.build import bulk
9+
from tojax import tojax
10+
from torch.jit._state import disable as torch_jit_disable
11+
from torch.jit._state import enable as torch_jit_enable
12+
13+
import orb_models.common.models.angular as _angular_mod
14+
from orb_models.common.atoms.batch.graph_batch import AtomGraphs
15+
from orb_models.forcefield.forcefield_adapter import ForcefieldAtomsAdapter
16+
from orb_models.forcefield.pretrained import ORB_PRETRAINED_MODELS
17+
18+
19+
@pytest.fixture(autouse=True, scope="module")
20+
def _ensure_plain_spherical_harmonics():
21+
"""Replace JIT-compiled _spherical_harmonics with a plain Python function.
22+
23+
If another test module imported angular.py with JIT enabled,
24+
_spherical_harmonics is a ScriptFunction that rejects TensorWrappers.
25+
We re-evaluate the module source into a scratch namespace with JIT
26+
disabled and patch only the function — no importlib.reload, so existing
27+
classes (and their super() chains) are untouched.
28+
"""
29+
original_fn = _angular_mod._spherical_harmonics
30+
if isinstance(original_fn, torch.jit.ScriptFunction):
31+
torch_jit_disable()
32+
scratch = dict(vars(_angular_mod))
33+
with open(_angular_mod.__file__) as f:
34+
exec(compile(f.read(), _angular_mod.__file__, "exec"), scratch)
35+
_angular_mod._spherical_harmonics = scratch["_spherical_harmonics"]
36+
torch_jit_enable()
37+
yield
38+
_angular_mod._spherical_harmonics = original_fn
39+
40+
41+
DEPRECATED_MODELS = {
42+
"orb-v1",
43+
"orb-d3-v1",
44+
"orb-d3-sm-v1",
45+
"orb-d3-xs-v1",
46+
"orb-v1-mptraj-only",
47+
}
48+
49+
ACTIVE_MODELS = sorted(name for name in ORB_PRETRAINED_MODELS if name not in DEPRECATED_MODELS)
50+
CONSERVATIVE_MODELS = {name for name in ACTIVE_MODELS if "conservative" in name}
51+
52+
53+
def _make_atoms():
54+
55+
atoms = bulk("Cu", "fcc", a=3.6)
56+
atoms.info["charge"] = 0
57+
atoms.info["spin"] = 1
58+
return atoms
59+
60+
61+
# NOTE: This fails for conservative models because tojax does not support torch.autograd.grad
62+
@pytest.mark.parametrize("model_name", ACTIVE_MODELS)
63+
def test_tojax_outputs_match_pytorch(model_name):
64+
load_fn = ORB_PRETRAINED_MODELS[model_name]
65+
model, adapter = load_fn(device="cpu", compile=False)
66+
model.eval()
67+
68+
atoms = _make_atoms()
69+
torch_batch = adapter.from_ase_atoms(atoms, device="cpu")
70+
jax_batch = adapter.from_ase_atoms(atoms, device="cpu")
71+
72+
jax_model = tojax(model)
73+
74+
with torch.enable_grad():
75+
torch_out = model(torch_batch)
76+
77+
jax_out = jax_model(jax_batch)
78+
79+
# Only compare model predictions, not internal intermediates like
80+
# node_features / edge_features where float32 rounding differences
81+
# accumulate across message-passing layers and many edges.
82+
INTERNAL_KEYS = {"node_features", "edge_features"}
83+
prediction_keys = set(torch_out) - INTERNAL_KEYS
84+
missing = prediction_keys - set(jax_out)
85+
assert not missing, f"Keys missing from JAX output: {missing}"
86+
87+
for key in prediction_keys:
88+
torch_val = torch_out[key].detach().float().cpu().numpy()
89+
jax_val = np.asarray(jax_out[key], dtype=np.float32)
90+
np.testing.assert_allclose(
91+
jax_val,
92+
torch_val,
93+
atol=1e-5,
94+
rtol=1e-4,
95+
err_msg=f"Output mismatch for '{model_name}' key '{key}'",
96+
)
97+
98+
99+
def _from_kups_atoms(adapter: ForcefieldAtomsAdapter, data: dict[str, torch.Tensor]) -> AtomGraphs:
100+
"""Build AtomGraphs from a kups AtomGraphInput dict."""
101+
senders = data["edge_index"][0]
102+
receivers = data["edge_index"][1]
103+
n_systems = data["cell"].shape[0]
104+
batch = data["batch"]
105+
src_batch = batch[senders]
106+
107+
n_node = torch.zeros(n_systems, dtype=torch.int64).scatter_add_(
108+
0, batch, torch.ones_like(batch)
109+
)
110+
nedges = torch.zeros(n_systems, dtype=torch.int64).scatter_add_(
111+
0, src_batch, torch.ones_like(src_batch)
112+
)
113+
114+
atomic_numbers = data["atomic_numbers"]
115+
atomic_numbers_embedding = torch.nn.functional.one_hot(atomic_numbers.long(), 118).float()
116+
117+
system_features: dict[str, torch.Tensor] = {
118+
"cell": data["cell"],
119+
"pbc": data["pbc"],
120+
}
121+
if "charge" in data:
122+
system_features["total_charge"] = data["charge"]
123+
if "spin" in data:
124+
system_features["spin_multiplicity"] = data["spin"]
125+
126+
cells_per_edge = data["cell"][src_batch]
127+
shifts = torch.bmm(data["cell_offsets"].unsqueeze(1), cells_per_edge).squeeze(1)
128+
vectors = data["pos"][receivers] - data["pos"][senders] + shifts
129+
130+
assert adapter.radius is not None, "Adapter radius must be set"
131+
return AtomGraphs(
132+
senders=senders,
133+
receivers=receivers,
134+
n_node=n_node,
135+
n_edge=nedges,
136+
node_features={
137+
"positions": data["pos"],
138+
"atomic_numbers": atomic_numbers,
139+
"atomic_numbers_embedding": atomic_numbers_embedding,
140+
"atom_identity": torch.arange(data["pos"].shape[0], dtype=torch.int64),
141+
},
142+
system_features=system_features,
143+
edge_features={"unit_shifts": data["cell_offsets"], "vectors": vectors},
144+
node_targets={},
145+
edge_targets={},
146+
system_targets={},
147+
system_id=None,
148+
fix_atoms=None,
149+
tags=None,
150+
radius=adapter.radius,
151+
max_num_neighbors=nedges,
152+
half_supercell=False,
153+
)
154+
155+
156+
def _to_kups_atoms(atomgraph: AtomGraphs) -> dict[str, torch.Tensor]:
157+
"""Convert to the flat AtomGraphInput dict used by kups/tojax."""
158+
n_atoms = int(atomgraph.n_node.sum())
159+
n_systems = atomgraph.n_node.shape[0]
160+
device = atomgraph.n_node.device
161+
return {
162+
"pos": atomgraph.node_features["positions"],
163+
"atomic_numbers": atomgraph.node_features["atomic_numbers"],
164+
"cell": atomgraph.system_features["cell"],
165+
"pbc": atomgraph.system_features["pbc"],
166+
"edge_index": torch.stack([atomgraph.senders, atomgraph.receivers]),
167+
"cell_offsets": atomgraph.edge_features["unit_shifts"],
168+
"batch": torch.arange(n_systems, device=device).repeat_interleave(
169+
atomgraph.n_node, output_size=n_atoms
170+
),
171+
"charge": atomgraph.system_features.get(
172+
"total_charge", torch.zeros(n_systems, dtype=torch.long, device=device)
173+
).view(-1),
174+
"spin": atomgraph.system_features.get(
175+
"spin_multiplicity", torch.zeros(n_systems, dtype=torch.long, device=device)
176+
).view(-1),
177+
}
178+
179+
180+
def _make_predict_fn(adapter, model):
181+
"""Build a tojax-traceable predict function: kups AtomGraphInput -> {energy, forces, stress}."""
182+
183+
def predict_fn(data):
184+
graph = _from_kups_atoms(adapter, data)
185+
result = model.predict(graph, split=False)
186+
out = {"energy": result["energy"]}
187+
if "forces" in result:
188+
out["forces"] = result["forces"]
189+
if "stress" in result:
190+
out["stress"] = result["stress"]
191+
return out
192+
193+
return predict_fn
194+
195+
196+
@pytest.mark.parametrize("model_name", ACTIVE_MODELS)
197+
def test_kups_compatibility(model_name):
198+
"""Test energy, forces, and stress using the kups AtomGraphInput format.
199+
200+
Since tojax cannot translate torch.autograd.grad, this test follows the
201+
pattern from the tojax export example (export_orb.py) where model + graph
202+
construction are packaged into a single function that accepts a flat
203+
AtomGraphInput dict — the same format kups uses.
204+
205+
Unlike test_tojax_outputs_match_pytorch which passes the model directly
206+
to tojax(model), this test tojax-traces a function that builds an
207+
AtomGraphs from raw tensors and calls model.predict(). For conservative
208+
models, forces and stress are computed via jax.grad of the energy
209+
(replacing torch.autograd.grad with JAX's own autodiff).
210+
"""
211+
212+
load_fn = ORB_PRETRAINED_MODELS[model_name]
213+
model, adapter = load_fn(device="cpu", compile=False)
214+
model.eval()
215+
is_conservative = model_name in CONSERVATIVE_MODELS
216+
217+
atoms = _make_atoms()
218+
batch = adapter.from_ase_atoms(atoms, device="cpu")
219+
data = _to_kups_atoms(batch)
220+
221+
# PyTorch reference
222+
ref_batch = adapter.from_ase_atoms(atoms, device="cpu")
223+
with torch.enable_grad():
224+
torch_out = model.predict(ref_batch)
225+
226+
jax_data = tojax(data)
227+
228+
if is_conservative:
229+
# Energy + forces/stress via jax.grad (torch.autograd.grad can't be translated)
230+
predict_fn = _make_predict_fn(adapter, model)
231+
jax_energy_fn = tojax(lambda data: predict_fn(data)["energy"])
232+
batch_indices = jax_data["batch"]
233+
234+
def _energy_with_strain(pos, strain):
235+
sym_strain = 0.5 * (strain + jnp.swapaxes(strain, -1, -2))
236+
deformed_cell = jax_data["cell"] + jnp.einsum(
237+
"bij,bjk->bik", jax_data["cell"], sym_strain
238+
)
239+
deformed_pos = pos + jnp.einsum("ni,nij->nj", pos, sym_strain[batch_indices])
240+
return jax_energy_fn({**jax_data, "pos": deformed_pos, "cell": deformed_cell}).sum()
241+
242+
energy_and_grads = jax.value_and_grad(_energy_with_strain, argnums=(0, 1))
243+
jax_energy, (neg_forces, virial) = energy_and_grads(
244+
jax_data["pos"], jnp.zeros_like(jax_data["cell"])
245+
)
246+
247+
np.testing.assert_allclose(
248+
np.asarray(jax_energy),
249+
torch_out[model.energy_name].detach().float().numpy(),
250+
atol=1e-5,
251+
rtol=1e-4,
252+
err_msg=f"energy mismatch for '{model_name}'",
253+
)
254+
np.testing.assert_allclose(
255+
np.asarray(-neg_forces, dtype=np.float32),
256+
torch_out[model.grad_forces_name].detach().float().numpy(),
257+
atol=1e-4,
258+
rtol=1e-4,
259+
err_msg=f"force mismatch for '{model_name}'",
260+
)
261+
if model.has_stress:
262+
volume = jnp.abs(jnp.linalg.det(jax_data["cell"]))
263+
jax_stress_3x3 = np.asarray(virial / volume[:, None, None])
264+
jax_stress = np.stack(
265+
[
266+
jax_stress_3x3[..., 0, 0],
267+
jax_stress_3x3[..., 1, 1],
268+
jax_stress_3x3[..., 2, 2],
269+
(jax_stress_3x3[..., 1, 2] + jax_stress_3x3[..., 2, 1]) / 2,
270+
(jax_stress_3x3[..., 0, 2] + jax_stress_3x3[..., 2, 0]) / 2,
271+
(jax_stress_3x3[..., 0, 1] + jax_stress_3x3[..., 1, 0]) / 2,
272+
],
273+
axis=-1,
274+
).astype(np.float32)
275+
np.testing.assert_allclose(
276+
jax_stress.reshape(torch_out[model.grad_stress_name].shape),
277+
torch_out[model.grad_stress_name].detach().float().numpy(),
278+
atol=1e-4,
279+
rtol=1e-4,
280+
err_msg=f"stress mismatch for '{model_name}'",
281+
)
282+
else:
283+
# Direct models: energy, forces, stress from tojax'd predict
284+
jax_predict_fn = tojax(_make_predict_fn(adapter, model))
285+
jax_out = jax_predict_fn(jax_data)
286+
287+
np.testing.assert_allclose(
288+
np.asarray(jax_out["energy"]),
289+
torch_out["energy"].detach().float().numpy(),
290+
atol=1e-5,
291+
rtol=1e-4,
292+
err_msg=f"energy mismatch for '{model_name}'",
293+
)
294+
if "forces" in jax_out:
295+
np.testing.assert_allclose(
296+
np.asarray(jax_out["forces"], dtype=np.float32),
297+
torch_out["forces"].detach().float().numpy(),
298+
atol=1e-4,
299+
rtol=1e-4,
300+
err_msg=f"force mismatch for '{model_name}'",
301+
)
302+
if "stress" in jax_out:
303+
np.testing.assert_allclose(
304+
np.asarray(jax_out["stress"], dtype=np.float32),
305+
torch_out["stress"].detach().float().numpy(),
306+
atol=1e-4,
307+
rtol=1e-4,
308+
err_msg=f"stress mismatch for '{model_name}'",
309+
)

0 commit comments

Comments
 (0)