Skip to content

Commit 59611c8

Browse files
committed
feat(jax): add local training entrypoint
Port the JAX training entrypoint from the parallel branch onto current master, but keep it local-only by removing distributed, sharding, and Hessian hooks. Use the current dpmodel compute_or_load_stat data-stat path and add regression coverage for the cleanup constraints. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
1 parent f39a081 commit 59611c8

7 files changed

Lines changed: 921 additions & 1 deletion

File tree

deepmd/backend/jax.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]:
6262
Callable[[Namespace], None]
6363
The entry point hook of the backend.
6464
"""
65-
raise NotImplementedError
65+
from deepmd.jax.entrypoints.main import (
66+
main,
67+
)
68+
69+
return main
6670

6771
@property
6872
def deep_eval(self) -> type["DeepEvalBackend"]:

deepmd/jax/entrypoints/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later

deepmd/jax/entrypoints/main.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""DeePMD-Kit entry point module."""
3+
4+
import argparse
5+
from pathlib import (
6+
Path,
7+
)
8+
from typing import (
9+
Optional,
10+
Union,
11+
)
12+
13+
from deepmd.backend.suffix import (
14+
format_model_suffix,
15+
)
16+
from deepmd.jax.entrypoints.freeze import (
17+
freeze,
18+
)
19+
from deepmd.jax.entrypoints.train import (
20+
train,
21+
)
22+
from deepmd.loggers.loggers import (
23+
set_log_handles,
24+
)
25+
from deepmd.main import (
26+
parse_args,
27+
)
28+
29+
__all__ = ["main"]
30+
31+
32+
def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None:
33+
"""DeePMD-Kit entry point.
34+
35+
Parameters
36+
----------
37+
args : list[str] or argparse.Namespace, optional
38+
list of command line arguments, used to avoid calling from the subprocess,
39+
as it is quite slow to import tensorflow; if Namespace is given, it will
40+
be used directly
41+
42+
Raises
43+
------
44+
RuntimeError
45+
if no command was input
46+
"""
47+
if not isinstance(args, argparse.Namespace):
48+
args = parse_args(args=args)
49+
50+
dict_args = vars(args)
51+
set_log_handles(
52+
args.log_level,
53+
Path(args.log_path) if args.log_path else None,
54+
mpi_log=None,
55+
)
56+
57+
if args.command == "train":
58+
train(**dict_args)
59+
elif args.command == "freeze":
60+
dict_args["output"] = format_model_suffix(
61+
dict_args["output"], preferred_backend=args.backend, strict_prefer=True
62+
)
63+
freeze(**dict_args)
64+
elif args.command is None:
65+
pass
66+
else:
67+
raise RuntimeError(f"unknown command {args.command}")

