Skip to content

Commit d70d203

Browse files
harveydevereuxoerc0122ElliottKasoar
authored
Add SevenNet to training (#662)
Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com> Co-authored-by: Elliott Kasoar <45317199+ElliottKasoar@users.noreply.github.com>
1 parent 01d45a1 commit d70d203

8 files changed

Lines changed: 398 additions & 11 deletions

File tree

docs/source/user_guide/command_line.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ Training and fine-tuning MLIPs
659659
------------------------------
660660

661661
.. note::
662-
Currently only MACE and Nequip models are supported.
662+
Currently MACE, Nequip, and SevenNet models are supported.
663663

664664
Models can be trained by passing an archictecture and an archictecture specific configuration file as options to the ``janus train`` command. The configuration file will be passed to the corresponding MLIPs command line interface. For example to train a MACE MLIP:
665665

@@ -692,14 +692,19 @@ For MACE, training will create ``logs``, ``checkpoints`` and ``results`` sub-dir
692692
Instructions for writing a MACE ``config.yml`` file can be found in the `MACE Readme <https://github.com/ACEsuit/mace?tab=readme-ov-file#training>`_ and the `MACE run_train CLI <https://github.com/ACEsuit/mace/blob/main/mace/cli/run_train.py>`_.
693693

694694

695-
Training Nequip MLIPS
695+
Training Nequip MLIPs
696696
+++++++++++++++++++++
697697

698698
Configuration of Nequip training is outlined in the `Nequip user guide <https://nequip.readthedocs.io/en/latest/guide/guide.html>`_. In particular note that the configuration file must have a ``.yaml`` extension.
699699

700700
The results directory contents depends on the options selected in the configuration file, but may typically contain model checkpoint, ``.ckpt``, files and a metrics directory.
701701

702702

703+
Training SevenNet MLIPs
704+
+++++++++++++++++++++++
705+
706+
The `SevenNet documentation <https://sevennet.readthedocs.io/en/latest/>`_ contains information on training SevenNet MLIPs. The SevenNet `tutorial repository <https://github.com/MDIL-SNU/sevennet_tutorial/tree/main>`_ also contains some example ```.yaml``` configuration files for training and fine-tuning.
707+
703708
Preprocessing training data
704709
----------------------------
705710

janus_core/cli/train.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,25 @@ def train(
122122
"""Fine-tuning requested but there is no checkpoint or
123123
package specified in your config."""
124124
)
125+
case "sevennet":
126+
continue_section = config["train"].get("continue")
127+
if continue_section is None and fine_tune:
128+
raise ValueError(
129+
"""Fine-tuning requested but there is no continue
130+
section in your config."""
131+
)
132+
model = continue_section.get("checkpoint")
133+
if model is None:
134+
raise ValueError(
135+
"No model specified as a checkpoint for fine-tuning."
136+
)
137+
if not fine_tune and continue_section is not None:
138+
raise ValueError(
139+
"""Fine-tuning not requested but a continue
140+
section is in your config. Please use
141+
--fine-tune"""
142+
)
143+
125144
case _:
126145
raise ValueError(f"Unsupported Architecture ({arch})")
127146

janus_core/training/train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from argparse import ArgumentParser
56
from typing import Any
67

78
import yaml
@@ -93,6 +94,15 @@ def train(
9394
)
9495
foundation_model = model["checkpoint_path"]
9596

97+
case "sevennet":
98+
from sevenn.main.sevenn import cmd_parser_train, run
99+
100+
parser = ArgumentParser()
101+
cmd_parser_train(parser)
102+
mlip_args = parser.parse_args(
103+
[str(mlip_config), "--working_dir", str(file_prefix), "-s"]
104+
)
105+
96106
case _:
97107
raise ValueError(f"{arch} is currently unsupported in train.")
98108

tests/data/sevennet_fine_tune.yml

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
model:
2+
chemical_species: auto
3+
4+
cutoff: 2.0
5+
irreps_manual:
6+
- 128x0e
7+
- 128x0e+64x1e+32x2e+32x3e
8+
- 128x0e+64x1e+32x2e+32x3e
9+
- 128x0e+64x1e+32x2e+32x3e
10+
- 128x0e+64x1e+32x2e+32x3e
11+
- 128x0e
12+
channel: 128
13+
lmax: 3
14+
num_convolution_layer: 5
15+
is_parity: false
16+
radial_basis:
17+
radial_basis_name: bessel
18+
bessel_basis_num: 8
19+
cutoff_function:
20+
cutoff_function_name: poly_cut
21+
poly_cut_p_value: 6
22+
23+
act_radial: silu
24+
weight_nn_hidden_neurons:
25+
- 64
26+
- 64
27+
act_scalar:
28+
e: silu
29+
o: tanh
30+
act_gate:
31+
e: silu
32+
o: tanh
33+
34+
train_denominator: false
35+
train_shift_scale: false
36+
use_bias_in_linear: false
37+
38+
readout_as_fcn: false
39+
self_connection_type: linear
40+
interaction_type: nequip
41+
42+
train:
43+
random_seed: 1
44+
is_train_stress: True
45+
epoch: 1
46+
47+
48+
49+
optimizer: 'adam'
50+
optim_param:
51+
lr: 0.005
52+
scheduler: 'exponentiallr'
53+
scheduler_param:
54+
gamma: 0.99
55+
56+
force_loss_weight: 0.1
57+
stress_loss_weight: 1e-06
58+
59+
per_epoch: 1
60+
61+
62+
63+
error_record:
64+
- ['Energy', 'RMSE']
65+
- ['Force', 'RMSE']
66+
- ['Stress', 'RMSE']
67+
- ['TotalLoss', 'None']
68+
69+
continue:
70+
reset_optimizer: True
71+
reset_scheduler: True
72+
reset_epoch: True
73+
checkpoint: 'tests/models/extra/SevenNet_l3i5.pth'
74+
75+
use_statistic_values_of_checkpoint: True
76+
77+
data:
78+
batch_size: 4
79+
data_divide_ratio: 0.1
80+
81+
shift: 'per_atom_energy_mean'
82+
scale: 'force_rms'
83+
84+
85+
86+
data_format: 'ase'
87+
data_format_args:
88+
index: ':'
89+
90+
91+
92+
load_dataset_path: ['tests/data/mlip_train.xyz']

tests/data/sevennet_train.yml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
model:
2+
chemical_species: 'Auto'
3+
cutoff: 2.0
4+
channel: 4
5+
lmax: 1
6+
num_convolution_layer: 1
7+
8+
weight_nn_hidden_neurons: [4, 4]
9+
radial_basis:
10+
radial_basis_name: 'bessel'
11+
bessel_basis_num: 8
12+
cutoff_function:
13+
cutoff_function_name: 'poly_cut'
14+
poly_cut_p_value: 6
15+
16+
act_gate: {'e': 'silu', 'o': 'tanh'}
17+
act_scalar: {'e': 'silu', 'o': 'tanh'}
18+
19+
is_parity: False
20+
21+
self_connection_type: 'nequip'
22+
23+
conv_denominator: "avg_num_neigh"
24+
train_denominator: False
25+
train_shift_scale: False
26+
27+
train:
28+
random_seed: 1
29+
is_train_stress: True
30+
epoch: 2
31+
optimizer: 'adam'
32+
optim_param:
33+
lr: 0.005
34+
scheduler: 'exponentiallr'
35+
scheduler_param:
36+
gamma: 0.99
37+
force_loss_weight: 0.1
38+
stress_loss_weight: 1e-06
39+
per_epoch: 1
40+
error_record:
41+
- ['Energy', 'RMSE']
42+
- ['Force', 'RMSE']
43+
- ['Stress', 'RMSE']
44+
- ['TotalLoss', 'None']
45+
46+
data:
47+
batch_size: 4
48+
data_divide_ratio: 0.1
49+
50+
shift: 'per_atom_energy_mean'
51+
scale: 'force_rms'
52+
data_format: 'ase'
53+
data_format_args:
54+
index: ':'
55+
load_dataset_path: ['tests/data/mlip_train.xyz']

tests/models/extra_models.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,24 @@
55
from argparse import ArgumentParser
66
from pathlib import Path
77
from urllib.request import urlretrieve
8+
from warnings import warn
9+
10+
11+
def try_retrieve(url: str, filename: Path):
12+
"""Attempt to retrieve from url.
13+
14+
Patameters
15+
----------
16+
url
17+
The url to attempt to retrieve from.
18+
filename
19+
The local filename of the retrieved data.
20+
"""
21+
try:
22+
urlretrieve(url, filename)
23+
except Exception as e:
24+
warn(f"Unable to retrieve from {url} because of:\n {e}", stacklevel=2)
25+
826

927
if __name__ == "__main__":
1028
parser = ArgumentParser(
@@ -14,7 +32,13 @@
1432
args = parser.parse_args()
1533

1634
args.path.mkdir(parents=True, exist_ok=True)
17-
urlretrieve(
35+
36+
try_retrieve(
1837
"https://zenodo.org/records/16980200/files/NequIP-MP-L-0.1.nequip.zip",
19-
filename=args.path / "NequIP-MP-L-0.1.nequip.zip",
38+
args.path / "NequIP-MP-L-0.1.nequip.zip",
39+
)
40+
41+
try_retrieve(
42+
"https://github.com/MDIL-SNU/SevenNet/raw/dff008ac9c53d368b5bee30a27fa4bdfd73f19b2/sevenn/pretrained_potentials/SevenNet_l3i5/checkpoint_l3i5.pth",
43+
args.path / "SevenNet_l3i5.pth",
2044
)

0 commit comments

Comments
 (0)