Skip to content

Commit 2528142

Browse files
committed
feat(jax): add local training entrypoint
Port the JAX training entrypoint from the parallel branch onto current master, but keep it local-only by removing distributed, sharding, and Hessian hooks. Use the current dpmodel compute_or_load_stat data-stat path and add regression coverage for the cleanup constraints. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
1 parent f39a081 commit 2528142

7 files changed

Lines changed: 915 additions & 1 deletion

File tree

deepmd/backend/jax.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]:
6262
Callable[[Namespace], None]
6363
The entry point hook of the backend.
6464
"""
65-
raise NotImplementedError
65+
from deepmd.jax.entrypoints.main import (
66+
main,
67+
)
68+
69+
return main
6670

6771
@property
6872
def deep_eval(self) -> type["DeepEvalBackend"]:

deepmd/jax/entrypoints/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later

deepmd/jax/entrypoints/main.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""DeePMD-Kit entry point module."""
3+
4+
import argparse
5+
from pathlib import (
6+
Path,
7+
)
8+
9+
from deepmd.backend.suffix import (
10+
format_model_suffix,
11+
)
12+
from deepmd.jax.entrypoints.freeze import (
13+
freeze,
14+
)
15+
from deepmd.jax.entrypoints.train import (
16+
train,
17+
)
18+
from deepmd.loggers.loggers import (
19+
set_log_handles,
20+
)
21+
from deepmd.main import (
22+
parse_args,
23+
)
24+
25+
__all__ = ["main"]
26+
27+
28+
def main(args: list[str] | argparse.Namespace | None = None) -> None:
29+
"""DeePMD-Kit entry point.
30+
31+
Parameters
32+
----------
33+
args : list[str] or argparse.Namespace, optional
34+
list of command line arguments, used to avoid calling from the subprocess,
35+
as it is quite slow to import tensorflow; if Namespace is given, it will
36+
be used directly
37+
38+
Raises
39+
------
40+
RuntimeError
41+
if no command was input
42+
"""
43+
if not isinstance(args, argparse.Namespace):
44+
args = parse_args(args=args)
45+
46+
dict_args = vars(args)
47+
set_log_handles(
48+
args.log_level,
49+
Path(args.log_path) if args.log_path else None,
50+
mpi_log=None,
51+
)
52+
53+
if args.command == "train":
54+
train(**dict_args)
55+
elif args.command == "freeze":
56+
dict_args["output"] = format_model_suffix(
57+
dict_args["output"], preferred_backend=args.backend, strict_prefer=True
58+
)
59+
freeze(**dict_args)
60+
elif args.command is None:
61+
pass
62+
else:
63+
raise RuntimeError(f"unknown command {args.command}")

