Skip to content

perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks#5457

Open
anyangml wants to merge 7 commits into
deepmodeling:masterfrom
anyangml:fix/compile-multitask
Open

perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks#5457
anyangml wants to merge 7 commits into
deepmodeling:masterfrom
anyangml:fix/compile-multitask

Conversation

@anyangml
Copy link
Copy Markdown
Collaborator

@anyangml anyangml commented May 26, 2026

make dataset embedding and energy bias as input not buffer for compile, this allows multitask training share compiled model thus resolve OOM and NCCL timeout issue. Since the empty_cache and del are removed, no GC complaints.

Regression Test
lcurve

Summary by CodeRabbit

  • New Features

    • Multi-task training now groups models by structure and caches/reuses compiled computation graphs, with per-task buffer handling to support shared fitting nets.
  • Bug Fixes

    • Checkpoint loading now skips extraneous per-task buffer entries so only original model parameters are restored.
    • Training aggregation coerces tensor-like loss/metric values to floats for accurate reporting.
  • Tests

    • Added a regression test ensuring compiled vs. eager outputs match per task for shared-fitting, different-descriptor setups.

Review Change Stack

Copilot AI review requested due to automatic review settings May 26, 2026 02:03
@dosubot dosubot Bot added the bug label May 26, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 26, 2026

📝 Walkthrough

Walkthrough

This PR introduces task-structure-aware torch.compile caching for multi-task models. It extracts per-task fitting-net buffers, computes a shared-structure identity key, promotes those buffers into explicit FX symbolic inputs for graph reuse, updates checkpoint loading to skip task-buffer remnants, and converts training loss values to floats for display aggregation.

Changes

Task-structure-aware torch.compile optimization

Layer / File(s) Summary
Task buffer detection & structure key
deepmd/pt_expt/train/training.py
Adds buffer-name constants, _detect_task_buffers, and updates _get_model_structure_key to return a tuple capturing descriptor identity and fitting-net child identity.
_trace_and_compile buffer promotion
deepmd/pt_expt/train/training.py
Extends _trace_and_compile to accept task_buffers, temporarily patch fitting-net/atomic-model _buffers so promoted buffers become FX placeholders, include them as extra symbolic inputs to make_fx, and return (compiled_module, task_buf_order).
_CompiledModel runtime ordering & forward
deepmd/pt_expt/train/training.py
_CompiledModel.__init__ accepts task_buf_order; forward() fetches current-task buffer tensors (atomic-model am/ prefixed and fitting-net buffers) and passes them as variadic args into the compiled forward_lower.
Structure-key compilation cache in _compile_model
deepmd/pt_expt/train/training.py
Pre-pass groups tasks by structure_key, detects per-task promoted buffers per group, reuses cached (compiled_lower, task_buf_order) when available or compiles per-structure and caches; constructs _CompiledModel(..., task_buf_order).
Training aggregation float coercion
deepmd/pt_expt/train/training.py
Adds _to_float() and coerces tensor-like loss/metric entries to Python floats during single-task and multi-task train/validation aggregation (excluding l2_ metrics).
Skip task buffer entries during checkpoint cleanup
deepmd/pt_expt/infer/deep_eval.py
DeepEval._load_pt now filters out ._task_ marked keys (in addition to compiled-forward-lower keys) when preparing a checkpoint state_dict for loading.
Regression test: shared fitting, different descriptor
source/tests/pt_expt/test_training.py
Adds TestCompiledSharedFittingDifferentDescriptor to exercise multi-task compilation reuse, structure-key differences, eager→compiled weight sync, and compiled-vs-eager numeric equivalence per task.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5423: Both PRs modify deepmd/pt_expt/infer/deep_eval.py—specifically DeepEval._load_pt's state-dict cleanup/filtered-key loading logic—so the main PR's new "._task_" key omission is directly related.
  • deepmodeling/deepmd-kit#5397: Related multi-task torch.compile/FX compilation work that this PR extends with per-task buffer promotion and structure-key reuse.

Suggested labels

enhancement

