perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks#5457
perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks#5457anyangml wants to merge 7 commits into
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughThis PR introduces task-structure-aware torch.compile caching for multi-task models. It extracts per-task fitting-net buffers, computes a shared-structure identity key, promotes those buffers into explicit FX symbolic inputs for graph reuse, updates checkpoint loading to skip task-buffer remnants, and converts training loss values to floats for display aggregation. ChangesTask-structure-aware torch.compile optimization
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR adjusts the pt_expt torch.compile path for multi-task training to reduce redundant compiled graphs (and associated memory/oom issues) by promoting per-task fitting-net buffers to explicit compiled-graph inputs and reusing compiled graphs across tasks when the model structure is shared.
Changes:
- Promote task-specific fitting-net buffers (
bias_atom_e,case_embd) into FX placeholders so one compiled graph can be reused with different per-task buffer values. - Add per-structure caching in the compile pipeline to avoid recompiling the same shared structure for each task.
- Make training-time logging robust by converting tensor scalars to Python floats before formatting/aggregation.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
deepmd/pt_expt/train/training.py |
Adds task-buffer promotion + compiled-graph reuse caching for multi-task compile; adjusts logging scalar handling. |
deepmd/pt_expt/infer/deep_eval.py |
Updates .pt checkpoint loading to ignore newly introduced _CompiledModel per-task buffer copies. |
Comments suppressed due to low confidence (1)
deepmd/pt_expt/train/training.py:1072
- There are existing pt_expt tests covering multi-task + torch.compile, but the new compiled-graph reuse path should be covered by a test that exercises a config where only some components are shared (e.g., fitting_net shared via shared_dict, descriptor not shared). That case would validate the structure-key logic and prevent accidental graph reuse across non-identical forward_lower graphs.
descriptor = model.get_descriptor()
if isinstance(descriptor, DescrptDPA1DP):
n_attn = descriptor.get_numb_attn_layer()
if n_attn > 0:
log.warning(
"Compiling DPA1/se_atten_v2 with %d attention "
"layer(s) (task=%s): the compiled forces/grads "
"are slightly hardware-sensitive (multi-thread "
"reduction order), and may not match the eager "
"path bit-for-bit. Use 'enable_compile: false' "
"or 'attn_layer: 0' for fully reproducible runs.",
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
🧹 Nitpick comments (3)
deepmd/pt_expt/train/training.py (3)
319-319: ⚡ Quick winAdd explicit
strict=Trueto zip call.The
zip()on line 319 iterates overtask_buf_orderandtask_buf_vals, which are guaranteed to have the same length by construction (lines 290-293). Addingstrict=Truedocuments this invariant and provides a runtime assertion if the construction logic ever changes.- for name, val in zip(task_buf_order, task_buf_vals): + for name, val in zip(task_buf_order, task_buf_vals, strict=True):As per coding guidelines, run
ruff check .before committing to catch linting issues.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` at line 319, The zip over task_buf_order and task_buf_vals in the loop using "for name, val in zip(task_buf_order, task_buf_vals):" should assert the equal-length invariant by adding strict=True; update that zip call to zip(task_buf_order, task_buf_vals, strict=True) so a runtime error surfaces if lengths diverge, then run ruff check . before committing to ensure linting passes.
314-334: ⚡ Quick winPotential issue with buffer restoration logic.
Lines 320 and 334 save and restore buffer entries, but if
originals[name]isNone(buffer didn't exist), line 334 sets_fitting._buffers[name] = Noneinstead of deleting the entry. This could leaveNoneentries in the buffer registry that weren't present before patching.Consider using conditional restoration:
for name, orig in originals.items(): if orig is not None: _fitting._buffers[name] = orig else: _fitting._buffers.pop(name, None)However, if the buffers are guaranteed to exist (since
_get_task_buffersonly extracts existing buffers), this may not be an issue in practice.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` around lines 314 - 334, The restoration currently writes None back into _fitting._buffers for entries that did not exist before, so change the finally-block that iterates originals (the dict populated from task_buf_order/task_buf_vals) to restore by reassigning when orig is not None and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the dictionary named originals and the finally block that resets _fitting._buffers and replace the unconditional assignment with a conditional restore/remove to avoid leaving None entries after model.forward_lower returns.
92-108: ⚡ Quick winClarify the child name check logic.
Line 103 compares child module names against
_TASK_SPECIFIC_BUFFER_NAMES, which contains buffer names ("bias_atom_e","case_embd"). Child modules fromnamed_children()typically have names like"nets","layers", etc., not buffer names. This check will almost always beTrue, making it effectively a no-op.If the intent is to skip the first child when it's task-specific, the logic may need adjustment. Otherwise, consider removing the check or adding a comment explaining why it's safe to use the first child's
id()directly.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` around lines 92 - 108, The code in _get_model_structure_key uses named_children() names compared against _TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names like "bias_atom_e"), but child module names come from named_children() and won't match buffer names, so the filter is effectively a no-op; fix by computing the set of task-specific buffer names from the fitting net (e.g., buffers = {n for n,_ in fitting.named_buffers()}) and then skip any child whose name appears in that buffer set (replace the current name check), or if the original intent was to just take the first non-task-specific child drop the faulty comparison and simply return id of the first child from fitting.named_children(); update _get_model_structure_key accordingly and keep reference to fitting, named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the code to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 319: The zip over task_buf_order and task_buf_vals in the loop using "for
name, val in zip(task_buf_order, task_buf_vals):" should assert the equal-length
invariant by adding strict=True; update that zip call to zip(task_buf_order,
task_buf_vals, strict=True) so a runtime error surfaces if lengths diverge, then
run ruff check . before committing to ensure linting passes.
- Around line 314-334: The restoration currently writes None back into
_fitting._buffers for entries that did not exist before, so change the
finally-block that iterates originals (the dict populated from
task_buf_order/task_buf_vals) to restore by reassigning when orig is not None
and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the
dictionary named originals and the finally block that resets _fitting._buffers
and replace the unconditional assignment with a conditional restore/remove to
avoid leaving None entries after model.forward_lower returns.
- Around line 92-108: The code in _get_model_structure_key uses named_children()
names compared against _TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names
like "bias_atom_e"), but child module names come from named_children() and won't
match buffer names, so the filter is effectively a no-op; fix by computing the
set of task-specific buffer names from the fitting net (e.g., buffers = {n for
n,_ in fitting.named_buffers()}) and then skip any child whose name appears in
that buffer set (replace the current name check), or if the original intent was
to just take the first non-task-specific child drop the faulty comparison and
simply return id of the first child from fitting.named_children(); update
_get_model_structure_key accordingly and keep reference to fitting,
named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the
code to change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: bac28a64-1300-4bb0-9d04-cf61734721ce
📒 Files selected for processing (2)
deepmd/pt_expt/infer/deep_eval.pydeepmd/pt_expt/train/training.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5457 +/- ##
==========================================
- Coverage 82.46% 82.44% -0.02%
==========================================
Files 829 829
Lines 88763 88876 +113
Branches 4225 4225
==========================================
+ Hits 73197 73278 +81
- Misses 14274 14307 +33
+ Partials 1292 1291 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 491-500: If self._task_buf_order is set but accessing buffers
fails, don't silently set task_buf_vals = (); instead, in the except
AttributeError branch for original_model.get_fitting_net()/getattr(...) raise a
clear RuntimeError (or ValueError) that mentions the missing fitting net or
buffer names and refers to self._task_buf_order so callers know why
compiled_forward_lower would fail; keep the existing else path that sets
task_buf_vals = () only when _task_buf_order is empty, and ensure the raised
message names the expected buffer attributes and the method
original_model.get_fitting_net to help locate the root cause.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: add25e51-694f-4a58-8e05-9d3e0d5909e9
📒 Files selected for processing (1)
deepmd/pt_expt/train/training.py
| if self._task_buf_order: | ||
| try: | ||
| _fitting = self.original_model.get_fitting_net() | ||
| task_buf_vals: tuple = tuple( | ||
| getattr(_fitting, name) for name in self._task_buf_order | ||
| ) | ||
| except AttributeError: | ||
| task_buf_vals = () | ||
| else: | ||
| task_buf_vals = () |
There was a problem hiding this comment.
Silent fallback may cause confusing downstream errors.
If _task_buf_order is non-empty (compiled graph expects buffer arguments) but the AttributeError catch triggers, the empty task_buf_vals passed to compiled_forward_lower will cause an argument count mismatch—a confusing error that hides the real issue.
Consider raising an informative error when buffers are expected but unavailable:
Proposed fix
if self._task_buf_order:
try:
_fitting = self.original_model.get_fitting_net()
task_buf_vals: tuple = tuple(
getattr(_fitting, name) for name in self._task_buf_order
)
except AttributeError:
- task_buf_vals = ()
+ raise RuntimeError(
+ f"Compiled graph expects task buffers {self._task_buf_order} "
+ "but fitting net or buffer attributes are unavailable."
+ )
else:
task_buf_vals = ()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt_expt/train/training.py` around lines 491 - 500, If
self._task_buf_order is set but accessing buffers fails, don't silently set
task_buf_vals = (); instead, in the except AttributeError branch for
original_model.get_fitting_net()/getattr(...) raise a clear RuntimeError (or
ValueError) that mentions the missing fitting net or buffer names and refers to
self._task_buf_order so callers know why compiled_forward_lower would fail; keep
the existing else path that sets task_buf_vals = () only when _task_buf_order is
empty, and ensure the raised message names the expected buffer attributes and
the method original_model.get_fitting_net to help locate the root cause.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 343-345: The zip call iterating over task_buf_order and
task_buf_vals should be made strict to satisfy Ruff B905: change
zip(task_buf_order, task_buf_vals) to zip(task_buf_order, task_buf_vals,
strict=True) in the block that checks name.startswith(_AM_PREFIX). Also remove
or rename the unused variable model_pred (found as model_pred in this file) to
_model_pred (or delete it) to resolve RUF059 so no unused binding remains.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 62caa6e7-1434-49b6-8f30-bfab6dd88904
📒 Files selected for processing (1)
deepmd/pt_expt/train/training.py
| if task_buf_order: | ||
| for name, val in zip(task_buf_order, task_buf_vals): | ||
| if name.startswith(_AM_PREFIX): |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
ruff check .Repository: deepmodeling/deepmd-kit
Length of output: 50381
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "## ruff check (target file only)"
ruff check deepmd/pt_expt/train/training.py
echo "## ruff format --check (target file only)"
ruff format --check deepmd/pt_expt/train/training.pyRepository: deepmodeling/deepmd-kit
Length of output: 1469
🏁 Script executed:
ruff check .Repository: deepmodeling/deepmd-kit
Length of output: 50381
🏁 Script executed:
ruff check .
ruff format .Repository: deepmodeling/deepmd-kit
Length of output: 50381
Fix Ruff B905 (zip strict) and RUF059 (unused model_pred) in deepmd/pt_expt/train/training.py.
- B905 (line 344):
zip(task_buf_order, task_buf_vals)needs explicitstrict=to keep both sequences aligned. - RUF059 (line 1354):
model_predis never used—remove it or rename to_model_pred.
Suggested fix
- for name, val in zip(task_buf_order, task_buf_vals):
+ for name, val in zip(task_buf_order, task_buf_vals, strict=True):- model_pred, loss, more_loss = self.wrapper(
+ _model_pred, loss, more_loss = self.wrapper(🧰 Tools
🪛 Ruff (0.15.14)
[warning] 344-344: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt_expt/train/training.py` around lines 343 - 345, The zip call
iterating over task_buf_order and task_buf_vals should be made strict to satisfy
Ruff B905: change zip(task_buf_order, task_buf_vals) to zip(task_buf_order,
task_buf_vals, strict=True) in the block that checks
name.startswith(_AM_PREFIX). Also remove or rename the unused variable
model_pred (found as model_pred in this file) to _model_pred (or delete it) to
resolve RUF059 so no unused binding remains.
njzjz-bot
left a comment
There was a problem hiding this comment.
I think there is still a correctness issue in the compiled graph cache key for multitask models.
_get_model_structure_key() currently uses only the id of the first non-task-specific child under fitting_net as the structure key. That can incorrectly treat two task models as sharing the same compiled forward_lower graph when they share only the fitting net but have different/non-shared descriptors or other atomic-model state.
forward_lower is not just the fitting net path: it also includes descriptor/atomic-model computation. When the compiled graph traced for task/model 1 is reused for task/model 2, _CompiledModel.forward() only supplies the extra task buffers (bias, case_embd, out_bias, out_std, etc.). It does not pass descriptor or other atomic-model parameters as dynamic inputs. So if the descriptor is not actually shared, or has different parameters/outputs, task/model 2 may run the graph captured from task/model 1, producing incorrect predictions and gradients.
I would suggest making the cache key include the identity/structure of all shared components that participate in forward_lower — at least the descriptor/atomic-model path in addition to fitting — or, more conservatively, only reusing a compiled graph when both descriptor and fitting components are confirmed to be the same shared objects. Otherwise each task should compile separately.
A regression test would be helpful: construct two tasks that share fitting_net but use distinguishable non-shared descriptors, enable compile, and verify both tasks match the uncompiled results/gradients.
Minor follow-up: _CompiledModel.forward() currently swallows AttributeError while collecting task buffers and then passes an empty tuple, which will likely fail later with a less clear compiled-argument mismatch. It would be better to raise a direct RuntimeError with the missing task-buffer name/context.
Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5)
Is that a valid use case? |
|
Yes, I think it is a valid use case unless the model configuration or My concern is not limited to the current test setup. The compiled callable is So I would either:
If partial sharing of fitting without descriptor is impossible by construction, then an assertion/check documenting that invariant would also resolve this concern. Otherwise I think compiling separately is safer than silently reusing a graph across non-equivalent task models. Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5) |
wanghan-iapcm
left a comment
There was a problem hiding this comment.
Code review
Found 2 issues:
- Cache key under-specifies what is baked into the compiled graph.
_get_model_structure_keyreturnsid()of the fitting net's first non-task-specific child, but onlybias_atom_e/case_embd/out_bias/out_stdare promoted to FX placeholders — descriptor parameters/buffers (attention weights, type-embeddingdavg/dstd, exclude-mask, etc.) remain baked-in constants in the traced graph. Two tasks that sharefitting_netbut have different descriptors (or differ inntypes/dim_case_embd/sel/rcut) produce the same structure key and silently reuse task 0's compiled graph, yielding wrong predictions and gradients. Same concern previously raised by njzjz-bot and Copilot; unresolved on the current head.
deepmd-kit/deepmd/pt_expt/train/training.py
Lines 109 to 130 in 9ce8d3e
deepmd-kit/deepmd/pt_expt/train/training.py
Lines 1126 to 1188 in 9ce8d3e
Suggested fix: include the descriptor (and any other non-fitting components participating in forward_lower) in the key, e.g. (id(model.get_descriptor()), id(fitting_first_child)), or build the key from tuple(id(p) for _, p in model.named_parameters()) + tuple(id(b) for n, b in model.named_buffers() if n.rsplit(".", 1)[-1] not in _TASK_SPECIFIC_BUFFER_NAMES + _ATOMIC_MODEL_TASK_BUFFER_NAMES). A regression test where two tasks share fitting but have distinguishable descriptors and assert compiled outputs match eager would catch this.
- The promoted-buffer set is incomplete. Only
bias_atom_e,case_embd,out_bias,out_stdare promoted to FX placeholders, but fitting nets carry other per-task statistics buffers that are silently baked in as task-0 constants — notablyfparam_avg/fparam_inv_stdandaparam_avg/aparam_inv_std(set per task bymake_stat_inputfrom each task's data distribution), and the descriptor'sEnvMatdavg/dstdif descriptor stats also vary per task. For multi-task configs that usefparam/aparamwith shared fitting but task-local stats, all tasks would end up running with task 0's normalization, producing incorrect inputs to the fitting MLP.
deepmd-kit/deepmd/pt_expt/train/training.py
Lines 72 to 86 in 9ce8d3e
Suggested fix: either (a) expand _TASK_SPECIFIC_BUFFER_NAMES to include fparam_avg, fparam_inv_std, aparam_avg, aparam_inv_std and have _get_task_buffers enumerate them, or (b) auto-detect per-task buffers by diffing buffer identities across tasks in _compile_model rather than maintaining a hardcoded allow-list that will drift as descriptors and fittings evolve.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 1145: The for-loop uses an unused loop variable `sk` causing Ruff B007;
change the iteration to loop over the group values directly by replacing `for
sk, group_keys in _groups.items():` with iterating `_groups.values()` (i.e. `for
group_keys in _groups.values():`) so `sk` is removed and the loop body continues
to use `group_keys` unchanged.
In `@source/tests/pt_expt/test_training.py`:
- Around line 1301-1527: The test
TestCompiledSharedFittingDifferentDescriptor.test_compiled_matches_eager_per_task
is unbounded and needs the repo-standard 60s timeout; add a 60s test timeout
(e.g. annotate the test method or the test class with pytest.mark.timeout(60)
and import pytest) so the compiled/eager training regression cannot hang CI
longer than allowed, ensuring the decorator is applied to
test_compiled_matches_eager_per_task (or the containing
TestCompiledSharedFittingDifferentDescriptor class).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 33c2d228-ada9-4027-b0f6-2a1f979e9ce9
📒 Files selected for processing (2)
deepmd/pt_expt/train/training.pysource/tests/pt_expt/test_training.py
njzjz-bot
left a comment
There was a problem hiding this comment.
Re-reviewed the current head (8c25a2289). The main correctness concern from the previous review looks addressed:
_get_model_structure_key()now includes the descriptor identity, so shared fitting with different descriptors no longer collides.- fitting-net task buffers are auto-detected by buffer identity within each reusable structure group, and
out_bias/out_stdare promoted fromatomic_model. - the new regression test covers the shared-fitting / different-descriptor case.
- required checks are green.
I do not see a remaining correctness blocker from the original cache-key issue. The only things I would still clean up before merge are non-blocking maintenance items already pointed out by CodeRabbit: the unused sk loop variable, adding the timeout mark to the new compile regression test, and making the missing promoted-buffer path in _CompiledModel.forward() fail with a direct RuntimeError rather than falling through to an argument mismatch.
— OpenClaw 2026.5.12 (f066dd2) (model: custom-chat-jinzhezeng-group/gpt-5.5)
make dataset embedding and energy bias as input not buffer for compile, this allows multitask training share compiled model thus resolve OOM and NCCL timeout issue. Since the empty_cache and del are removed, no GC complaints.
Regression Test

Summary by CodeRabbit
New Features
Bug Fixes
Tests