|
| 1 | +"""Test tojax compatibility for all pretrained orb models.""" |
| 2 | + |
| 3 | +import jax |
| 4 | +import jax.numpy as jnp |
| 5 | +import numpy as np |
| 6 | +import pytest |
| 7 | +import torch |
| 8 | +from ase.build import bulk |
| 9 | +from tojax import tojax |
| 10 | +from torch.jit._state import disable as torch_jit_disable |
| 11 | +from torch.jit._state import enable as torch_jit_enable |
| 12 | + |
| 13 | +import orb_models.common.models.angular as _angular_mod |
| 14 | +from orb_models.common.atoms.batch.graph_batch import AtomGraphs |
| 15 | +from orb_models.forcefield.forcefield_adapter import ForcefieldAtomsAdapter |
| 16 | +from orb_models.forcefield.pretrained import ORB_PRETRAINED_MODELS |
| 17 | + |
| 18 | + |
| 19 | +@pytest.fixture(autouse=True, scope="module") |
| 20 | +def _ensure_plain_spherical_harmonics(): |
| 21 | + """Replace JIT-compiled _spherical_harmonics with a plain Python function. |
| 22 | +
|
| 23 | + If another test module imported angular.py with JIT enabled, |
| 24 | + _spherical_harmonics is a ScriptFunction that rejects TensorWrappers. |
| 25 | + We re-evaluate the module source into a scratch namespace with JIT |
| 26 | + disabled and patch only the function — no importlib.reload, so existing |
| 27 | + classes (and their super() chains) are untouched. |
| 28 | + """ |
| 29 | + original_fn = _angular_mod._spherical_harmonics |
| 30 | + if isinstance(original_fn, torch.jit.ScriptFunction): |
| 31 | + torch_jit_disable() |
| 32 | + scratch = dict(vars(_angular_mod)) |
| 33 | + with open(_angular_mod.__file__) as f: |
| 34 | + exec(compile(f.read(), _angular_mod.__file__, "exec"), scratch) |
| 35 | + _angular_mod._spherical_harmonics = scratch["_spherical_harmonics"] |
| 36 | + torch_jit_enable() |
| 37 | + yield |
| 38 | + _angular_mod._spherical_harmonics = original_fn |
| 39 | + |
| 40 | + |
| 41 | +DEPRECATED_MODELS = { |
| 42 | + "orb-v1", |
| 43 | + "orb-d3-v1", |
| 44 | + "orb-d3-sm-v1", |
| 45 | + "orb-d3-xs-v1", |
| 46 | + "orb-v1-mptraj-only", |
| 47 | +} |
| 48 | + |
| 49 | +ACTIVE_MODELS = sorted(name for name in ORB_PRETRAINED_MODELS if name not in DEPRECATED_MODELS) |
| 50 | +CONSERVATIVE_MODELS = {name for name in ACTIVE_MODELS if "conservative" in name} |
| 51 | + |
| 52 | + |
| 53 | +def _make_atoms(): |
| 54 | + |
| 55 | + atoms = bulk("Cu", "fcc", a=3.6) |
| 56 | + atoms.info["charge"] = 0 |
| 57 | + atoms.info["spin"] = 1 |
| 58 | + return atoms |
| 59 | + |
| 60 | + |
| 61 | +# NOTE: This fails for conservative models because tojax does not support torch.autograd.grad |
| 62 | +@pytest.mark.parametrize("model_name", ACTIVE_MODELS) |
| 63 | +def test_tojax_outputs_match_pytorch(model_name): |
| 64 | + load_fn = ORB_PRETRAINED_MODELS[model_name] |
| 65 | + model, adapter = load_fn(device="cpu", compile=False) |
| 66 | + model.eval() |
| 67 | + |
| 68 | + atoms = _make_atoms() |
| 69 | + torch_batch = adapter.from_ase_atoms(atoms, device="cpu") |
| 70 | + jax_batch = adapter.from_ase_atoms(atoms, device="cpu") |
| 71 | + |
| 72 | + jax_model = tojax(model) |
| 73 | + |
| 74 | + with torch.enable_grad(): |
| 75 | + torch_out = model(torch_batch) |
| 76 | + |
| 77 | + jax_out = jax_model(jax_batch) |
| 78 | + |
| 79 | + # Only compare model predictions, not internal intermediates like |
| 80 | + # node_features / edge_features where float32 rounding differences |
| 81 | + # accumulate across message-passing layers and many edges. |
| 82 | + INTERNAL_KEYS = {"node_features", "edge_features"} |
| 83 | + prediction_keys = set(torch_out) - INTERNAL_KEYS |
| 84 | + missing = prediction_keys - set(jax_out) |
| 85 | + assert not missing, f"Keys missing from JAX output: {missing}" |
| 86 | + |
| 87 | + for key in prediction_keys: |
| 88 | + torch_val = torch_out[key].detach().float().cpu().numpy() |
| 89 | + jax_val = np.asarray(jax_out[key], dtype=np.float32) |
| 90 | + np.testing.assert_allclose( |
| 91 | + jax_val, |
| 92 | + torch_val, |
| 93 | + atol=1e-5, |
| 94 | + rtol=1e-4, |
| 95 | + err_msg=f"Output mismatch for '{model_name}' key '{key}'", |
| 96 | + ) |
| 97 | + |
| 98 | + |
| 99 | +def _from_kups_atoms(adapter: ForcefieldAtomsAdapter, data: dict[str, torch.Tensor]) -> AtomGraphs: |
| 100 | + """Build AtomGraphs from a kups AtomGraphInput dict.""" |
| 101 | + senders = data["edge_index"][0] |
| 102 | + receivers = data["edge_index"][1] |
| 103 | + n_systems = data["cell"].shape[0] |
| 104 | + batch = data["batch"] |
| 105 | + src_batch = batch[senders] |
| 106 | + |
| 107 | + n_node = torch.zeros(n_systems, dtype=torch.int64).scatter_add_( |
| 108 | + 0, batch, torch.ones_like(batch) |
| 109 | + ) |
| 110 | + nedges = torch.zeros(n_systems, dtype=torch.int64).scatter_add_( |
| 111 | + 0, src_batch, torch.ones_like(src_batch) |
| 112 | + ) |
| 113 | + |
| 114 | + atomic_numbers = data["atomic_numbers"] |
| 115 | + atomic_numbers_embedding = torch.nn.functional.one_hot(atomic_numbers.long(), 118).float() |
| 116 | + |
| 117 | + system_features: dict[str, torch.Tensor] = { |
| 118 | + "cell": data["cell"], |
| 119 | + "pbc": data["pbc"], |
| 120 | + } |
| 121 | + if "charge" in data: |
| 122 | + system_features["total_charge"] = data["charge"] |
| 123 | + if "spin" in data: |
| 124 | + system_features["spin_multiplicity"] = data["spin"] |
| 125 | + |
| 126 | + cells_per_edge = data["cell"][src_batch] |
| 127 | + shifts = torch.bmm(data["cell_offsets"].unsqueeze(1), cells_per_edge).squeeze(1) |
| 128 | + vectors = data["pos"][receivers] - data["pos"][senders] + shifts |
| 129 | + |
| 130 | + assert adapter.radius is not None, "Adapter radius must be set" |
| 131 | + return AtomGraphs( |
| 132 | + senders=senders, |
| 133 | + receivers=receivers, |
| 134 | + n_node=n_node, |
| 135 | + n_edge=nedges, |
| 136 | + node_features={ |
| 137 | + "positions": data["pos"], |
| 138 | + "atomic_numbers": atomic_numbers, |
| 139 | + "atomic_numbers_embedding": atomic_numbers_embedding, |
| 140 | + "atom_identity": torch.arange(data["pos"].shape[0], dtype=torch.int64), |
| 141 | + }, |
| 142 | + system_features=system_features, |
| 143 | + edge_features={"unit_shifts": data["cell_offsets"], "vectors": vectors}, |
| 144 | + node_targets={}, |
| 145 | + edge_targets={}, |
| 146 | + system_targets={}, |
| 147 | + system_id=None, |
| 148 | + fix_atoms=None, |
| 149 | + tags=None, |
| 150 | + radius=adapter.radius, |
| 151 | + max_num_neighbors=nedges, |
| 152 | + half_supercell=False, |
| 153 | + ) |
| 154 | + |
| 155 | + |
| 156 | +def _to_kups_atoms(atomgraph: AtomGraphs) -> dict[str, torch.Tensor]: |
| 157 | + """Convert to the flat AtomGraphInput dict used by kups/tojax.""" |
| 158 | + n_atoms = int(atomgraph.n_node.sum()) |
| 159 | + n_systems = atomgraph.n_node.shape[0] |
| 160 | + device = atomgraph.n_node.device |
| 161 | + return { |
| 162 | + "pos": atomgraph.node_features["positions"], |
| 163 | + "atomic_numbers": atomgraph.node_features["atomic_numbers"], |
| 164 | + "cell": atomgraph.system_features["cell"], |
| 165 | + "pbc": atomgraph.system_features["pbc"], |
| 166 | + "edge_index": torch.stack([atomgraph.senders, atomgraph.receivers]), |
| 167 | + "cell_offsets": atomgraph.edge_features["unit_shifts"], |
| 168 | + "batch": torch.arange(n_systems, device=device).repeat_interleave( |
| 169 | + atomgraph.n_node, output_size=n_atoms |
| 170 | + ), |
| 171 | + "charge": atomgraph.system_features.get( |
| 172 | + "total_charge", torch.zeros(n_systems, dtype=torch.long, device=device) |
| 173 | + ).view(-1), |
| 174 | + "spin": atomgraph.system_features.get( |
| 175 | + "spin_multiplicity", torch.zeros(n_systems, dtype=torch.long, device=device) |
| 176 | + ).view(-1), |
| 177 | + } |
| 178 | + |
| 179 | + |
| 180 | +def _make_predict_fn(adapter, model): |
| 181 | + """Build a tojax-traceable predict function: kups AtomGraphInput -> {energy, forces, stress}.""" |
| 182 | + |
| 183 | + def predict_fn(data): |
| 184 | + graph = _from_kups_atoms(adapter, data) |
| 185 | + result = model.predict(graph, split=False) |
| 186 | + out = {"energy": result["energy"]} |
| 187 | + if "forces" in result: |
| 188 | + out["forces"] = result["forces"] |
| 189 | + if "stress" in result: |
| 190 | + out["stress"] = result["stress"] |
| 191 | + return out |
| 192 | + |
| 193 | + return predict_fn |
| 194 | + |
| 195 | + |
| 196 | +@pytest.mark.parametrize("model_name", ACTIVE_MODELS) |
| 197 | +def test_kups_compatibility(model_name): |
| 198 | + """Test energy, forces, and stress using the kups AtomGraphInput format. |
| 199 | +
|
| 200 | + Since tojax cannot translate torch.autograd.grad, this test follows the |
| 201 | + pattern from the tojax export example (export_orb.py) where model + graph |
| 202 | + construction are packaged into a single function that accepts a flat |
| 203 | + AtomGraphInput dict — the same format kups uses. |
| 204 | +
|
| 205 | + Unlike test_tojax_outputs_match_pytorch which passes the model directly |
| 206 | + to tojax(model), this test tojax-traces a function that builds an |
| 207 | + AtomGraphs from raw tensors and calls model.predict(). For conservative |
| 208 | + models, forces and stress are computed via jax.grad of the energy |
| 209 | + (replacing torch.autograd.grad with JAX's own autodiff). |
| 210 | + """ |
| 211 | + |
| 212 | + load_fn = ORB_PRETRAINED_MODELS[model_name] |
| 213 | + model, adapter = load_fn(device="cpu", compile=False) |
| 214 | + model.eval() |
| 215 | + is_conservative = model_name in CONSERVATIVE_MODELS |
| 216 | + |
| 217 | + atoms = _make_atoms() |
| 218 | + batch = adapter.from_ase_atoms(atoms, device="cpu") |
| 219 | + data = _to_kups_atoms(batch) |
| 220 | + |
| 221 | + # PyTorch reference |
| 222 | + ref_batch = adapter.from_ase_atoms(atoms, device="cpu") |
| 223 | + with torch.enable_grad(): |
| 224 | + torch_out = model.predict(ref_batch) |
| 225 | + |
| 226 | + jax_data = tojax(data) |
| 227 | + |
| 228 | + if is_conservative: |
| 229 | + # Energy + forces/stress via jax.grad (torch.autograd.grad can't be translated) |
| 230 | + predict_fn = _make_predict_fn(adapter, model) |
| 231 | + jax_energy_fn = tojax(lambda data: predict_fn(data)["energy"]) |
| 232 | + batch_indices = jax_data["batch"] |
| 233 | + |
| 234 | + def _energy_with_strain(pos, strain): |
| 235 | + sym_strain = 0.5 * (strain + jnp.swapaxes(strain, -1, -2)) |
| 236 | + deformed_cell = jax_data["cell"] + jnp.einsum( |
| 237 | + "bij,bjk->bik", jax_data["cell"], sym_strain |
| 238 | + ) |
| 239 | + deformed_pos = pos + jnp.einsum("ni,nij->nj", pos, sym_strain[batch_indices]) |
| 240 | + return jax_energy_fn({**jax_data, "pos": deformed_pos, "cell": deformed_cell}).sum() |
| 241 | + |
| 242 | + energy_and_grads = jax.value_and_grad(_energy_with_strain, argnums=(0, 1)) |
| 243 | + jax_energy, (neg_forces, virial) = energy_and_grads( |
| 244 | + jax_data["pos"], jnp.zeros_like(jax_data["cell"]) |
| 245 | + ) |
| 246 | + |
| 247 | + np.testing.assert_allclose( |
| 248 | + np.asarray(jax_energy), |
| 249 | + torch_out[model.energy_name].detach().float().numpy(), |
| 250 | + atol=1e-5, |
| 251 | + rtol=1e-4, |
| 252 | + err_msg=f"energy mismatch for '{model_name}'", |
| 253 | + ) |
| 254 | + np.testing.assert_allclose( |
| 255 | + np.asarray(-neg_forces, dtype=np.float32), |
| 256 | + torch_out[model.grad_forces_name].detach().float().numpy(), |
| 257 | + atol=1e-4, |
| 258 | + rtol=1e-4, |
| 259 | + err_msg=f"force mismatch for '{model_name}'", |
| 260 | + ) |
| 261 | + if model.has_stress: |
| 262 | + volume = jnp.abs(jnp.linalg.det(jax_data["cell"])) |
| 263 | + jax_stress_3x3 = np.asarray(virial / volume[:, None, None]) |
| 264 | + jax_stress = np.stack( |
| 265 | + [ |
| 266 | + jax_stress_3x3[..., 0, 0], |
| 267 | + jax_stress_3x3[..., 1, 1], |
| 268 | + jax_stress_3x3[..., 2, 2], |
| 269 | + (jax_stress_3x3[..., 1, 2] + jax_stress_3x3[..., 2, 1]) / 2, |
| 270 | + (jax_stress_3x3[..., 0, 2] + jax_stress_3x3[..., 2, 0]) / 2, |
| 271 | + (jax_stress_3x3[..., 0, 1] + jax_stress_3x3[..., 1, 0]) / 2, |
| 272 | + ], |
| 273 | + axis=-1, |
| 274 | + ).astype(np.float32) |
| 275 | + np.testing.assert_allclose( |
| 276 | + jax_stress.reshape(torch_out[model.grad_stress_name].shape), |
| 277 | + torch_out[model.grad_stress_name].detach().float().numpy(), |
| 278 | + atol=1e-4, |
| 279 | + rtol=1e-4, |
| 280 | + err_msg=f"stress mismatch for '{model_name}'", |
| 281 | + ) |
| 282 | + else: |
| 283 | + # Direct models: energy, forces, stress from tojax'd predict |
| 284 | + jax_predict_fn = tojax(_make_predict_fn(adapter, model)) |
| 285 | + jax_out = jax_predict_fn(jax_data) |
| 286 | + |
| 287 | + np.testing.assert_allclose( |
| 288 | + np.asarray(jax_out["energy"]), |
| 289 | + torch_out["energy"].detach().float().numpy(), |
| 290 | + atol=1e-5, |
| 291 | + rtol=1e-4, |
| 292 | + err_msg=f"energy mismatch for '{model_name}'", |
| 293 | + ) |
| 294 | + if "forces" in jax_out: |
| 295 | + np.testing.assert_allclose( |
| 296 | + np.asarray(jax_out["forces"], dtype=np.float32), |
| 297 | + torch_out["forces"].detach().float().numpy(), |
| 298 | + atol=1e-4, |
| 299 | + rtol=1e-4, |
| 300 | + err_msg=f"force mismatch for '{model_name}'", |
| 301 | + ) |
| 302 | + if "stress" in jax_out: |
| 303 | + np.testing.assert_allclose( |
| 304 | + np.asarray(jax_out["stress"], dtype=np.float32), |
| 305 | + torch_out["stress"].detach().float().numpy(), |
| 306 | + atol=1e-4, |
| 307 | + rtol=1e-4, |
| 308 | + err_msg=f"stress mismatch for '{model_name}'", |
| 309 | + ) |
0 commit comments