deepmd/jax/entrypoints/train.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""DeePMD training entrypoint script.
3+
4+
Can handle local training.
5+
"""
6+
7+
import json
8+
import logging
9+
import time
10+
from typing import (
11+
Any,
12+
)
13+
14+
from deepmd.common import (
15+
j_loader,
16+
)
17+
from deepmd.jax.env import (
18+
jax,
19+
jax_export,
20+
)
21+
from deepmd.jax.train.trainer import (
22+
DPTrainer,
23+
)
24+
from deepmd.utils import random as dp_random
25+
from deepmd.utils.argcheck import (
26+
normalize,
27+
)
28+
from deepmd.utils.compat import (
29+
update_deepmd_input,
30+
)
31+
from deepmd.utils.data_system import (
32+
get_data,
33+
)
34+
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter
35+
36+
__all__ = ["train"]
37+
38+
log = logging.getLogger(__name__)
39+
40+
41+
class SummaryPrinter(BaseSummaryPrinter):
42+
"""Summary printer for JAX."""
43+
44+
def is_built_with_cuda(self) -> bool:
45+
"""Check if the backend is built with CUDA."""
46+
return jax_export.default_export_platform() == "cuda"
47+
48+
def is_built_with_rocm(self) -> bool:
49+
"""Check if the backend is built with ROCm."""
50+
return jax_export.default_export_platform() == "rocm"
51+
52+
def get_compute_device(self) -> str:
53+
"""Get Compute device."""
54+
return jax.default_backend()
55+
56+
def get_ngpus(self) -> int:
57+
"""Get the number of GPUs."""
58+
return jax.device_count()
59+
60+
def get_backend_info(self) -> dict:
61+
"""Get backend information."""
62+
return {
63+
"Backend": "JAX",
64+
"JAX ver": jax.__version__,
65+
}
66+
67+
def get_device_name(self) -> str:
68+
"""Get the name of the device."""
69+
devices = jax.devices()
70+
if devices:
71+
return devices[0].device_kind
72+
else:
73+
return "Unknown"
74+
75+
76+
def train(
77+
*,
78+
INPUT: str,
79+
init_model: str | None,
80+
restart: str | None,
81+
output: str,
82+
init_frz_model: str,
83+
mpi_log: str,
84+
log_level: int,
85+
log_path: str | None,
86+
skip_neighbor_stat: bool = False,
87+
finetune: str | None = None,
88+
use_pretrain_script: bool = False,
89+
**kwargs: Any,
90+
) -> None:
91+
"""Run DeePMD model training.
92+
93+
Parameters
94+
----------
95+
INPUT : str
96+
json/yaml control file
97+
init_model : Optional[str]
98+
path prefix of checkpoint files or None
99+
restart : Optional[str]
100+
path prefix of checkpoint files or None
101+
output : str
102+
path for dump file with arguments
103+
init_frz_model : str
104+
path to frozen model or None
105+
mpi_log : str
106+
mpi logging mode
107+
log_level : int
108+
logging level defined by int 0-3
109+
log_path : Optional[str]
110+
logging file path or None if logs are to be output only to stdout
111+
skip_neighbor_stat : bool, default=False
112+
skip checking neighbor statistics
113+
finetune : Optional[str]
114+
path to pretrained model or None
115+
use_pretrain_script : bool
116+
Whether to use model script in pretrained model when doing init-model or init-frz-model.
117+
Note that this option is true and unchangeable for fine-tuning.
118+
**kwargs
119+
additional arguments
120+
121+
Raises
122+
------
123+
RuntimeError
124+
if the training command fails.
125+
"""
126+
# load json database
127+
jdata = j_loader(INPUT)
128+
129+
origin_type_map = None
130+
131+
jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")
132+
133+
jdata = normalize(jdata)
134+
jdata = update_sel(jdata)
135+
136+
with open(output, "w") as fp:
137+
json.dump(jdata, fp, indent=4)
138+
SummaryPrinter()()
139+
140+
# make necessary checks
141+
assert "training" in jdata
142+
143+
# init the model
144+
145+
model = DPTrainer(
146+
jdata,
147+
init_model=init_model,
148+
restart=restart,
149+
)
150+
rcut = model.model.get_rcut()
151+
type_map = model.model.get_type_map()
152+
if len(type_map) == 0:
153+
ipt_type_map = None
154+
else:
155+
ipt_type_map = type_map
156+
157+
# init random seed of data systems
158+
seed = jdata["training"].get("seed", None)
159+
if seed is not None:
160+
seed += jax.process_index()
161+
seed = seed % (2**32)
162+
dp_random.seed(seed)
163+
164+
# init data
165+
train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, None)
166+
train_data.add_data_requirements(model.data_requirements)
167+
train_data.print_summary("training")
168+
if jdata["training"].get("validation_data", None) is not None:
169+
valid_data = get_data(
170+
jdata["training"]["validation_data"],
171+
rcut,
172+
train_data.type_map,
173+
None,
174+
)
175+
valid_data.add_data_requirements(model.data_requirements)
176+
valid_data.print_summary("validation")
177+
else:
178+
valid_data = None
179+
180+
# get training info
181+
stop_batch = jdata["training"]["numb_steps"]
182+
origin_type_map = jdata["model"].get("origin_type_map", None)
183+
if (
184+
origin_type_map is not None and not origin_type_map
185+
): # get the type_map from data if not provided
186+
origin_type_map = get_data(
187+
jdata["training"]["training_data"], rcut, None, None
188+
).get_type_map()
189+
190+
# train the model with the provided systems in a cyclic way
191+
start_time = time.time()
192+
model.train(train_data, valid_data)
193+
end_time = time.time()
194+
log.info("finished training")
195+
log.info(f"wall time: {(end_time - start_time):.3f} s")
196+
197+
198+
def update_sel(jdata: dict) -> dict:
199+
"""Update descriptor selections from neighbor statistics when available."""
200+
log.info(
201+
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
202+
)
203+
jdata_cpy = jdata.copy()
204+
type_map = jdata["model"].get("type_map")
205+
train_data = get_data(
206+
jdata["training"]["training_data"],
207+
0, # not used
208+
type_map,
209+
None, # not used
210+
)
211+
# TODO: OOM, need debug
212+
# jdata_cpy["model"], min_nbor_dist = BaseModel.update_sel(
213+
# train_data, type_map, jdata["model"]
214+
# )
215+
return jdata_cpy

deepmd/jax/train/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later

0 commit comments

Comments
 (0)