Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
512eeb6
feat(pt_expt): multi-task training support
Apr 15, 2026
9f4d232
fix(dpmodel): wrap fparam/aparam reshape with descriptive ValueError
Apr 16, 2026
9f1f1d8
fix: address CodeQL findings in PR #5397
Apr 16, 2026
f3f5474
fix(pt_expt): access unwrapped module in _compile_model for DDP compat
Apr 16, 2026
665b85a
test(pt_expt): add DDP + torch.compile training tests
Apr 16, 2026
aabb710
feat(pt_expt): use inductor+dynamic for training compile
Apr 16, 2026
f774cd2
test(pt_expt): port silut activation + repformers accessors from #5393
Apr 16, 2026
0b5468e
test(pt_expt): assert virial in compile correctness tests
Apr 16, 2026
9bf006b
test(pt_expt): port silu compile and varying-natoms tests from #5393
Apr 16, 2026
7722f52
test(pt_expt): compare compiled vs uncompiled with varying natoms
Apr 16, 2026
be14ac2
test(pt_expt): cover DPA2/DPA3 in varying-natoms compile correctness
Apr 16, 2026
4c0b8ec
test(pt_expt): exercise DPA2 three-body branch in compile correctness
Apr 16, 2026
80c714c
fix(dpmodel): restore nf in reshapes to fix zero-atom and add silu_ba…
Apr 17, 2026
6158d9c
fix: address CodeQL findings in PR #5397
Apr 17, 2026
c2efbf1
fix(pt): wrap fparam/aparam reshape with descriptive ValueError
Apr 17, 2026
1e694a3
feat(pt_expt): reject DPA1/se_atten_v2 with attention at compile time
Apr 18, 2026
6d39ddf
fix(pt_expt): remove false DPA1 attention compile guard
Apr 18, 2026
23eb6dd
refactor(dpmodel): remove unused get_numb_attn_layer API
Apr 18, 2026
bacd312
fix(test): use real path for PT water data, remove unused API
Apr 18, 2026
f834202
fix(pt_expt): rebuild FX graph after detach node removal to avoid seg…
Apr 18, 2026
447a572
fix(pt_expt): tune inductor options for compile training
Apr 18, 2026
fb25ccb
fix(pt_expt): disable DDPOptimizer to prevent compiled graph splitting
Apr 18, 2026
479900d
fix(test): add .cpu() before .numpy() for GPU-compatible activation t…
Apr 18, 2026
b67a181
fix(pt_expt): revert inductor options that cause numerical divergence
Apr 18, 2026
7ce7352
fix(test): make DDP tests device-adaptive instead of hardcoding CPU
Apr 18, 2026
975db17
fix(test): correct freeze test docstrings to match dpa3 guard
Apr 18, 2026
64dc703
fix(pt_expt): move optimize_ddp into _compile_model, resolve test sym…
Apr 18, 2026
28fbcac
fix(test): backup/restore fparam.npy in TestFparam instead of deleting
Apr 18, 2026
fbb361a
fix(test): skip DDP tests when NCCL is selected with fewer than 2 GPUs
Apr 18, 2026
7739fad
perf(pt2): optimize .pt2 C++ inference path
Apr 20, 2026
19272c2
Merge upstream/master into perf-pt-expt-pt2-cpp
Apr 20, 2026
b7509db
feat(pt2): make nlist nnei dimension dynamic in .pt2 export
Apr 20, 2026
eec2528
fix(pt2): pad nlist in Python eval path for dynamic nnei
Apr 20, 2026
217a587
fix(pt2): move atomic virial check before run_model and reject unsupp…
Apr 20, 2026
8a9fe63
fix(pt2): move nlist padding inside traced fn and strip shape assertions
Apr 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions deepmd/entrypoints/convert_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def convert_backend(
*, # Enforce keyword-only arguments
INPUT: str,
OUTPUT: str,
atomic_virial: bool = False,
**kwargs: Any,
) -> None:
"""Convert a model file from one backend to another.
Expand All @@ -20,12 +21,26 @@ def convert_backend(
----------
INPUT : str
The input model file.
INPUT : str
OUTPUT : str
The output model file.
atomic_virial : bool
If True, export .pt2/.pte models with per-atom virial correction.
This adds ~2.5x inference cost. Default False.
"""
inp_backend: Backend = Backend.detect_backend_by_model(INPUT)()
out_backend: Backend = Backend.detect_backend_by_model(OUTPUT)()
inp_hook = inp_backend.serialize_hook
out_hook = out_backend.deserialize_hook
data = inp_hook(INPUT)
out_hook(OUTPUT, data)
# Forward atomic_virial to pt_expt deserialize_to_file if applicable
import inspect

