Skip to content

Commit 53f9813

Browse files
authored
feat(jax): add training (#5460)
## Summary This PR ports the JAX training entrypoint from the `parallel` branch onto the current `deepmodeling/deepmd-kit` master as a local-only training path. The change keeps the useful JAX trainer/CLI pieces while deliberately removing the parallel/distributed parts requested for cleanup: - add a JAX `train` entrypoint and wire it into the JAX backend command path - add local JAX trainer infrastructure for model initialization, data statistics, loss setup, training, validation, checkpointing, and model export - use the current dpmodel `compute_or_load_stat` data-stat practice from master - remove parallel/sharding-specific behavior from the training path - remove Hessian-specific behavior from the training path - map the lower-interface model outputs into the keys expected by `EnergyLoss` - use `communicate_extended_output` so extended/ghost atom force contributions are scattered back to local atoms correctly - add regression coverage for the local JAX training entrypoint and cleanup constraints ## Tests - `/tmp/deepmd-jax-venv/bin/python -m pytest -q source/tests/jax/test_training.py` - GitHub Actions `Test Python`: https://github.com/njzjz-bothub/deepmd-kit/actions/runs/26464854510 Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a CLI entrypoint to run JAX train/freeze commands. * Backend hook now returns the JAX entrypoint, enabling invocation as the JAX backend. * New JAX trainer providing training/validation flow, checkpointing, learning-rate scheduling, input preparation and data conversion. * Enhanced runtime summary and logging (device/backend info, GPU counts, JAX version). * **Tests** * Added an end-to-end test that runs a single-step JAX training workflow and verifies produced artifacts. <!-- review_stack_entry_start --> [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/deepmodeling/deepmd-kit/pull/5460?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent a7c597d commit 53f9813

8 files changed

Lines changed: 1044 additions & 1 deletion

File tree

deepmd/backend/jax.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]:
6262
Callable[[Namespace], None]
6363
The entry point hook of the backend.
6464
"""
65-
raise NotImplementedError
65+
from deepmd.jax.entrypoints.main import (
66+
main,
67+
)
68+
69+
return main
6670

6771
@property
6872
def deep_eval(self) -> type["DeepEvalBackend"]:

deepmd/jax/entrypoints/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later

deepmd/jax/entrypoints/freeze.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Freeze utilities for the JAX backend."""
3+
4+
from pathlib import (
5+
Path,
6+
)
7+
8+
from deepmd.backend.suffix import (
9+
format_model_suffix,
10+
)
11+
from deepmd.jax.utils.serialization import (
12+
deserialize_to_file,
13+
serialize_from_file,
14+
)
15+
16+
17+
def freeze(
18+
*,
19+
checkpoint_folder: str,
20+
output: str,
21+
**kwargs: object,
22+
) -> None:
23+
"""Freeze a JAX checkpoint into a serialized model file.
24+
25+
Parameters
26+
----------
27+
checkpoint_folder : str
28+
Location of either the checkpoint directory or a folder containing the
29+
stable ``checkpoint`` pointer.
30+
output : str
31+
Output model filename or prefix. The JAX model suffix is added when the
32+
filename has no supported backend suffix.
33+
**kwargs
34+
Other CLI arguments accepted for backend entry-point compatibility.
35+
"""
36+
del kwargs
37+
38+
checkpoint_path = Path(checkpoint_folder)
39+
if (checkpoint_path / "checkpoint").is_file():
40+
checkpoint_pointer = (checkpoint_path / "checkpoint").read_text().strip()
41+
checkpoint_folder = str(checkpoint_path / checkpoint_pointer)
42+
43+
output = format_model_suffix(
44+
output,
45+
preferred_backend="jax",
46+
strict_prefer=True,
47+
)
48+
data = serialize_from_file(checkpoint_folder)
49+
deserialize_to_file(output, data)