deepmd/jax/entrypoints/train.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""DeePMD training entrypoint script.
3+
4+
Can handle local training.
5+
"""
6+
7+
import json
8+
import logging
9+
import time
10+
from typing import (
11+
Any,
12+
Optional,
13+
)
14+
15+
16+
from deepmd.common import (
17+
j_loader,
18+
)
19+
from deepmd.jax.env import (
20+
jax,
21+
jax_export,
22+
)
23+
from deepmd.jax.train.trainer import (
24+
DPTrainer,
25+
)
26+
from deepmd.utils import random as dp_random
27+
from deepmd.utils.argcheck import (
28+
normalize,
29+
)
30+
from deepmd.utils.compat import (
31+
update_deepmd_input,
32+
)
33+
from deepmd.utils.data_system import (
34+
get_data,
35+
)
36+
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter
37+
38+
__all__ = ["train"]
39+
40+
log = logging.getLogger(__name__)
41+
42+
43+
class SummaryPrinter(BaseSummaryPrinter):
44+
"""Summary printer for JAX."""
45+
46+
def is_built_with_cuda(self) -> bool:
47+
"""Check if the backend is built with CUDA."""
48+
return jax_export.default_export_platform() == "cuda"
49+
50+
def is_built_with_rocm(self) -> bool:
51+
"""Check if the backend is built with ROCm."""
52+
return jax_export.default_export_platform() == "rocm"
53+
54+
def get_compute_device(self) -> str:
55+
"""Get Compute device."""
56+
return jax.default_backend()
57+
58+
def get_ngpus(self) -> int:
59+
"""Get the number of GPUs."""
60+
return jax.device_count()
61+
62+
def get_backend_info(self) -> dict:
63+
"""Get backend information."""
64+
return {
65+
"Backend": "JAX",
66+
"JAX ver": jax.__version__,
67+
}
68+
69+
def get_device_name(self) -> str:
70+
"""Get the name of the device."""
71+
devices = jax.devices()
72+
if devices:
73+
return devices[0].device_kind
74+
else:
75+
return "Unknown"
76+
77+
78+
def train(
79+
*,
80+
INPUT: str,
81+
init_model: Optional[str],
82+
restart: Optional[str],
83+
output: str,
84+
init_frz_model: str,
85+
mpi_log: str,
86+
log_level: int,
87+
log_path: Optional[str],
88+
skip_neighbor_stat: bool = False,
89+
finetune: Optional[str] = None,
90+
use_pretrain_script: bool = False,
91+
**kwargs: Any,
92+
) -> None:
93+
"""Run DeePMD model training.
94+
95+
Parameters
96+
----------
97+
INPUT : str
98+
json/yaml control file
99+
init_model : Optional[str]
100+
path prefix of checkpoint files or None
101+
restart : Optional[str]
102+
path prefix of checkpoint files or None
103+
output : str
104+
path for dump file with arguments
105+
init_frz_model : str
106+
path to frozen model or None
107+
mpi_log : str
108+
mpi logging mode
109+
log_level : int
110+
logging level defined by int 0-3
111+
log_path : Optional[str]
112+
logging file path or None if logs are to be output only to stdout
113+
skip_neighbor_stat : bool, default=False
114+
skip checking neighbor statistics
115+
finetune : Optional[str]
116+
path to pretrained model or None
117+
use_pretrain_script : bool
118+
Whether to use model script in pretrained model when doing init-model or init-frz-model.
119+
Note that this option is true and unchangeable for fine-tuning.
120+
**kwargs
121+
additional arguments
122+
123+
Raises
124+
------
125+
RuntimeError
126+
if the training command fails.
127+
"""
128+
# load json database
129+
jdata = j_loader(INPUT)
130+
131+
origin_type_map = None
132+
133+
jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")
134+
135+
jdata = normalize(jdata)
136+
jdata = update_sel(jdata)
137+
138+
with open(output, "w") as fp:
139+
json.dump(jdata, fp, indent=4)
140+
SummaryPrinter()()
141+
142+
# make necessary checks
143+
assert "training" in jdata
144+
145+
# init the model
146+
147+
model = DPTrainer(
148+
jdata,
149+
init_model=init_model,
150+
restart=restart,
151+
)
152+
rcut = model.model.get_rcut()
153+
type_map = model.model.get_type_map()
154+
if len(type_map) == 0:
155+
ipt_type_map = None
156+
else:
157+
ipt_type_map = type_map
158+
159+
# init random seed of data systems
160+
seed = jdata["training"].get("seed", None)
161+
if seed is not None:
162+
seed += jax.process_index()
163+
seed = seed % (2**32)
164+
dp_random.seed(seed)
165+
166+
# init data
167+
train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, None)
168+
train_data.add_data_requirements(model.data_requirements)
169+
train_data.print_summary("training")
170+
if jdata["training"].get("validation_data", None) is not None:
171+
valid_data = get_data(
172+
jdata["training"]["validation_data"],
173+
rcut,
174+
train_data.type_map,
175+
None,
176+
)
177+
valid_data.add_data_requirements(model.data_requirements)
178+
valid_data.print_summary("validation")
179+
else:
180+
valid_data = None
181+
182+
# get training info
183+
stop_batch = jdata["training"]["numb_steps"]
184+
origin_type_map = jdata["model"].get("origin_type_map", None)
185+
if (
186+
origin_type_map is not None and not origin_type_map
187+
): # get the type_map from data if not provided
188+
origin_type_map = get_data(
189+
jdata["training"]["training_data"], rcut, None, None
190+
).get_type_map()
191+
192+
# train the model with the provided systems in a cyclic way
193+
start_time = time.time()
194+
model.train(train_data, valid_data)
195+
end_time = time.time()
196+
log.info("finished training")
197+
log.info(f"wall time: {(end_time - start_time):.3f} s")
198+
199+
200+
def update_sel(jdata: dict) -> dict:
201+
"""Update descriptor selections from neighbor statistics when available."""
202+
log.info(
203+
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
204+
)
205+
jdata_cpy = jdata.copy()
206+
type_map = jdata["model"].get("type_map")
207+
train_data = get_data(
208+
jdata["training"]["training_data"],
209+
0, # not used
210+
type_map,
211+
None, # not used
212+
)
213+
# TODO: OOM, need debug
214+
# jdata_cpy["model"], min_nbor_dist = BaseModel.update_sel(
215+
# train_data, type_map, jdata["model"]
216+
# )
217+
return jdata_cpy

deepmd/jax/train/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later

0 commit comments

Comments
 (0)