sig = inspect.signature(out_hook)
if "do_atomic_virial" in sig.parameters:
out_hook(OUTPUT, data, do_atomic_virial=atomic_virial)
else:
if atomic_virial:
raise ValueError(
"--atomic-virial is only supported for pt_expt .pt2/.pte outputs"
)
out_hook(OUTPUT, data)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
8 changes: 8 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,14 @@ def main_parser() -> argparse.ArgumentParser:
)
parser_convert_backend.add_argument("INPUT", help="The input model file.")
parser_convert_backend.add_argument("OUTPUT", help="The output model file.")
parser_convert_backend.add_argument(
"--atomic-virial",
action="store_true",
default=False,
help="Export .pt2/.pte models with per-atom virial correction. "
"This adds ~2.5x inference cost but is required for "
"LAMMPS compute/atom virial output.",
)

# * show model ******************************************************************
parser_show = subparsers.add_parser(
Expand Down
37 changes: 29 additions & 8 deletions deepmd/pt_expt/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,21 @@ def fn(
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
extended_coord = extended_coord.detach().requires_grad_(True)
# Pad nlist with one extra -1 column inside the traced function.
# This ensures n_nnei > sum(sel), forcing the sort branch in
# _format_nlist. The padding becomes part of the compiled graph,
# so callers never need to pad externally.
nlist = torch.cat(
[
nlist,
-torch.ones(
(*nlist.shape[:2], 1),
dtype=nlist.dtype,
device=nlist.device,
),
],
dim=-1,
)
return model.forward_common_lower(
extended_coord,
extended_atype,
Expand All @@ -356,13 +371,19 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn, **make_fx_kwargs)(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
# Force format_nlist to always use the sort branch during tracing.
model.need_sorted_nlist_for_lower = lambda: True
try:
traced = make_fx(fn, **make_fx_kwargs)(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
finally:
del model.need_sorted_nlist_for_lower
return traced

return CM
37 changes: 28 additions & 9 deletions deepmd/pt_expt/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ def fn(
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
extended_coord = extended_coord.detach().requires_grad_(True)
# Pad nlist inside traced function (see make_model.py for rationale).
nlist = torch.cat(
[
nlist,
-torch.ones(
(*nlist.shape[:2], 1),
dtype=nlist.dtype,
device=nlist.device,
),
],
dim=-1,
)
return model.forward_common_lower(
extended_coord,
extended_atype,
Expand All @@ -107,15 +119,22 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn, **make_fx_kwargs)(
extended_coord,
extended_atype,
extended_spin,
nlist,
mapping,
fparam,
aparam,
)
# Force format_nlist to always use the sort branch during tracing.
backbone = model.backbone_model
backbone.need_sorted_nlist_for_lower = lambda: True
try:
traced = make_fx(fn, **make_fx_kwargs)(
extended_coord,
extended_atype,
extended_spin,
nlist,
mapping,
fparam,
aparam,
)
finally:
del backbone.need_sorted_nlist_for_lower
return traced

def forward_common_lower(
self, *args: Any, **kwargs: Any
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt_expt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,6 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None:
n_attn,
task_key,
)

inp, _ = self.get_data(is_train=True, task_key=task_key)
coord = inp["coord"].detach()
atype = inp["atype"].detach()
Expand Down
Loading
Loading