Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e34a2cb
Port orbmol-v2 (learnable electrostatics) from the reference codebase
timduignan May 7, 2026
d15591f
orbmol-v2 cleanup: preserve pair_repulsion BC, simplify absolute_energy
timduignan May 7, 2026
bca39cb
Add BC guard tests for orbmol-v2 port
timduignan May 7, 2026
8defeea
orbmol_v2: HF weights URL, bump nvalchemiops, add network smoke test
timduignan May 7, 2026
3a6fb37
Cap nvalchemi-toolkit-ops at <0.4
timduignan May 7, 2026
5ce1f66
Document orbmol-v2 in MODELS.md
timduignan May 7, 2026
0fb1861
Fix MODELS.md: bare-1/r Coulomb for non-periodic, drop size-consisten…
timduignan May 7, 2026
c3b0930
Polish docs and docstrings for orbmol-v2
timduignan May 7, 2026
7ac3c29
Ported changes from the internal repo
vsimkus May 7, 2026
d359837
Soften LatentSpinHead claims; document fp64 energy default
timduignan May 8, 2026
4f7220a
Soften charge/spin head wording; bump torch-sim-atomistic to 0.6.0
timduignan May 8, 2026
d5e3fae
Handle node/graph features/targets in from_ase_atoms_list
vsimkus May 8, 2026
7dd0e20
Update examples to orbmol_v2
vsimkus May 8, 2026
73aa674
Refocus README orbmol-v2 update on CoulombModule + headline benchmarks
timduignan May 12, 2026
290c0a0
Pass explicit n_node to scatter_mean (port of #3074)
timduignan May 22, 2026
23f7695
Compile-safe heads + drop in-place ops on Coulomb path (port of #3074)
timduignan May 22, 2026
df8f4f0
Enable full-model torch.compile for orbmol-v2 (port of #3074)
timduignan May 22, 2026
7638657
Switch orbmol-v2 default to teqabfhg (no per-atom spin head)
timduignan May 22, 2026
b5f1b4c
Docs: drop LatentSpinHead language from MODELS.md
timduignan May 22, 2026
64c2b12
Refresh README orbmol-v2 benchmarks for teqabfhg
timduignan May 22, 2026
61b7d78
scripts/speed.py: set charge/spin on random crystals (orbmol_v2 support)
timduignan May 22, 2026
6ae0776
Refresh README speed bullet with teqabfhg numbers
timduignan May 22, 2026
d808684
README speed bullet: v1 vs v2-teqabfhg table from BENCH
timduignan May 22, 2026
d878054
README speed bullet: 3-column table including v1 full-compile
timduignan May 22, 2026
f5661f2
finetune.py: detect conservative-vs-direct from model, not name string
timduignan May 25, 2026
166e75f
Address Vaidas review comments
timduignan May 25, 2026
b79f71d
Address Vaidas follow-up review comments
timduignan May 26, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions FINETUNING_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ Number of atoms: 3
```bash
python finetune.py \
--data_path /path/to/your/dataset.db \
--base_model orb_v3_conservative_omol \
--base_model orbmol_v2 \
--energy_loss_weight 0.1 \
--forces_loss_weight 1.0 \
--stress_loss_weight 0.0 \
Expand All @@ -164,7 +164,7 @@ python finetune.py \
```bash
python finetune.py \
--data_path /path/to/your/dataset.db \
--base_model orb_v3_conservative_omol \
--base_model orbmol_v2 \
--custom_reference_energies /path/to/reference_energies.json \
--energy_loss_weight 0.1 \
--forces_loss_weight 1.0
Expand All @@ -175,7 +175,7 @@ python finetune.py \
```bash
python finetune.py \
--data_path /path/to/your/dataset.db \
--base_model orb_v3_conservative_omol \
--base_model orbmol_v2 \
--custom_reference_energies /path/to/reference_energies.json \
--trainable_reference_energies \
--energy_loss_weight 0.1 \
Expand All @@ -187,7 +187,7 @@ python finetune.py \
```bash
python finetune.py \
--data_path /path/to/your/dataset.db \
--base_model orb_v3_conservative_omol \
--base_model orbmol_v2 \
--trainable_reference_energies \
--energy_loss_weight 0.1 \
--forces_loss_weight 1.0
Expand Down Expand Up @@ -258,7 +258,7 @@ Lines starting with `#` are treated as comments and ignored.

The script automatically handles the differences between conservative and direct models:

- **Conservative models** (e.g., `orb_v3_conservative_omol`):
- **Conservative models** (e.g., `orbmol_v2`):
- Use `grad_forces` and `grad_stress` as **loss-weight keys**
- Compute forces via automatic differentiation

Expand All @@ -276,7 +276,7 @@ If you prefer to write your own finetuning script, you can use the clean API dir
from orb_models.forcefield import pretrained

# Load model with custom configuration
model, atoms_adapter = pretrained.orb_v3_conservative_omol(
model, atoms_adapter = pretrained.orbmol_v2(
device='cuda',
precision='float32-high',
train=True,
Expand Down Expand Up @@ -319,7 +319,7 @@ import torch
from orb_models.forcefield import pretrained

# Load model architecture (set train=False for inference)
model, atoms_adapter = pretrained.orb_v3_conservative_omol(train=False)
model, atoms_adapter = pretrained.orbmol_v2(train=False)

# Load your finetuned checkpoint
model.load_state_dict(torch.load('path/to/finetuned_checkpoint.pt'))
Expand All @@ -331,7 +331,7 @@ You can also specify loss weights when loading for further finetuning:

```python
# Load for continued finetuning with different loss weights
model, atoms_adapter = pretrained.orb_v3_conservative_omol(
model, atoms_adapter = pretrained.orbmol_v2(
train=True,
loss_weights={'energy': 0.5, 'grad_forces': 20.0}
)
Expand All @@ -358,7 +358,7 @@ Finetuning on ORCA wB97M-V data with different reference scheme:
```bash
python finetune.py \
--data_path my_dataset.db \
--base_model orb_v3_conservative_omol \
--base_model orbmol_v2 \
--custom_reference_energies my_refs.json \
--energy_loss_weight 1.0 \
--forces_loss_weight 10.0 \
Expand All @@ -370,7 +370,7 @@ python finetune.py \
from orb_models.forcefield import pretrained
import torch

model, atoms_adapter = pretrained.orb_v3_conservative_omol(train=False)
model, atoms_adapter = pretrained.orbmol_v2(train=False)
model.load_state_dict(torch.load('checkpoints/my_finetuned_model.pt'))
# Reference energies from my_refs.json are now loaded!
```
Expand All @@ -384,7 +384,7 @@ from orb_models.forcefield import pretrained
from orb_models.dataset.ase_sqlite_dataset import AseSqliteDataset

# Load model with configuration
model, atoms_adapter = pretrained.orb_v3_conservative_omol(
model, atoms_adapter = pretrained.orbmol_v2(
device='cuda',
train=True,
train_reference_energies=False, # Fixed reference energies
Expand Down
20 changes: 18 additions & 2 deletions MODELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,27 @@ We provide several pretrained models that can be used to calculate energies, for
These models are a continuation of the [`orb-v3`](#v3-models) series trained on the [Open Molecules 2025 (OMol25)](https://arxiv.org/pdf/2505.08762) dataset—over 100M high-accuracy DFT calculations (ωB97M-V/def2-TZVPD) on diverse molecular systems including metal complexes, biomolecules, and electrolytes. Note: The training data does not contain periodic systems and these models have not been carefully tested on periodic systems.

There are two options:
* `orb-v3-conservative-omol`
* `orb-v3-direct-omol`
* `orbmol-v1-conservative`
* `orbmol-v1-direct`

See below for more explanation of this naming convention. Both models have `inf` neighbors, ensuring a continuous PES.

### OrbMol-v2 (learnable electrostatics)

* `orbmol-v2`

OrbMol-v2 extends the OrbMol architecture with **learnable per-atom electrostatics**: a `LatentChargeHead` predicts per-atom latent charges constrained to sum to the system total charge, and a `CoulombModule` adds a long-range Coulomb energy on top of the GNN, direct bare-1/r Coulomb sum for non-periodic systems, Particle Mesh Ewald via `nvalchemiops` for periodic systems. The energy head (`ChargeConditionedEnergyHead`) is conditioned on the per-atom charges. Similar to orbmol-v1, system-level total charge and spin are required.

Trained on OMol25 and OPoly26 (ωB97M-V/def2-TZVPD); supports both periodic and non-periodic systems. Stress is enabled via `model.enable_stress()` if needed.

```python
from orb_models.forcefield.pretrained import orbmol_v2
model, atoms_adapter = orbmol_v2(device="cuda")
# atoms.info["charge"] and atoms.info["spin"] (multiplicity, = 2S+1) must be set.
```

> **Caution:** While the model does predict per-atom charge values as a latent feature in the charge head, the model has not seen any per-atom charge values during training; these are emergent from optimisation against energies and forces alone. They should therefore be treated with caution: while in at least some cases they appear to correspond to the correct physical values, the reliability and generality of this correspondence is unclear and is the subject of ongoing investigations.

### [V3 Models](https://arxiv.org/abs/2504.06231)

V3 models use the following naming convention: ```orb-v3-X-Y-Z``` where:
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,20 @@ Alternatively, you can use Docker to run orb-models; [see instructions below](#d

### Updates

**May 2026**: Release of OrbMol-v2 — adds a `CoulombModule` for long-range electrostatics on top of the OrbMol architecture, using direct Coulomb summation for non-periodic systems and Particle Mesh Ewald (via `nvalchemiops`) for periodic. Trained on OMol25 and OPoly26 (ωB97M-V/def2-TZVPD); load with `pretrained.orbmol_v2(device="cuda")`. See [MODELS.md](MODELS.md) for the full architecture description.

* **Long-range electrostatics and learnable charges.** GSCDB138 Normalized Error Ratio drops from **6.05 → 1.62** (3.7× lower, comparable to a good DFT functional).
* **Full-model compilation.** `model.compile(...)` now wraps the full regressor for all models, giving ~1.7× speedup at 10k atoms on a single 80 GB GPU.

`model.predict(...)["energy"]` now returns **fp64** by default to preserve kJ/mol resolution against OMol-scale references (~1e4–1e5 eV). Pass `fp64_energy=False` to opt out.

**February 2026**: Improved GPU-accelerated graph construction with [ALCHEMI Toolkit-Ops](https://github.com/NVIDIA/nvalchemi-toolkit-ops) and batched simulation with [TorchSim](https://github.com/TorchSim/torch-sim):

* Alchemi-based graph construction (GPU-accelerated, up to 12x faster for large single systems, and sub-linear batch scaling delivering >100x graph construction speed-up for large batches of small systems)
* TorchSim wrapper for batched optimisation and simulation, see [usage with TorchSim](#usage-with-torchsim)
* Alchemi-based D3 dispersion correction module, see [D3 correction](#d3-correction)


**August 2025**: Release of the [OrbMol potentials](https://www.orbitalindustries.com/posts/orbmol-extending-orb-to-molecular-systems):

* Trained on the [Open Molecules 2025 (OMol25)](https://arxiv.org/pdf/2505.08762) dataset—over 100M high-accuracy DFT calculations (ωB97M-V/def2-TZVPD) on diverse molecular systems including metal complexes, biomolecules, and electrolytes.
Expand Down Expand Up @@ -189,7 +197,7 @@ from ase.build import molecule
from orb_models.forcefield import pretrained

device = "cpu" # or device="cuda"
orbff, atoms_adapter = pretrained.orb_v3_conservative_omol(
orbff, atoms_adapter = pretrained.orbmol_v2(
device=device,
precision="float32-high", # or "float32-highest" / "float64
)
Expand Down
4 changes: 2 additions & 2 deletions examples/NaClWaterMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def run_md_simulation(
# Set the calculator
# Note: If you encounter compilation errors (e.g., Triton issues on clusters),
# you can disable compilation by adding compile=False:
# orbff, atoms_adapter = pretrained.orb_v3_conservative_omol(device=device, compile=False)
orbff, atoms_adapter = pretrained.orb_v3_conservative_omol(device=device)
# orbff, atoms_adapter = pretrained.orbmol_v2(device=device, compile=False)
orbff, atoms_adapter = pretrained.orbmol_v2(device=device)
atoms.calc = ORBCalculator(orbff, atoms_adapter=atoms_adapter, device=device)

# Set the initial velocities
Expand Down
51 changes: 26 additions & 25 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from orb_models.common.training.util import get_optim, init_device
from orb_models.common.utils import seed_everything
from orb_models.forcefield import pretrained
from orb_models.forcefield.models.conservative_regressor import (
ConservativeForcefieldRegressor,
)

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

Expand Down Expand Up @@ -439,26 +442,30 @@ def run(args):
# GPUs and does not appear to hurt training
precision = "float32-high"

# Prepare loss weights if specified
loss_weights = {}
is_conservative_model = "conservative" in args.base_model
# Instantiate model with configuration
base_model = args.base_model
model, atoms_adapter = getattr(pretrained, base_model)(
device=device,
precision=precision,
train=True,
train_reference_energies=args.trainable_reference_energies,
)

# Detect conservative vs direct from the instantiated model type.
is_conservative_model = isinstance(model, ConservativeForcefieldRegressor)

# Map CLI loss-weight flags onto the keys the instantiated model expects.
loss_weights: dict[str, float] = {}
if args.energy_loss_weight is not None:
loss_weights["energy"] = args.energy_loss_weight

if args.forces_loss_weight is not None:
# Key depends on model type
if is_conservative_model:
loss_weights["grad_forces"] = args.forces_loss_weight
else: # direct model
loss_weights["forces"] = args.forces_loss_weight
key = "grad_forces" if is_conservative_model else "forces"
loss_weights[key] = args.forces_loss_weight

if args.stress_loss_weight is not None:
# Key depends on model type
if is_conservative_model:
loss_weights["grad_stress"] = args.stress_loss_weight
else: # direct model
loss_weights["stress"] = args.stress_loss_weight
key = "grad_stress" if is_conservative_model else "stress"
loss_weights[key] = args.stress_loss_weight

if args.equigrad_loss_weight is not None:
if not is_conservative_model:
Expand All @@ -471,16 +478,7 @@ def run(args):
for key, val in loss_weights.items():
logging.info(f" {key}: {val}")
logging.info("=" * 60)

# Instantiate model with configuration
base_model = args.base_model
model, atoms_adapter = getattr(pretrained, base_model)(
device=device,
precision=precision,
train=True,
train_reference_energies=args.trainable_reference_energies,
loss_weights=loss_weights if loss_weights else None,
)
model.loss_weights.update(loss_weights)

# Handle custom reference energies if provided
if args.custom_reference_energies:
Expand Down Expand Up @@ -678,12 +676,15 @@ def main():
type=str,
help="Base model to finetune.",
choices=[
"orbmol_v2",
"orb_v3_conservative_omol",
"orb_v3_direct_omol",
"orbmol_v1_conservative",
"orbmol_v1_direct",
"orb_v3_conservative_inf_omat",
"orb_v3_conservative_20_omat",
"orb_v3_direct_inf_omat",
"orb_v3_direct_20_omat",
"orb_v3_conservative_omol",
"orb_v3_direct_omol",
"orb_v2",
],
)
Expand Down
3 changes: 3 additions & 0 deletions orb_models/common/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def loss(self, batch: T) -> ModelOutput:
"""Encodes to latents before message passing."""
raise NotImplementedError()

def prepare_for_inference(self) -> None:
"""Hook called before inference. Override to enable inference-only features."""


class RegressorModelMixin[T: AbstractAtomBatch](ModelMixin[T]):
"""Model Mixin for our regression models."""
Expand Down
26 changes: 8 additions & 18 deletions orb_models/common/models/segment_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def aggregate_nodes(
if reduction == "sum":
return scatter_sum(tensor, segments, dim=0, dim_size=count)
elif reduction == "mean":
return scatter_mean(tensor, segments, dim=0, dim_size=count)
return scatter_mean(tensor, segments, n_node, dim=0, dim_size=count)
elif reduction == "max":
return segment_max(tensor, segments, num_segments=count)
else:
Expand All @@ -61,11 +61,6 @@ def segment_max(data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int
return scatter_max(data, segment_ids, dim=0, dim_size=num_segments)


def segment_mean(data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int):
"""Computes index based mean over segments of a tensor."""
return scatter_mean(data, segment_ids, dim=0, dim_size=num_segments)


def segment_softmax(
data: torch.Tensor,
segment_ids: torch.Tensor,
Expand Down Expand Up @@ -222,6 +217,7 @@ def scatter_std(
def scatter_mean(
src: torch.Tensor,
index: torch.Tensor,
count: torch.Tensor,
dim: int = -1,
out: torch.Tensor | None = None,
dim_size: int | None = None,
Expand All @@ -231,6 +227,7 @@ def scatter_mean(
Args:
src (torch.Tensor): The source tensor.
index (torch.Tensor): The indices of elements to scatter.
count (torch.Tensor): Pre-computed group sizes (e.g. n_node).
dim (int, optional): The dimension along which to index. Defaults to -1.
out (Optional[torch.Tensor], optional): The output tensor. Defaults to None.
dim_size (Optional[int], optional): Size of the output tensor. Defaults to None.
Expand All @@ -241,20 +238,13 @@ def scatter_mean(
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)

index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1

ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count[count < 1] = 1
count = _broadcast(count, out, dim)
divisor = count.to(dtype=out.dtype)
divisor = divisor.clamp(min=1)
divisor = _broadcast(divisor, out, dim)
if out.is_floating_point():
out.true_divide_(count)
out.true_divide_(divisor)
else:
out.div_(count, rounding_mode="floor")
out.div_(divisor, rounding_mode="floor")
return out


Expand Down
Loading
Loading