deepmd/jax/entrypoints/main.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""DeePMD-Kit entry point module."""
3+
4+
import argparse
5+
from pathlib import (
6+
Path,
7+
)
8+
9+
from deepmd.jax.entrypoints.freeze import (
10+
freeze,
11+
)
12+
from deepmd.jax.entrypoints.train import (
13+
train,
14+
)
15+
from deepmd.loggers.loggers import (
16+
set_log_handles,
17+
)
18+
from deepmd.main import (
19+
parse_args,
20+
)
21+
22+
__all__ = ["main"]
23+
24+
25+
def main(args: list[str] | argparse.Namespace | None = None) -> None:
26+
"""DeePMD-Kit entry point.
27+
28+
Parameters
29+
----------
30+
args : list[str] or argparse.Namespace, optional
31+
list of command line arguments, used to avoid calling from the subprocess,
32+
as it is quite slow to import tensorflow; if Namespace is given, it will
33+
be used directly
34+
35+
Raises
36+
------
37+
RuntimeError
38+
if no command was input
39+
"""
40+
if not isinstance(args, argparse.Namespace):
41+
args = parse_args(args=args)
42+
43+
dict_args = vars(args)
44+
set_log_handles(
45+
args.log_level,
46+
Path(args.log_path) if args.log_path else None,
47+
mpi_log=None,
48+
)
49+
50+
if args.command == "train":
51+
train(**dict_args)
52+
elif args.command == "freeze":
53+
freeze(**dict_args)
54+
elif args.command is None:
55+
pass
56+
else:
57+
raise RuntimeError(f"unknown command {args.command}")

deepmd/jax/entrypoints/train.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""DeePMD training entrypoint script.
3+
4+
Can handle local training.
5+
"""
6+
7+
import json
8+
import logging
9+
import time
10+
from typing import (
11+
Any,
12+
)
13+
14+
from deepmd.common import (
15+
j_loader,
16+
)
17+
from deepmd.jax.env import (
18+
jax,
19+
jax_export,
20+
)
21+
from deepmd.jax.train.trainer import (
22+
DPTrainer,
23+
)
24+
from deepmd.utils import random as dp_random
25+
from deepmd.utils.argcheck import (
26+
normalize,
27+
)
28+
from deepmd.utils.compat import (
29+
update_deepmd_input,
30+
)
31+
from deepmd.utils.data_system import (
32+
get_data,
33+
)
34+
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter
35+
36+
__all__ = ["train"]
37+
38+
log = logging.getLogger(__name__)
39+
40+
41+
class SummaryPrinter(BaseSummaryPrinter):
42+
"""Summary printer for JAX."""
43+
44+
def is_built_with_cuda(self) -> bool:
45+
"""Check if the backend is built with CUDA."""
46+
return jax_export.default_export_platform() == "cuda"
47+
48+
def is_built_with_rocm(self) -> bool:
49+
"""Check if the backend is built with ROCm."""
50+
return jax_export.default_export_platform() == "rocm"
51+
52+
def get_compute_device(self) -> str:
53+
"""Get Compute device."""
54+
return jax.default_backend()
55+
56+
def get_ngpus(self) -> int:
57+
"""Get the number of GPUs."""
58+
return jax.device_count()
59+
60+
def get_backend_info(self) -> dict:
61+
"""Get backend information."""
62+
return {
63+
"Backend": "JAX",
64+
"JAX ver": jax.__version__,
65+
}
66+
67+
def get_device_name(self) -> str:
68+
"""Get the name of the device."""
69+
devices = jax.devices()
70+
if devices:
71+
return devices[0].device_kind
72+
else:
73+
return "Unknown"
74+
75+
76+
def train(
77+
*,
78+
INPUT: str,
79+
init_model: str | None,
80+
restart: str | None,
81+
output: str,
82+
init_frz_model: str | None,
83+
mpi_log: str,
84+
log_level: int,
85+
log_path: str | None,
86+
skip_neighbor_stat: bool = False,
87+
finetune: str | None = None,
88+
use_pretrain_script: bool = False,
89+
**kwargs: Any,
90+
) -> None:
91+
"""Run DeePMD model training.
92+
93+
Parameters
94+
----------
95+
INPUT : str
96+
json/yaml control file
97+
init_model : Optional[str]
98+
path prefix of checkpoint files or None
99+
restart : Optional[str]
100+
path prefix of checkpoint files or None
101+
output : str
102+
path for dump file with arguments
103+
init_frz_model : str | None
104+
path to frozen model, or None if no frozen model is used
105+
mpi_log : str
106+
mpi logging mode
107+
log_level : int
108+
logging level defined by int 0-3
109+
log_path : Optional[str]
110+
logging file path or None if logs are to be output only to stdout
111+
skip_neighbor_stat : bool, default=False
112+
skip checking neighbor statistics
113+
finetune : Optional[str]
114+
path to pretrained model or None
115+
use_pretrain_script : bool
116+
Whether to use model script in pretrained model when doing init-model or init-frz-model.
117+
Note that this option is true and unchangeable for fine-tuning.
118+
**kwargs
119+
additional arguments
120+
121+
Raises
122+
------
123+
RuntimeError
124+
if the training command fails.
125+
"""
126+
# load json database
127+
jdata = j_loader(INPUT)
128+
129+
if init_frz_model:
130+
raise NotImplementedError("JAX training does not support init_frz_model yet")
131+
if finetune:
132+
raise NotImplementedError("JAX training does not support finetune yet")
133+
if use_pretrain_script:
134+
raise NotImplementedError(
135+
"JAX training does not support use_pretrain_script yet"
136+
)
137+
138+
jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")
139+
140+
jdata = normalize(jdata)
141+
if not skip_neighbor_stat:
142+
jdata = update_sel(jdata)
143+
144+
with open(output, "w") as fp:
145+
json.dump(jdata, fp, indent=4)
146+
SummaryPrinter()()
147+
148+
# make necessary checks
149+
assert "training" in jdata
150+
151+
# init the model
152+
153+
model = DPTrainer(
154+
jdata,
155+
init_model=init_model,
156+
restart=restart,
157+
)
158+
rcut = model.model.get_rcut()
159+
type_map = model.model.get_type_map()
160+
if len(type_map) == 0:
161+
ipt_type_map = None
162+
else:
163+
ipt_type_map = type_map
164+
165+
# init random seed of data systems
166+
seed = jdata["training"].get("seed", None)
167+
if seed is not None:
168+
seed += jax.process_index()
169+
seed = seed % (2**32)
170+
dp_random.seed(seed)
171+
172+
# init data
173+
train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, None)
174+
train_data.add_data_requirements(model.data_requirements)
175+
train_data.print_summary("training")
176+
if jdata["training"].get("validation_data", None) is not None:
177+
valid_data = get_data(
178+
jdata["training"]["validation_data"],
179+
rcut,
180+
train_data.type_map,
181+
None,
182+
)
183+
valid_data.add_data_requirements(model.data_requirements)
184+
valid_data.print_summary("validation")
185+
else:
186+
valid_data = None
187+
188+
# train the model with the provided systems in a cyclic way
189+
start_time = time.time()
190+
model.train(train_data, valid_data)
191+
end_time = time.time()
192+
log.info("finished training")
193+
log.info(f"wall time: {(end_time - start_time):.3f} s")
194+
195+
196+
def update_sel(jdata: dict) -> dict:
197+
"""Update descriptor selections from neighbor statistics when available."""
198+
log.info(
199+
"Skip neighbor statistics update for JAX training; "
200+
"BaseModel.update_sel currently needs more memory than expected."
201+
)
202+
# TODO: Restore BaseModel.update_sel once the JAX data path avoids OOM.
203+
return jdata.copy()

deepmd/jax/train/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later

0 commit comments

Comments
 (0)