Skip to content

Commit 27a18b6

Browse files
perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks (deepmodeling#5457)
make dataset embedding and energy bias as input not buffer for compile, this allows multitask training share compiled model thus resolve OOM and NCCL timeout issue. Since the empty_cache and del are removed, no GC complaints. Regression Test <img width="3600" height="2100" alt="lcurve" src="https://github.com/user-attachments/assets/c043bf6c-53bb-441f-ac98-0d021b68ec1b" /> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Multi-task training groups models by structure and caches/reuses compiled computation graphs, with per-task buffer handling to support shared fitting nets. * **Bug Fixes** * Checkpoint loading now skips extraneous per-task buffer entries so only original model parameters are restored. * Training aggregation coerces tensor-like loss/metric values to floats for accurate reporting. * **Tests** * Added regression test ensuring compiled and eager outputs match per task for shared-fitting, different-descriptor setups. <!-- review_stack_entry_start --> [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/deepmodeling/deepmd-kit/pull/5457?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1de6de9 commit 27a18b6

5 files changed

Lines changed: 578 additions & 64 deletions

File tree

deepmd/pt_expt/descriptor/dpa2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def share_params(
6969
"g1_shape_tranform"
7070
]
7171
self._modules["repformers"] = base_class._modules["repformers"]
72+
if "tebd_transform" in base_class._modules:
73+
self._modules["tebd_transform"] = base_class._modules["tebd_transform"]
7274
elif shared_level == 1:
7375
self._modules["type_embedding"] = base_class._modules["type_embedding"]
7476
else:

deepmd/pt_expt/descriptor/dpa3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def share_params(
4040
if not resume:
4141
merge_env_stat(base_class.repflows, self.repflows, model_prob)
4242
self._modules["repflows"] = base_class._modules["repflows"]
43+
if self.add_chg_spin_ebd:
44+
for key in ("chg_embedding", "spin_embedding", "mix_cs_mlp"):
45+
if key in base_class._modules:
46+
self._modules[key] = base_class._modules[key]
4347
elif shared_level == 1:
4448
self._modules["type_embedding"] = base_class._modules["type_embedding"]
4549
else:

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,15 @@ def _load_pt(self, model_file: str, head: str | None = None) -> None:
368368
# eager inference). Drop the latter and unwrap the former.
369369
cleaned: dict[str, Any] = {}
370370
compiled_marker = ".compiled_forward_lower."
371+
# Per-task buffer copies registered on _CompiledModel (bias_atom_e,
372+
# case_embd) — real values live on the original model's fitting net.
373+
task_buf_marker = "._task_"
371374
wrapper_infix = ".original_model."
372375
for key, value in state_dict.items():
373376
if compiled_marker in key:
374377
continue
378+
if task_buf_marker in key:
379+
continue
375380
if wrapper_infix in key:
376381
key = key.replace(wrapper_infix, ".", 1)
377382
cleaned[key] = value

0 commit comments

Comments
 (0)