Skip to content

Commit fdeff2b

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): add training infrastructure (#5270)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * New "PyTorch-Exportable" backend with end-to-end training CLI, model factory, loss wrapper, model wrapper, checkpointing, neighbor/stat utilities, and optional torch.compile acceleration with automatic re-tracing. * **Bug Fixes** * Added runtime validation for required parameter fields (fparam/aparam) and ensured loss/reporting arrays are device-consistent. * **Refactor** * Reworked stat input, batch normalization, and model-stat collection for clearer, backend-agnostic data flow. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 7e648f0 commit fdeff2b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+3035
-89
lines changed

deepmd/backend/pt_expt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class PyTorchExportableBackend(Backend):
3333
"""PyTorch exportable backend."""
3434

35-
name = "PyTorch Exportable"
35+
name = "PyTorch-Exportable"
3636
"""The formal name of the backend."""
3737
features: ClassVar[Backend.Feature] = (
3838
Backend.Feature.ENTRY_POINT
@@ -63,7 +63,7 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]:
6363
Callable[[Namespace], None]
6464
The entry point hook of the backend.
6565
"""
66-
from deepmd.pt.entrypoints.main import main as deepmd_main
66+
from deepmd.pt_expt.entrypoints.main import main as deepmd_main
6767

6868
return deepmd_main
6969

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -394,19 +394,19 @@ def wrapped_sampler() -> list[dict]:
394394
atom_exclude_types = self.atom_excl.get_exclude_types()
395395
for sample in sampled:
396396
sample["atom_exclude_types"] = list(atom_exclude_types)
397-
if (
398-
"find_fparam" not in sampled[0]
399-
and "fparam" not in sampled[0]
400-
and self.has_default_fparam()
401-
):
397+
# For systems where fparam is missing (find_fparam == 0),
398+
# fill with default fparam if available and mark as found.
399+
if self.has_default_fparam():
402400
default_fparam = self.get_default_fparam()
403401
if default_fparam is not None:
404402
default_fparam_np = np.array(default_fparam)
405403
for sample in sampled:
406-
nframe = sample["atype"].shape[0]
407-
sample["fparam"] = np.tile(
408-
default_fparam_np.reshape(1, -1), (nframe, 1)
409-
)
404+
if "find_fparam" in sample and not sample["find_fparam"]:
405+
nframe = sample["atype"].shape[0]
406+
sample["fparam"] = np.tile(
407+
default_fparam_np.reshape(1, -1), (nframe, 1)
408+
)
409+
sample["find_fparam"] = np.bool_(True)
410410
return sampled
411411

412412
return wrapped_sampler

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ class DescrptDPA1(NativeOP, BaseDescriptor):
242242
arXiv preprint arXiv:2208.08236.
243243
"""
244244

245+
_update_sel_cls = UpdateSel
246+
245247
def __init__(
246248
self,
247249
rcut: float,
@@ -662,7 +664,7 @@ def update_sel(
662664
The minimum distance between two atoms
663665
"""
664666
local_jdata_cpy = local_jdata.copy()
665-
min_nbor_dist, sel = UpdateSel().update_one_sel(
667+
min_nbor_dist, sel = cls._update_sel_cls().update_one_sel(
666668
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True
667669
)
668670
local_jdata_cpy["sel"] = sel[0]

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,8 @@ class DescrptDPA2(NativeOP, BaseDescriptor):
441441
Comput Mater 10, 293 (2024). https://doi.org/10.1038/s41524-024-01493-2
442442
"""
443443

444+
_update_sel_cls = UpdateSel
445+
444446
def __init__(
445447
self,
446448
ntypes: int,
@@ -1114,7 +1116,7 @@ def update_sel(
11141116
The minimum distance between two atoms
11151117
"""
11161118
local_jdata_cpy = local_jdata.copy()
1117-
update_sel = UpdateSel()
1119+
update_sel = cls._update_sel_cls()
11181120
min_nbor_dist, repinit_sel = update_sel.update_one_sel(
11191121
train_data,
11201122
type_map,

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ class DescrptDPA3(NativeOP, BaseDescriptor):
337337
arXiv preprint arXiv:2506.01686 (2025).
338338
"""
339339

340+
_update_sel_cls = UpdateSel
341+
340342
def __init__(
341343
self,
342344
ntypes: int,
@@ -729,7 +731,7 @@ def update_sel(
729731
The minimum distance between two atoms
730732
"""
731733
local_jdata_cpy = local_jdata.copy()
732-
update_sel = UpdateSel()
734+
update_sel = cls._update_sel_cls()
733735
min_nbor_dist, repflow_e_sel = update_sel.update_one_sel(
734736
train_data,
735737
type_map,

deepmd/dpmodel/descriptor/make_base_descriptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class BD(ABC, PluginVariant, make_plugin_registry("descriptor")):
5151
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
5252
if cls is BD:
5353
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
54-
return super().__new__(cls)
54+
return object.__new__(cls)
5555

5656
@abstractmethod
5757
def get_rcut(self) -> float:

deepmd/dpmodel/descriptor/se_e2_a.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ class DescrptSeA(NativeOP, BaseDescriptor):
149149
Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 4441-4451.
150150
"""
151151

152+
_update_sel_cls = UpdateSel
153+
152154
def __init__(
153155
self,
154156
rcut: float,
@@ -582,7 +584,7 @@ def update_sel(
582584
The minimum distance between two atoms
583585
"""
584586
local_jdata_cpy = local_jdata.copy()
585-
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(
587+
min_nbor_dist, local_jdata_cpy["sel"] = cls._update_sel_cls().update_one_sel(
586588
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
587589
)
588590
return local_jdata_cpy, min_nbor_dist

deepmd/dpmodel/descriptor/se_r.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ class DescrptSeR(NativeOP, BaseDescriptor):
128128
Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 4441-4451.
129129
"""
130130

131+
_update_sel_cls = UpdateSel
132+
131133
def __init__(
132134
self,
133135
rcut: float,
@@ -505,7 +507,7 @@ def update_sel(
505507
The minimum distance between two atoms
506508
"""
507509
local_jdata_cpy = local_jdata.copy()
508-
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(
510+
min_nbor_dist, local_jdata_cpy["sel"] = cls._update_sel_cls().update_one_sel(
509511
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
510512
)
511513
return local_jdata_cpy, min_nbor_dist

deepmd/dpmodel/descriptor/se_t.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ class DescrptSeT(NativeOP, BaseDescriptor):
116116
Not used in this descriptor, only to be compat with input.
117117
"""
118118

119+
_update_sel_cls = UpdateSel
120+
119121
def __init__(
120122
self,
121123
rcut: float,
@@ -505,7 +507,7 @@ def update_sel(
505507
The minimum distance between two atoms
506508
"""
507509
local_jdata_cpy = local_jdata.copy()
508-
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(
510+
min_nbor_dist, local_jdata_cpy["sel"] = cls._update_sel_cls().update_one_sel(
509511
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
510512
)
511513
return local_jdata_cpy, min_nbor_dist

deepmd/dpmodel/descriptor/se_t_tebd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ class DescrptSeTTebd(NativeOP, BaseDescriptor):
141141
142142
"""
143143

144+
_update_sel_cls = UpdateSel
145+
144146
def __init__(
145147
self,
146148
rcut: float,
@@ -500,7 +502,7 @@ def update_sel(
500502
The minimum distance between two atoms
501503
"""
502504
local_jdata_cpy = local_jdata.copy()
503-
min_nbor_dist, sel = UpdateSel().update_one_sel(
505+
min_nbor_dist, sel = cls._update_sel_cls().update_one_sel(
504506
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True
505507
)
506508
local_jdata_cpy["sel"] = sel[0]

0 commit comments

Comments
 (0)