Skip to content

Commit 7dfdd99

Browse files
committed
Ported changes from the internal repo
1 parent c3b0930 commit 7dfdd99

17 files changed

Lines changed: 616 additions & 650 deletions

MODELS.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ We provide several pretrained models that can be used to calculate energies, for
77
These models are a continuation of the [`orb-v3`](#v3-models) series trained on the [Open Molecules 2025 (OMol25)](https://arxiv.org/pdf/2505.08762) dataset—over 100M high-accuracy DFT calculations (ωB97M-V/def2-TZVPD) on diverse molecular systems including metal complexes, biomolecules, and electrolytes. Note: The training data does not contain periodic systems and these models have not been carefully tested on periodic systems.
88

99
There are two options:
10-
* `orb-v3-conservative-omol`
11-
* `orb-v3-direct-omol`
10+
* `orbmol-v1-conservative`
11+
* `orbmol-v1-direct`
1212

1313
See below for more explanation of this naming convention. Both models have `inf` neighbors, ensuring a continuous PES.
1414

@@ -18,7 +18,7 @@ See below for more explanation of this naming convention. Both models have `inf`
1818

1919
OrbMol-v2 extends the OrbMol architecture with **learnable per-atom electrostatics**: a `LatentChargeHead` predicts per-atom partial charges (constrained to sum to the system total charge), a `LatentSpinHead` predicts per-atom spins (constrained to sum to 2S = `spin_multiplicity − 1`), and a `CoulombModule` adds a long-range Coulomb energy on top of the GNN — direct bare-1/r Coulomb sum for non-periodic systems, Particle Mesh Ewald via `nvalchemiops` for periodic systems. The energy head (`ChargeConditionedEnergyHead`) is conditioned on the predicted charges and spins per-atom.
2020

21-
Trained on OMol25 (ωB97M-V/def2-TZVPD); supports both periodic and non-periodic systems. Stress is enabled via `model.enable_stress()` if needed.
21+
Trained on OMol25 and OPoly26 (ωB97M-V/def2-TZVPD); supports both periodic and non-periodic systems. Stress is enabled via `model.enable_stress()` if needed.
2222

2323
```python
2424
from orb_models.forcefield.pretrained import orbmol_v2

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,18 @@ Alternatively, you can use Docker to run orb-models; [see instructions below](#d
2121

2222
### Updates
2323

24+
**May 2026**: Release of OrbMol-v2 — extends the OrbMol architecture with learnable per-atom electrostatics:
25+
26+
* New `LatentChargeHead` and `LatentSpinHead` predict per-atom charges and spins (constrained to sum to the system total charge / 2S = `spin_multiplicity − 1`), and a `CoulombModule` adds long-range Coulomb energy on top of the GNN — direct Coulomb sum for non-periodic systems, Particle Mesh Ewald via `nvalchemiops` for periodic.
27+
* The energy head (`ChargeConditionedEnergyHead`) is conditioned on the predicted charges and spins per atom.
28+
* Trained on OMol25 and OPoly26 (ωB97M-V/def2-TZVPD); load with `pretrained.orbmol_v2(device="cuda")`.
29+
2430
**February 2026**: Improved GPU-accelerated graph construction with [ALCHEMI Toolkit-Ops](https://github.com/NVIDIA/nvalchemi-toolkit-ops) and batched simulation with [TorchSim](https://github.com/TorchSim/torch-sim):
2531

2632
* Alchemi-based graph construction (GPU-accelerated, up to 12x faster for large single systems, and sub-linear batch scaling delivering >100x graph construction speed-up for large batches of small systems)
2733
* TorchSim wrapper for batched optimisation and simulation, see [usage with TorchSim](#usage-with-torchsim)
2834
* Alchemi-based D3 dispersion correction module, see [D3 correction](#d3-correction)
2935

30-
**May 2026**: Release of OrbMol-v2 — extends the OrbMol architecture with learnable per-atom electrostatics:
31-
32-
* New `LatentChargeHead` and `LatentSpinHead` predict per-atom charges and spins (constrained to sum to the system total charge / 2S = `spin_multiplicity − 1`), and a `CoulombModule` adds long-range Coulomb energy on top of the GNN — bare 1/r direct sum for non-periodic systems, Particle Mesh Ewald via `nvalchemiops` for periodic.
33-
* The energy head (`ChargeConditionedEnergyHead`) is conditioned on the predicted charges and spins per atom.
34-
* Trained on OMol25 (ωB97M-V/def2-TZVPD); load with `pretrained.orbmol_v2(device="cuda")`.
3536

3637
**August 2025**: Release of the [OrbMol potentials](https://www.orbitalindustries.com/posts/orbmol-extending-orb-to-molecular-systems):
3738

@@ -195,7 +196,7 @@ from ase.build import molecule
195196
from orb_models.forcefield import pretrained
196197

197198
device = "cpu" # or device="cuda"
198-
orbff, atoms_adapter = pretrained.orb_v3_conservative_omol(
199+
orbff, atoms_adapter = pretrained.orbmol_v1_conservative(
199200
device=device,
200201
precision="float32-high", # or "float32-highest" / "float64
201202
)

orb_models/common/models/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def loss(self, batch: T) -> ModelOutput:
3636
"""Encodes to latents before message passing."""
3737
raise NotImplementedError()
3838

39+
def prepare_for_inference(self) -> None:
40+
"""Hook called before inference. Override to enable inference-only features."""
41+
3942

4043
class RegressorModelMixin[T: AbstractAtomBatch](ModelMixin[T]):
4144
"""Model Mixin for our regression models."""

orb_models/forcefield/models/conservative_regressor.py

Lines changed: 80 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class ConservativeForcefieldRegressor(base.RegressorModelMixin[AtomGraphs]):
4040
- "grad_forces"
4141
- "grad_stress"
4242
which weight the gradient based losses of forces/stress respectively.
43-
distill_direct_heads: Whether to distill the direct heads into the conservative heads.
43+
coulomb_module: Optional CoulombModule for long-range electrostatics.
44+
When present, a latent_charges head must also be in heads.
4445
**kwargs: Additional kwargs, used for backwards compatibility of deprecated arguments.
4546
"""
4647

@@ -55,12 +56,10 @@ def __init__(
5556
heads: Mapping[str, ForcefieldHead | ConfidenceHead],
5657
model: MoleculeGNS,
5758
loss_weights: dict[str, float] | None = None,
58-
distill_direct_heads: bool = False,
5959
online_normalisation: bool = True,
6060
level_of_theory: str | None = None,
6161
forces_loss_type: Literal["mae", "mse", "huber_0.01", "condhuber_0.01"] = "condhuber_0.01",
6262
pair_repulsion: bool = False,
63-
pair_repulsion_node_aggregation: str = "mean",
6463
has_stress: bool = True,
6564
coulomb_module: CoulombModule | None = None,
6665
**kwargs,
@@ -85,7 +84,6 @@ def __init__(
8584
_validate_heads_and_loss_weights(heads, nongrad_loss_weights)
8685

8786
self.loss_weights = loss_weights
88-
self.distill_direct_heads = distill_direct_heads
8987
self.forces_loss_type = forces_loss_type
9088

9189
self.model = model
@@ -95,11 +93,7 @@ def __init__(
9593

9694
self.pair_repulsion = pair_repulsion
9795
if self.pair_repulsion:
98-
self.pair_repulsion_fn = ZBLBasis(
99-
p=6,
100-
compute_gradients=False,
101-
node_aggregation=pair_repulsion_node_aggregation,
102-
)
96+
self.pair_repulsion_fn = ZBLBasis(p=6, compute_gradients=False, node_aggregation="sum")
10397

10498
self.coulomb_module = coulomb_module
10599
if self.coulomb_module is not None:
@@ -115,11 +109,21 @@ def __init__(
115109
self.forces_target = PROPERTIES[self.forces_name]
116110
self.grad_forces_name = f"{self.grad_prefix}_{self.forces_name}"
117111

118-
# Stress names are always derived (from level_of_theory); has_stress toggles computation
119-
self.stress_name: str = f"stress-{level_of_theory}" if level_of_theory else "stress"
120-
self.stress_target: PropertyDefinition = PROPERTIES[self.stress_name]
121-
self.grad_stress_name: str = f"{self.grad_prefix}_{self.stress_name}"
112+
# Stress is optional since only periodic systems have it
122113
self.has_stress = has_stress
114+
if self.has_stress:
115+
self.stress_name: str | None = (
116+
f"stress-{level_of_theory}" if level_of_theory else "stress"
117+
)
118+
self.stress_target: PropertyDefinition | None = PROPERTIES[self.stress_name]
119+
self.grad_stress_name: str | None = f"{self.grad_prefix}_{self.stress_name}"
120+
else:
121+
self.stress_name = None
122+
self.stress_target = None
123+
self.grad_stress_name = None
124+
assert self.has_stress == (self.grad_stress_name is not None), (
125+
"grad_stress_name must be set if has_stress is True"
126+
)
123127

124128
self.grad_rotation_name = "rotational_grad"
125129

@@ -129,8 +133,17 @@ def __init__(
129133
self.extra_properties.append(heads[name].target.fullname)
130134

131135
def enable_stress(self) -> None:
132-
"""Enable stress computation."""
136+
"""Enable stress computation. No-op if already enabled."""
137+
if self.has_stress:
138+
return
133139
self.has_stress = True
140+
self.stress_name = "stress"
141+
self.stress_target = PROPERTIES["stress"]
142+
self.grad_stress_name = f"{self.grad_prefix}_{self.stress_name}"
143+
144+
def prepare_for_inference(self) -> None:
145+
"""Enable stress for inference — always available via autograd."""
146+
self.enable_stress()
134147

135148
def disable_stress(self) -> None:
136149
"""Disable stress computation."""
@@ -146,11 +159,11 @@ def properties(self):
146159
self.grad_rotation_name,
147160
]
148161
if self.has_stress:
162+
assert self.grad_stress_name is not None, (
163+
"grad_stress_name must be set if has_stress is True"
164+
)
149165
props.append(self.grad_stress_name)
150-
for name in self.extra_properties:
151-
if not self.has_stress and "stress" in name:
152-
continue
153-
props.append(name)
166+
props.extend(self.extra_properties)
154167
return props
155168

156169
def forward(self, batch: AtomGraphs) -> dict[str, torch.Tensor]:
@@ -167,54 +180,43 @@ def forward(self, batch: AtomGraphs) -> dict[str, torch.Tensor]:
167180
node_features = out["node_features"]
168181

169182
# Predict per-atom charges/spins BEFORE energy head so they can
170-
# be used as conditioning features (ChargeConditionedEnergyHead) and by CoulombModule.
183+
# be used as conditioning features in ChargeConditionedEnergyHead and CoulombModule.
171184
latent_charges = None
172185
if "latent_charges" in self.heads:
173186
latent_charges = self.heads["latent_charges"](node_features, batch)
187+
174188
latent_spins = None
175189
if "latent_spins" in self.heads:
176190
latent_spins = self.heads["latent_spins"](node_features, batch)
177191

178192
energy_head = self.heads[self.energy_name]
179-
is_charge_conditioned = isinstance(energy_head, ChargeConditionedEnergyHead)
180-
if is_charge_conditioned:
181-
# ChargeConditionedEnergyHead.forward returns physical interaction energy directly.
182-
assert latent_charges is not None, (
183-
"ChargeConditionedEnergyHead requires a 'latent_charges' head"
184-
)
185-
raw_energy = energy_head(
193+
energy_head = cast(ForcefieldHead, energy_head)
194+
if isinstance(energy_head, ChargeConditionedEnergyHead):
195+
interaction_energy = energy_head(
186196
node_features,
187197
batch,
188198
per_atom_charges=latent_charges,
189199
per_atom_spins=latent_spins,
190200
)
191-
if self.pair_repulsion:
192-
raw_energy = raw_energy + self.pair_repulsion_fn(batch)["energy"]
193201
else:
194-
energy_head = cast(ForcefieldHead, energy_head)
195-
base_energy = energy_head(node_features, batch)
196-
raw_energy = energy_head.denormalize(base_energy, batch)
197-
if self.pair_repulsion:
198-
raw_energy = raw_energy + self.pair_repulsion_fn(batch)["energy"]
202+
assert latent_spins is None, "Latent spins are predicted but not used."
203+
interaction_energy = energy_head(node_features, batch)
204+
if self.pair_repulsion:
205+
interaction_energy += self.pair_repulsion_fn(batch)["energy"]
199206

200-
# Long-range Coulomb (predicted charges only — spins not used).
201207
coulomb_explicit_forces = None
202208
coulomb_explicit_virial = None
203209
if self.coulomb_module is not None:
204210
assert latent_charges is not None, "CoulombModule requires a LatentChargeHead"
205211
coulomb_energy, coulomb_explicit_forces, coulomb_explicit_virial = self.coulomb_module(
206212
latent_charges, batch
207213
)
208-
raw_energy = raw_energy + coulomb_energy
214+
interaction_energy += coulomb_energy
209215

210-
# Store final energy in `out` (interaction-units for ChargeConditioned, normalized otherwise).
211-
if is_charge_conditioned:
212-
out[self.energy_name] = raw_energy
213-
else:
214-
out[self.energy_name] = energy_head.normalize(raw_energy, batch, online=False)
216+
out[self.energy_name] = interaction_energy
215217

216218
forces, stress, rotational_grad = compute_gradient_forces_and_stress(
217-
energy=raw_energy,
219+
energy=interaction_energy,
218220
positions=batch.node_features["positions"],
219221
displacement=batch.system_features["stress_displacement"],
220222
cell=batch.system_features["cell"],
@@ -225,11 +227,11 @@ def forward(self, batch: AtomGraphs) -> dict[str, torch.Tensor]:
225227

226228
# Add explicit/spatial Coulomb force/stress corrections (see CoulombModule docstring).
227229
if self.coulomb_module is not None:
228-
assert coulomb_explicit_forces is not None
229-
assert coulomb_explicit_virial is not None
230+
assert coulomb_explicit_forces is not None, "Explicit/spatial forces are not computed"
231+
assert coulomb_explicit_virial is not None, "Explicit/spatial virial is not computed"
230232
forces = forces + coulomb_explicit_forces
231233
if self.has_stress:
232-
assert stress is not None
234+
assert stress is not None, "has_stress is True but stress is None"
233235
cell_3d = batch.system_features["cell"].view(-1, 3, 3)
234236
volume = torch.linalg.det(cell_3d).abs()
235237
coulomb_stress_3x3 = -coulomb_explicit_virial / volume.view(-1, 1, 1)
@@ -250,26 +252,40 @@ def forward(self, batch: AtomGraphs) -> dict[str, torch.Tensor]:
250252

251253
return out
252254

253-
def predict(self, batch: AtomGraphs, split: bool = False) -> dict[str, torch.Tensor]:
254-
"""Predict energy, forces, and stress."""
255+
def predict(
256+
self,
257+
batch: AtomGraphs,
258+
split: bool = False,
259+
fp64_energy: bool = True,
260+
) -> dict[str, torch.Tensor]:
261+
"""Predict energy, forces, and stress.
262+
263+
Args:
264+
batch: Input batch.
265+
split: If True, split predictions per graph.
266+
fp64_energy: If True (default), return absolute energy in fp64;
267+
required to preserve kJ/mol resolution since reference
268+
energies can be as high as ~1e4-1e5 eV. If False, returns
269+
energy in the input dtype.
270+
"""
255271
preds = self(batch)
256272

257273
out = {}
258-
energy_head = self.heads[self.energy_name]
259-
if isinstance(energy_head, ChargeConditionedEnergyHead):
260-
# preds[energy_name] is interaction energy in physical units; add reference in fp64.
261-
out[self.energy_name] = energy_head.absolute_energy(preds[self.energy_name], batch)
262-
else:
263-
energy_head = cast(ForcefieldHead, energy_head)
264-
out[self.energy_name] = energy_head.denormalize(preds[self.energy_name], batch)
274+
energy_head = cast(EnergyHead, self.heads[self.energy_name])
275+
out[self.energy_name] = energy_head.absolute_energy(
276+
preds[self.energy_name], batch, fp64=fp64_energy
277+
)
265278
out[self.grad_forces_name] = preds[self.grad_forces_name]
266279
if self.has_stress:
280+
assert self.grad_stress_name is not None, (
281+
"grad_stress_name must be set if has_stress is True"
282+
)
267283
out[self.grad_stress_name] = preds[self.grad_stress_name]
268284
out[self.grad_rotation_name] = preds[self.grad_rotation_name]
269285
for name in self.extra_properties:
270286
head = self.heads[name]
271287
if isinstance(head, ForcefieldHead):
272-
out[name] = head.denormalize(preds[name], batch)
288+
out[name] = preds[name]
273289
elif isinstance(head, ConfidenceHead):
274290
out[name] = torch.softmax(preds[name], dim=-1)
275291
else:
@@ -287,7 +303,6 @@ def loss(self, batch: AtomGraphs) -> base.ModelOutput:
287303

288304
energy_pred = out[self.energy_name]
289305
raw_grad_forces_pred = out[self.grad_forces_name]
290-
grad_forces_pred = self.grad_forces_normalizer(raw_grad_forces_pred, online=False)
291306

292307
# metrics
293308
metrics: dict = {}
@@ -309,7 +324,7 @@ def loss(self, batch: AtomGraphs) -> base.ModelOutput:
309324

310325
# Conservative forces
311326
loss_out = forces_loss_function(
312-
pred=grad_forces_pred,
327+
raw_pred=raw_grad_forces_pred,
313328
raw_target=batch.node_targets[self.forces_name],
314329
raw_gold_target=batch.node_targets[self.forces_name],
315330
name=self.forces_name,
@@ -326,10 +341,13 @@ def loss(self, batch: AtomGraphs) -> base.ModelOutput:
326341

327342
# Conservative stress (optional)
328343
if self.has_stress and self.grad_stress_name in out:
344+
assert self.stress_name is not None, "stress_name must be set if has_stress is True"
345+
assert self.grad_stress_name is not None, (
346+
"grad_stress_name must be set if has_stress is True"
347+
)
329348
raw_grad_stress_pred = out[self.grad_stress_name]
330-
grad_stress_pred = self.grad_stress_normalizer(raw_grad_stress_pred, online=False)
331349
loss_out = stress_loss_function(
332-
pred=grad_stress_pred,
350+
raw_pred=raw_grad_stress_pred,
333351
raw_target=batch.system_targets[self.stress_name],
334352
raw_gold_target=batch.system_targets[self.stress_name],
335353
name=self.stress_name,
@@ -342,23 +360,15 @@ def loss(self, batch: AtomGraphs) -> base.ModelOutput:
342360
metrics.update({f"{self.grad_prefix}-{k}": v for k, v in loss_out.log.items()})
343361

344362
# Direct forces / stress predictions
345-
for grad_name, grad_pred in [
346-
(self.grad_forces_name, raw_grad_forces_pred),
347-
] + (
348-
[(self.grad_stress_name, out[self.grad_stress_name])]
349-
if self.has_stress and self.grad_stress_name in out
350-
else []
363+
for grad_name in [self.grad_forces_name] + (
364+
[self.grad_stress_name] if self.has_stress and self.grad_stress_name in out else []
351365
):
366+
assert grad_name is not None
352367
direct_name = grad_name.replace(self.grad_prefix + "_", "")
353368
if direct_name in self.extra_properties:
354369
direct_head = cast(ForcefieldHead, self.heads[direct_name])
355370
direct_pred = out[direct_name]
356-
if self.distill_direct_heads:
357-
loss_out = direct_head.loss(
358-
direct_pred, batch, alternative_target=grad_pred.detach()
359-
)
360-
else:
361-
loss_out = direct_head.loss(direct_pred, batch)
371+
loss_out = direct_head.loss(direct_pred, batch)
362372
loss = self.loss_weights[direct_name] * loss_out.loss
363373
total_loss += loss
364374
metrics.update(loss_out.log)

0 commit comments

Comments
 (0)