Suggested reviewers

  • njzjz
  • wanghan-iapcm
  • OutisLi
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.14% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly describes the main change: sharing compiled forward_lower across multi-task shared-fitting tasks for performance optimization.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment thread deepmd/pt_expt/train/training.py Fixed
Comment thread deepmd/pt_expt/train/training.py Fixed
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adjusts the pt_expt torch.compile path for multi-task training to reduce redundant compiled graphs (and associated memory/oom issues) by promoting per-task fitting-net buffers to explicit compiled-graph inputs and reusing compiled graphs across tasks when the model structure is shared.

Changes:

  • Promote task-specific fitting-net buffers (bias_atom_e, case_embd) into FX placeholders so one compiled graph can be reused with different per-task buffer values.
  • Add per-structure caching in the compile pipeline to avoid recompiling the same shared structure for each task.
  • Make training-time logging robust by converting tensor scalars to Python floats before formatting/aggregation.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
deepmd/pt_expt/train/training.py Adds task-buffer promotion + compiled-graph reuse caching for multi-task compile; adjusts logging scalar handling.
deepmd/pt_expt/infer/deep_eval.py Updates .pt checkpoint loading to ignore newly introduced _CompiledModel per-task buffer copies.
Comments suppressed due to low confidence (1)

deepmd/pt_expt/train/training.py:1072

  • There are existing pt_expt tests covering multi-task + torch.compile, but the new compiled-graph reuse path should be covered by a test that exercises a config where only some components are shared (e.g., fitting_net shared via shared_dict, descriptor not shared). That case would validate the structure-key logic and prevent accidental graph reuse across non-identical forward_lower graphs.
            descriptor = model.get_descriptor()
            if isinstance(descriptor, DescrptDPA1DP):
                n_attn = descriptor.get_numb_attn_layer()
                if n_attn > 0:
                    log.warning(
                        "Compiling DPA1/se_atten_v2 with %d attention "
                        "layer(s) (task=%s): the compiled forces/grads "
                        "are slightly hardware-sensitive (multi-thread "
                        "reduction order), and may not match the eager "
                        "path bit-for-bit.  Use 'enable_compile: false' "
                        "or 'attn_layer: 0' for fully reproducible runs.",

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt_expt/train/training.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
deepmd/pt_expt/train/training.py (3)

319-319: ⚡ Quick win

Add explicit strict=True to zip call.

The zip() on line 319 iterates over task_buf_order and task_buf_vals, which are guaranteed to have the same length by construction (lines 290-293). Adding strict=True documents this invariant and provides a runtime assertion if the construction logic ever changes.

-            for name, val in zip(task_buf_order, task_buf_vals):
+            for name, val in zip(task_buf_order, task_buf_vals, strict=True):

As per coding guidelines, run ruff check . before committing to catch linting issues.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` at line 319, The zip over task_buf_order
and task_buf_vals in the loop using "for name, val in zip(task_buf_order,
task_buf_vals):" should assert the equal-length invariant by adding strict=True;
update that zip call to zip(task_buf_order, task_buf_vals, strict=True) so a
runtime error surfaces if lengths diverge, then run ruff check . before
committing to ensure linting passes.

314-334: ⚡ Quick win

Potential issue with buffer restoration logic.

Lines 320 and 334 save and restore buffer entries, but if originals[name] is None (buffer didn't exist), line 334 sets _fitting._buffers[name] = None instead of deleting the entry. This could leave None entries in the buffer registry that weren't present before patching.

Consider using conditional restoration:

for name, orig in originals.items():
    if orig is not None:
        _fitting._buffers[name] = orig
    else:
        _fitting._buffers.pop(name, None)

However, if the buffers are guaranteed to exist (since _get_task_buffers only extracts existing buffers), this may not be an issue in practice.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 314 - 334, The restoration
currently writes None back into _fitting._buffers for entries that did not exist
before, so change the finally-block that iterates originals (the dict populated
from task_buf_order/task_buf_vals) to restore by reassigning when orig is not
None and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the
dictionary named originals and the finally block that resets _fitting._buffers
and replace the unconditional assignment with a conditional restore/remove to
avoid leaving None entries after model.forward_lower returns.

92-108: ⚡ Quick win

Clarify the child name check logic.

Line 103 compares child module names against _TASK_SPECIFIC_BUFFER_NAMES, which contains buffer names ("bias_atom_e", "case_embd"). Child modules from named_children() typically have names like "nets", "layers", etc., not buffer names. This check will almost always be True, making it effectively a no-op.

If the intent is to skip the first child when it's task-specific, the logic may need adjustment. Otherwise, consider removing the check or adding a comment explaining why it's safe to use the first child's id() directly.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 92 - 108, The code in
_get_model_structure_key uses named_children() names compared against
_TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names like "bias_atom_e"), but
child module names come from named_children() and won't match buffer names, so
the filter is effectively a no-op; fix by computing the set of task-specific
buffer names from the fitting net (e.g., buffers = {n for n,_ in
fitting.named_buffers()}) and then skip any child whose name appears in that
buffer set (replace the current name check), or if the original intent was to
just take the first non-task-specific child drop the faulty comparison and
simply return id of the first child from fitting.named_children(); update
_get_model_structure_key accordingly and keep reference to fitting,
named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the
code to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 319: The zip over task_buf_order and task_buf_vals in the loop using "for
name, val in zip(task_buf_order, task_buf_vals):" should assert the equal-length
invariant by adding strict=True; update that zip call to zip(task_buf_order,
task_buf_vals, strict=True) so a runtime error surfaces if lengths diverge, then
run ruff check . before committing to ensure linting passes.
- Around line 314-334: The restoration currently writes None back into
_fitting._buffers for entries that did not exist before, so change the
finally-block that iterates originals (the dict populated from
task_buf_order/task_buf_vals) to restore by reassigning when orig is not None
and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the
dictionary named originals and the finally block that resets _fitting._buffers
and replace the unconditional assignment with a conditional restore/remove to
avoid leaving None entries after model.forward_lower returns.
- Around line 92-108: The code in _get_model_structure_key uses named_children()
names compared against _TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names
like "bias_atom_e"), but child module names come from named_children() and won't
match buffer names, so the filter is effectively a no-op; fix by computing the
set of task-specific buffer names from the fitting net (e.g., buffers = {n for
n,_ in fitting.named_buffers()}) and then skip any child whose name appears in
that buffer set (replace the current name check), or if the original intent was
to just take the first non-task-specific child drop the faulty comparison and
simply return id of the first child from fitting.named_children(); update
_get_model_structure_key accordingly and keep reference to fitting,
named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the
code to change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: bac28a64-1300-4bb0-9d04-cf61734721ce

📥 Commits

Reviewing files that changed from the base of the PR and between f39a081 and 4cee0bf.

📒 Files selected for processing (2)
  • deepmd/pt_expt/infer/deep_eval.py
  • deepmd/pt_expt/train/training.py

@codecov
Copy link
Copy Markdown

codecov Bot commented May 26, 2026

Codecov Report

❌ Patch coverage is 75.18248% with 34 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.44%. Comparing base (f39a081) to head (8c25a22).
⚠️ Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt_expt/train/training.py 75.37% 33 Missing ⚠️
deepmd/pt_expt/infer/deep_eval.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5457      +/-   ##
==========================================
- Coverage   82.46%   82.44%   -0.02%     
==========================================
  Files         829      829              
  Lines       88763    88876     +113     
  Branches     4225     4225              
==========================================
+ Hits        73197    73278      +81     
- Misses      14274    14307      +33     
+ Partials     1292     1291       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 491-500: If self._task_buf_order is set but accessing buffers
fails, don't silently set task_buf_vals = (); instead, in the except
AttributeError branch for original_model.get_fitting_net()/getattr(...) raise a
clear RuntimeError (or ValueError) that mentions the missing fitting net or
buffer names and refers to self._task_buf_order so callers know why
compiled_forward_lower would fail; keep the existing else path that sets
task_buf_vals = () only when _task_buf_order is empty, and ensure the raised
message names the expected buffer attributes and the method
original_model.get_fitting_net to help locate the root cause.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: add25e51-694f-4a58-8e05-9d3e0d5909e9

📥 Commits

Reviewing files that changed from the base of the PR and between 4cee0bf and f3e29fe.

📒 Files selected for processing (1)
  • deepmd/pt_expt/train/training.py

Comment on lines +491 to +500
if self._task_buf_order:
try:
_fitting = self.original_model.get_fitting_net()
task_buf_vals: tuple = tuple(
getattr(_fitting, name) for name in self._task_buf_order
)
except AttributeError:
task_buf_vals = ()
else:
task_buf_vals = ()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Silent fallback may cause confusing downstream errors.

If _task_buf_order is non-empty (compiled graph expects buffer arguments) but the AttributeError catch triggers, the empty task_buf_vals passed to compiled_forward_lower will cause an argument count mismatch—a confusing error that hides the real issue.

Consider raising an informative error when buffers are expected but unavailable:

Proposed fix
         if self._task_buf_order:
             try:
                 _fitting = self.original_model.get_fitting_net()
                 task_buf_vals: tuple = tuple(
                     getattr(_fitting, name) for name in self._task_buf_order
                 )
             except AttributeError:
-                task_buf_vals = ()
+                raise RuntimeError(
+                    f"Compiled graph expects task buffers {self._task_buf_order} "
+                    "but fitting net or buffer attributes are unavailable."
+                )
         else:
             task_buf_vals = ()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 491 - 500, If
self._task_buf_order is set but accessing buffers fails, don't silently set
task_buf_vals = (); instead, in the except AttributeError branch for
original_model.get_fitting_net()/getattr(...) raise a clear RuntimeError (or
ValueError) that mentions the missing fitting net or buffer names and refers to
self._task_buf_order so callers know why compiled_forward_lower would fail; keep
the existing else path that sets task_buf_vals = () only when _task_buf_order is
empty, and ensure the raised message names the expected buffer attributes and
the method original_model.get_fitting_net to help locate the root cause.

Comment thread deepmd/pt_expt/train/training.py Fixed
Comment thread deepmd/pt_expt/train/training.py Fixed
@anyangml anyangml requested review from njzjz and wanghan-iapcm May 26, 2026 09:27
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 343-345: The zip call iterating over task_buf_order and
task_buf_vals should be made strict to satisfy Ruff B905: change
zip(task_buf_order, task_buf_vals) to zip(task_buf_order, task_buf_vals,
strict=True) in the block that checks name.startswith(_AM_PREFIX). Also remove
or rename the unused variable model_pred (found as model_pred in this file) to
_model_pred (or delete it) to resolve RUF059 so no unused binding remains.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 62caa6e7-1434-49b6-8f30-bfab6dd88904

📥 Commits

Reviewing files that changed from the base of the PR and between f3e29fe and 9ce8d3e.

📒 Files selected for processing (1)
  • deepmd/pt_expt/train/training.py

Comment on lines +343 to +345
if task_buf_order:
for name, val in zip(task_buf_order, task_buf_vals):
if name.startswith(_AM_PREFIX):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
ruff check .

Repository: deepmodeling/deepmd-kit

Length of output: 50381


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "## ruff check (target file only)"
ruff check deepmd/pt_expt/train/training.py

echo "## ruff format --check (target file only)"
ruff format --check deepmd/pt_expt/train/training.py

Repository: deepmodeling/deepmd-kit

Length of output: 1469


🏁 Script executed:

ruff check .

Repository: deepmodeling/deepmd-kit

Length of output: 50381


🏁 Script executed:

ruff check .
ruff format .

Repository: deepmodeling/deepmd-kit

Length of output: 50381


Fix Ruff B905 (zip strict) and RUF059 (unused model_pred) in deepmd/pt_expt/train/training.py.

  • B905 (line 344): zip(task_buf_order, task_buf_vals) needs explicit strict= to keep both sequences aligned.
  • RUF059 (line 1354): model_pred is never used—remove it or rename to _model_pred.
Suggested fix
-            for name, val in zip(task_buf_order, task_buf_vals):
+            for name, val in zip(task_buf_order, task_buf_vals, strict=True):
-            model_pred, loss, more_loss = self.wrapper(
+            _model_pred, loss, more_loss = self.wrapper(
🧰 Tools
🪛 Ruff (0.15.14)

[warning] 344-344: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 343 - 345, The zip call
iterating over task_buf_order and task_buf_vals should be made strict to satisfy
Ruff B905: change zip(task_buf_order, task_buf_vals) to zip(task_buf_order,
task_buf_vals, strict=True) in the block that checks
name.startswith(_AM_PREFIX). Also remove or rename the unused variable
model_pred (found as model_pred in this file) to _model_pred (or delete it) to
resolve RUF059 so no unused binding remains.

Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is still a correctness issue in the compiled graph cache key for multitask models.

_get_model_structure_key() currently uses only the id of the first non-task-specific child under fitting_net as the structure key. That can incorrectly treat two task models as sharing the same compiled forward_lower graph when they share only the fitting net but have different/non-shared descriptors or other atomic-model state.

forward_lower is not just the fitting net path: it also includes descriptor/atomic-model computation. When the compiled graph traced for task/model 1 is reused for task/model 2, _CompiledModel.forward() only supplies the extra task buffers (bias, case_embd, out_bias, out_std, etc.). It does not pass descriptor or other atomic-model parameters as dynamic inputs. So if the descriptor is not actually shared, or has different parameters/outputs, task/model 2 may run the graph captured from task/model 1, producing incorrect predictions and gradients.

I would suggest making the cache key include the identity/structure of all shared components that participate in forward_lower — at least the descriptor/atomic-model path in addition to fitting — or, more conservatively, only reusing a compiled graph when both descriptor and fitting components are confirmed to be the same shared objects. Otherwise each task should compile separately.

A regression test would be helpful: construct two tasks that share fitting_net but use distinguishable non-shared descriptors, enable compile, and verify both tasks match the uncompiled results/gradients.

Minor follow-up: _CompiledModel.forward() currently swallows AttributeError while collecting task buffers and then passes an empty tuple, which will likely fail later with a less clear compiled-argument mismatch. It would be better to raise a direct RuntimeError with the missing task-buffer name/context.

Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5)

@anyangml
Copy link
Copy Markdown
Collaborator Author

I think there is still a correctness issue in the compiled graph cache key for multitask models.

_get_model_structure_key() currently uses only the id of the first non-task-specific child under fitting_net as the structure key. That can incorrectly treat two task models as sharing the same compiled forward_lower graph when they share only the fitting net but have different/non-shared descriptors or other atomic-model state.

forward_lower is not just the fitting net path: it also includes descriptor/atomic-model computation. When the compiled graph traced for task/model 1 is reused for task/model 2, _CompiledModel.forward() only supplies the extra task buffers (bias, case_embd, out_bias, out_std, etc.). It does not pass descriptor or other atomic-model parameters as dynamic inputs. So if the descriptor is not actually shared, or has different parameters/outputs, task/model 2 may run the graph captured from task/model 1, producing incorrect predictions and gradients.

I would suggest making the cache key include the identity/structure of all shared components that participate in forward_lower — at least the descriptor/atomic-model path in addition to fitting — or, more conservatively, only reusing a compiled graph when both descriptor and fitting components are confirmed to be the same shared objects. Otherwise each task should compile separately.

A regression test would be helpful: construct two tasks that share fitting_net but use distinguishable non-shared descriptors, enable compile, and verify both tasks match the uncompiled results/gradients.

Minor follow-up: _CompiledModel.forward() currently swallows AttributeError while collecting task buffers and then passes an empty tuple, which will likely fail later with a less clear compiled-argument mismatch. It would be better to raise a direct RuntimeError with the missing task-buffer name/context.

Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5)

Is that a valid use case?

@njzjz-bot
Copy link
Copy Markdown
Contributor

Yes, I think it is a valid use case unless the model configuration or share_params logic explicitly forbids it.

My concern is not limited to the current test setup. The compiled callable is forward_lower, so the traced graph covers the descriptor/atomic-model path before the fitting net as well. If two task models share only fitting_net but keep different descriptors (or any other non-shared state participating in forward_lower), a cache key based only on a fitting-net child can collide. In that case task 2 would reuse the graph captured from task 1, while _CompiledModel.forward() only varies the promoted task buffers; it does not make descriptor parameters/buffers dynamic inputs.

So I would either:

  1. make the reuse condition conservative: reuse only when the whole forward_lower-participating model structure is shared, e.g. descriptor/atomic model and fitting are the same shared objects; or
  2. include the identities of all relevant parameters/buffers/modules in the cache key, excluding only the intentionally promoted per-task buffers.

If partial sharing of fitting without descriptor is impossible by construction, then an assertion/check documenting that invariant would also resolve this concern. Otherwise I think compiling separately is safer than silently reusing a graph across non-equivalent task models.

Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5)

Copy link
Copy Markdown
Collaborator

@wanghan-iapcm wanghan-iapcm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code review

Found 2 issues:

  1. Cache key under-specifies what is baked into the compiled graph. _get_model_structure_key returns id() of the fitting net's first non-task-specific child, but only bias_atom_e/case_embd/out_bias/out_std are promoted to FX placeholders — descriptor parameters/buffers (attention weights, type-embedding davg/dstd, exclude-mask, etc.) remain baked-in constants in the traced graph. Two tasks that share fitting_net but have different descriptors (or differ in ntypes / dim_case_embd / sel / rcut) produce the same structure key and silently reuse task 0's compiled graph, yielding wrong predictions and gradients. Same concern previously raised by njzjz-bot and Copilot; unresolved on the current head.

def _get_model_structure_key(model: torch.nn.Module) -> int:
"""Return an id that is identical for all tasks that share a fitting net.
After ``share_params``, the fitting net's child sub-modules are literally
the same Python objects across tasks. The first non-task-specific child's
``id()`` is therefore the same for all shared tasks and unique across
unrelated models.
"""
try:
fitting = model.get_fitting_net()
for name, child in fitting.named_children():
if name not in _TASK_SPECIFIC_BUFFER_NAMES:
return id(child)
except AttributeError:
pass
return id(model)
# ---------------------------------------------------------------------------
# Helper: loss factory (reused from pt)

structure_key = _get_model_structure_key(model)
task_bufs = _get_task_buffers(model)
if structure_key in _compiled_by_structure:
# Shared structure: reuse the already-compiled graph.
compiled_lower, task_buf_order = _compiled_by_structure[structure_key]
log.info(
"Reusing compiled graph for task=%s (shared model structure).",
task_key,
)
else:
inp, _ = self.get_data(is_train=True, task_key=task_key)
coord = inp["coord"].detach()
atype = inp["atype"].detach()
box = inp.get("box")
if box is not None:
box = box.detach()
nframes, nloc = atype.shape[:2]
coord_3d = coord.reshape(nframes, nloc, 3)
box_flat = box.reshape(nframes, 9) if box is not None else None
if box_flat is not None:
coord_norm = normalize_coord(
coord_3d, box_flat.reshape(nframes, 3, 3)
)
else:
coord_norm = coord_3d
ext_coord, ext_atype, mapping = extend_coord_with_ghosts(
coord_norm, atype, box_flat, model.get_rcut()
)
nlist_t = build_neighbor_list(
ext_coord,
ext_atype,
nloc,
model.get_rcut(),
model.get_sel(),
distinguish_types=False,
)
ext_coord = ext_coord.reshape(nframes, -1, 3)
fparam = inp.get("fparam")
aparam = inp.get("aparam")
charge_spin = inp.get("charge_spin")
compiled_lower, task_buf_order = _trace_and_compile(
model,
ext_coord,
ext_atype,
nlist_t,
mapping,
fparam,
aparam,
charge_spin=charge_spin,
task_buffers=task_bufs if task_bufs else None,
compile_opts=compile_opts,
)
_compiled_by_structure[structure_key] = (compiled_lower, task_buf_order)
wrapper_mod.model[task_key] = _CompiledModel(
model, compiled_lower, task_buf_order, task_bufs

Suggested fix: include the descriptor (and any other non-fitting components participating in forward_lower) in the key, e.g. (id(model.get_descriptor()), id(fitting_first_child)), or build the key from tuple(id(p) for _, p in model.named_parameters()) + tuple(id(b) for n, b in model.named_buffers() if n.rsplit(".", 1)[-1] not in _TASK_SPECIFIC_BUFFER_NAMES + _ATOMIC_MODEL_TASK_BUFFER_NAMES). A regression test where two tasks share fitting but have distinguishable descriptors and assert compiled outputs match eager would catch this.

  1. The promoted-buffer set is incomplete. Only bias_atom_e, case_embd, out_bias, out_std are promoted to FX placeholders, but fitting nets carry other per-task statistics buffers that are silently baked in as task-0 constants — notably fparam_avg/fparam_inv_std and aparam_avg/aparam_inv_std (set per task by make_stat_input from each task's data distribution), and the descriptor's EnvMat davg/dstd if descriptor stats also vary per task. For multi-task configs that use fparam/aparam with shared fitting but task-local stats, all tasks would end up running with task 0's normalization, producing incorrect inputs to the fitting MLP.

# Buffer names in the fitting net that differ per task after share_params;
# everything else in the fitting net is the same Python object across tasks.
_TASK_SPECIFIC_BUFFER_NAMES: tuple[str, ...] = ("bias_atom_e", "case_embd")
# Buffer names in atomic_model that are per-task (energy/output statistics).
# These live one level above the fitting net and are not reached by
# fitting-net share_params, so they must also be promoted to FX placeholders.
_ATOMIC_MODEL_TASK_BUFFER_NAMES: tuple[str, ...] = ("out_bias", "out_std")
# Prefix used in task_buf_order keys to distinguish atomic_model buffers
# from fitting-net buffers.
_AM_PREFIX = "am/"

Suggested fix: either (a) expand _TASK_SPECIFIC_BUFFER_NAMES to include fparam_avg, fparam_inv_std, aparam_avg, aparam_inv_std and have _get_task_buffers enumerate them, or (b) auto-detect per-task buffers by diffing buffer identities across tasks in _compile_model rather than maintaining a hardcoded allow-list that will drift as descriptors and fittings evolve.

@wanghan-iapcm wanghan-iapcm changed the title Fix: compile multitask perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks May 27, 2026
Comment thread deepmd/pt_expt/train/training.py Dismissed
Comment thread deepmd/pt_expt/train/training.py Dismissed
Comment thread deepmd/pt_expt/train/training.py Dismissed
Comment thread deepmd/pt_expt/train/training.py Dismissed
Comment thread deepmd/pt_expt/train/training.py Dismissed
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 1145: The for-loop uses an unused loop variable `sk` causing Ruff B007;
change the iteration to loop over the group values directly by replacing `for
sk, group_keys in _groups.items():` with iterating `_groups.values()` (i.e. `for
group_keys in _groups.values():`) so `sk` is removed and the loop body continues
to use `group_keys` unchanged.

In `@source/tests/pt_expt/test_training.py`:
- Around line 1301-1527: The test
TestCompiledSharedFittingDifferentDescriptor.test_compiled_matches_eager_per_task
is unbounded and needs the repo-standard 60s timeout; add a 60s test timeout
(e.g. annotate the test method or the test class with pytest.mark.timeout(60)
and import pytest) so the compiled/eager training regression cannot hang CI
longer than allowed, ensuring the decorator is applied to
test_compiled_matches_eager_per_task (or the containing
TestCompiledSharedFittingDifferentDescriptor class).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 33c2d228-ada9-4027-b0f6-2a1f979e9ce9

📥 Commits

Reviewing files that changed from the base of the PR and between 9ce8d3e and 8c25a22.

📒 Files selected for processing (2)
  • deepmd/pt_expt/train/training.py
  • source/tests/pt_expt/test_training.py

Comment thread deepmd/pt_expt/train/training.py
Comment thread source/tests/pt_expt/test_training.py
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-reviewed the current head (8c25a2289). The main correctness concern from the previous review looks addressed:

  • _get_model_structure_key() now includes the descriptor identity, so shared fitting with different descriptors no longer collides.
  • fitting-net task buffers are auto-detected by buffer identity within each reusable structure group, and out_bias/out_std are promoted from atomic_model.
  • the new regression test covers the shared-fitting / different-descriptor case.
  • required checks are green.

I do not see a remaining correctness blocker from the original cache-key issue. The only things I would still clean up before merge are non-blocking maintenance items already pointed out by CodeRabbit: the unused sk loop variable, adding the timeout mark to the new compile regression test, and making the missing promoted-buffer path in _CompiledModel.forward() fail with a direct RuntimeError rather than falling through to an argument mismatch.

— OpenClaw 2026.5.12 (f066dd2) (model: custom-chat-jinzhezeng-group/gpt-5.5)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants