Skip to content

fix(pt): recognize AOTInductor-wrapped CUDA OOM in AutoBatchSize#5418

Merged
njzjz merged 3 commits intodeepmodeling:masterfrom
OutisLi:pr/bs
Apr 26, 2026
Merged

fix(pt): recognize AOTInductor-wrapped CUDA OOM in AutoBatchSize#5418
njzjz merged 3 commits intodeepmodeling:masterfrom
OutisLi:pr/bs

Conversation

@OutisLi
Copy link
Copy Markdown
Collaborator

@OutisLi OutisLi commented Apr 24, 2026

When running dp --pt-expt test (or any path that goes through deepmd.pt_expt.infer.deep_eval) against a .pt2 AOTInductor package, AutoBatchSize doubles the batch on every success. For models with a large sel the exploration eventually saturates GPU memory, and the CUDA caching allocator raises the usual CUDA out of memory from inside the AOTInductor runtime. AOTInductor then rewraps that error as a generic
RuntimeError: run_func_(...) API call failed at
.../aoti_runner/model_container_runner.cpp, line 144
The original "CUDA out of memory" text is printed only to stderr, so the old is_oom_error -- which keyed on a short list of substrings in e.args[0] -- never matched. execute() therefore did not shrink the batch; the exception propagated and the run crashed on a GPU that was otherwise completely idle (as confirmed by monitoring nvidia-smi --query-compute-apps, which showed dp itself as the sole consumer holding tens of GiB just before the failure). Widen is_oom_error to:

  • walk the exception chain via __cause__ / __context__, so that a future PyTorch preserving the original OOM text is handled for free;
  • keep matching the four plain CUDA OOM markers on every message in the chain;
  • additionally treat the AOTInductor wrapper signature (run_func_( plus model_container_runner) as an OOM candidate. If the AOTInductor wrapper ever hides a non-OOM failure, the batch shrinker will halve down to 1 and then raise OutOfMemoryError, so the fallback is bounded -- non-OOM bugs still surface with a clear terminal error rather than being silently retried forever.

Summary by CodeRabbit

  • Bug Fixes

    • Improved out-of-memory detection to catch more CUDA memory exhaustion scenarios, including wrapped/instrumented failures; now reliably clears CUDA cached memory when OOM conditions are identified to reduce cascading failures.
  • Tests

    • Added unit tests validating OOM detection across varied exception shapes and confirming CUDA cache clearing is invoked for detected OOMs.

Copilot AI review requested due to automatic review settings April 24, 2026 02:25
@dosubot dosubot Bot added the bug label Apr 24, 2026
@OutisLi OutisLi requested a review from njzjz April 24, 2026 02:28
@OutisLi OutisLi requested a review from njzjz April 24, 2026 02:28
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 24, 2026

📝 Walkthrough

Walkthrough

Refactors AutoBatchSize.is_oom_error to detect CUDA OOMs and AOTInductor/PT2-wrapped failures by handling torch.cuda.OutOfMemoryError, validating RuntimeError types, traversing exception cause/context chains to aggregate messages, and clearing CUDA cache when an OOM is detected.

Changes

Cohort / File(s) Summary
OOM detection logic
deepmd/pt/utils/auto_batch_size.py
Refactored is_oom_error to handle torch.cuda.OutOfMemoryError, reject non-RuntimeError exceptions, traverse __cause__/__context__ chains to collect message strings, detect CUDA/cuSolver OOM markers and AOTInductor/PT2 wrapper signatures, call torch.cuda.empty_cache() on detection, and return a boolean.
Unit tests
source/tests/pt/test_auto_batch_size.py
Added tests for direct CUDA OOM text, OOM in a RuntimeError.__cause__, and detection of AOTInductor/PT2 wrapper messages; mocks torch.cuda.empty_cache to assert it is called when OOM is detected.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: improving AutoBatchSize to recognize CUDA OOM errors wrapped by AOTInductor, which is the core fix addressed in this PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
deepmd/pt/utils/auto_batch_size.py (1)

79-96: Optional: consolidate the three empty_cache() + return True branches.

Minor readability nit — the three OOM-positive branches repeat the same side effect. You can fold them into a single exit point without changing behavior.

♻️ Suggested consolidation
-        if any(m in msg for msg in msgs for m in plain_oom_markers):
-            torch.cuda.empty_cache()
-            return True
-
-        # AOTInductor (.pt2) wraps the underlying CUDA OOM as a generic
-        # ``run_func_(...) API call failed at .../model_container_runner.cpp``.
-        # ...
-        aoti_wrapped = any(
-            "run_func_(" in msg and "model_container_runner" in msg for msg in msgs
-        )
-        if aoti_wrapped:
-            torch.cuda.empty_cache()
-            return True
-
-        return False
+        plain_oom = any(m in msg for msg in msgs for m in plain_oom_markers)
+        # AOTInductor (.pt2) wraps the underlying CUDA OOM as a generic
+        # ``run_func_(...) API call failed at .../model_container_runner.cpp``.
+        # The original "CUDA out of memory" text is printed to stderr only, so
+        # we match on the wrapper signature. If the root cause is not OOM,
+        # ``execute()`` will shrink to batch size 1 and raise ``OutOfMemoryError``.
+        aoti_wrapped = any(
+            "run_func_(" in msg and "model_container_runner" in msg for msg in msgs
+        )
+        if plain_oom or aoti_wrapped:
+            torch.cuda.empty_cache()
+            return True
+        return False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/utils/auto_batch_size.py` around lines 79 - 96, The three places
that call torch.cuda.empty_cache() and return True (the plain OOM marker check
using plain_oom_markers, the earlier OOM detection loop over msgs, and the
AOTInductor wrapper check that sets aoti_wrapped) should be consolidated to a
single exit point: compute the boolean condition variables (e.g., plain_match =
any(m in msg for msg in msgs for m in plain_oom_markers), other_match = ... ,
aoti_wrapped = any("run_func_(" in msg and "model_container_runner" in msg for
msg in msgs)), combine them (e.g., if plain_match or other_match or
aoti_wrapped) then call torch.cuda.empty_cache() once and return True; update
the surrounding function (the OOM detection logic in auto_batch_size) to use
these names so the duplicated side effects are removed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@deepmd/pt/utils/auto_batch_size.py`:
- Around line 79-96: The three places that call torch.cuda.empty_cache() and
return True (the plain OOM marker check using plain_oom_markers, the earlier OOM
detection loop over msgs, and the AOTInductor wrapper check that sets
aoti_wrapped) should be consolidated to a single exit point: compute the boolean
condition variables (e.g., plain_match = any(m in msg for msg in msgs for m in
plain_oom_markers), other_match = ... , aoti_wrapped = any("run_func_(" in msg
and "model_container_runner" in msg for msg in msgs)), combine them (e.g., if
plain_match or other_match or aoti_wrapped) then call torch.cuda.empty_cache()
once and return True; update the surrounding function (the OOM detection logic
in auto_batch_size) to use these names so the duplicated side effects are
removed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: d8e9df32-3cac-49f1-ae75-1e812e9778fb

📥 Commits

Reviewing files that changed from the base of the PR and between 54f42d9 and 0dcff08.

📒 Files selected for processing (1)
  • deepmd/pt/utils/auto_batch_size.py

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Improves PyTorch AutoBatchSize OOM detection for .pt2 AOTInductor-packaged models by recognizing AOTInductor’s wrapped CUDA OOM failures, allowing batch size to shrink instead of crashing.

Changes:

  • Expand OOM detection to scan the exception chain (__cause__ / __context__) for known CUDA/cusolver OOM markers.
  • Add detection for AOTInductor wrapper error signatures (run_func_( + model_container_runner) and treat them as OOM.
  • Keep GPU cache cleanup (torch.cuda.empty_cache()) when an OOM is detected.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt/utils/auto_batch_size.py
Comment thread deepmd/pt/utils/auto_batch_size.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 24, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.36%. Comparing base (5c22e17) to head (33a3bd8).
⚠️ Report is 5 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5418      +/-   ##
==========================================
+ Coverage   80.46%   82.36%   +1.90%     
==========================================
  Files         823      824       +1     
  Lines       86625    87128     +503     
  Branches     4139     4197      +58     
==========================================
+ Hits        69701    71762    +2061     
+ Misses      15651    14091    -1560     
- Partials     1273     1275       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice fix. I walked through the exception-chain handling and the AOTInductor wrapper detection in AutoBatchSize, and the fallback behavior still looks bounded/reasonable. The current CI matrix is green, so I'm happy with this as-is.

— OpenClaw 2026.4.22 (model: gpt-5.4)

Comment thread deepmd/pt/utils/auto_batch_size.py
OutisLi added 2 commits April 25, 2026 10:36
When running `dp --pt-expt test` (or any path that goes through
`deepmd.pt_expt.infer.deep_eval`) against a `.pt2` AOTInductor
package, `AutoBatchSize` doubles the batch on every success.  For
models with a large `sel` the exploration eventually saturates GPU
memory, and the CUDA caching allocator raises the usual
``CUDA out of memory`` from inside the AOTInductor runtime.
AOTInductor then rewraps that error as a generic
    RuntimeError: run_func_(...) API call failed at
        .../aoti_runner/model_container_runner.cpp, line 144
The original "CUDA out of memory" text is printed only to stderr,
so the old `is_oom_error` -- which keyed on a short list of
substrings in `e.args[0]` -- never matched.  `execute()` therefore
did not shrink the batch; the exception propagated and the run
crashed on a GPU that was otherwise completely idle (as confirmed by
monitoring `nvidia-smi --query-compute-apps`, which showed dp itself
as the sole consumer holding tens of GiB just before the failure).
Widen `is_oom_error` to:
* walk the exception chain via `__cause__` / `__context__`, so that a
  future PyTorch preserving the original OOM text is handled for free;
* keep matching the four plain CUDA OOM markers on every message in
  the chain;
* additionally treat the AOTInductor wrapper signature
  (`run_func_(` plus `model_container_runner`) as an OOM candidate.
If the AOTInductor wrapper ever hides a non-OOM failure, the batch
shrinker will halve down to 1 and then raise `OutOfMemoryError`, so
the fallback is bounded -- non-OOM bugs still surface with a clear
terminal error rather than being silently retried forever.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/pt/test_auto_batch_size.py (1)

14-43: Consider adding a negative case and a direct OutOfMemoryError case.

The three new tests cover the chain-traversal and AOTI-wrapper paths well, but two coverage gaps remain that would harden against future regressions:

  1. The direct torch.cuda.OutOfMemoryError branch (early return at the top of is_oom_error) isn't exercised — instantiating one is a bit awkward (it requires CUDA), but a mock.patch on torch.cuda.OutOfMemoryError or constructing via RuntimeError-like surrogate plus isinstance patching can do it; alternatively a simple skip-if-no-CUDA path also works.
  2. There's no negative test asserting that an unrelated RuntimeError (e.g. RuntimeError("shape mismatch")) returns False and that empty_cache is not called. This guards against the marker list/AOTI heuristic accidentally widening into false positives, which would silently clear caches and shrink batch sizes on non-OOM bugs.
♻️ Suggested addition
+    `@mock.patch`("deepmd.pt.utils.auto_batch_size.torch.cuda.empty_cache")
+    def test_is_oom_error_non_oom_runtime_error(self, empty_cache) -> None:
+        auto_batch_size = AutoBatchSize(256, 2.0)
+        self.assertFalse(
+            auto_batch_size.is_oom_error(RuntimeError("shape mismatch"))
+        )
+        empty_cache.assert_not_called()
+
+    `@mock.patch`("deepmd.pt.utils.auto_batch_size.torch.cuda.empty_cache")
+    def test_is_oom_error_non_runtime_error(self, empty_cache) -> None:
+        auto_batch_size = AutoBatchSize(256, 2.0)
+        self.assertFalse(auto_batch_size.is_oom_error(ValueError("nope")))
+        empty_cache.assert_not_called()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/test_auto_batch_size.py` around lines 14 - 43, Add two tests
to TestAutoBatchSize covering (1) direct torch.cuda.OutOfMemoryError and (2) a
negative unrelated RuntimeError: create a test that patches
torch.cuda.OutOfMemoryError (or mocks isinstance checks) and asserts
AutoBatchSize.is_oom_error returns True and torch.cuda.empty_cache is called,
and add another test that calls AutoBatchSize.is_oom_error(RuntimeError("shape
mismatch")) asserting it returns False and that the patched
torch.cuda.empty_cache was NOT called; reference the AutoBatchSize.is_oom_error
method and the existing tests that patch
"deepmd.pt.utils.auto_batch_size.torch.cuda.empty_cache" to mirror setup.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@source/tests/pt/test_auto_batch_size.py`:
- Around line 14-43: Add two tests to TestAutoBatchSize covering (1) direct
torch.cuda.OutOfMemoryError and (2) a negative unrelated RuntimeError: create a
test that patches torch.cuda.OutOfMemoryError (or mocks isinstance checks) and
asserts AutoBatchSize.is_oom_error returns True and torch.cuda.empty_cache is
called, and add another test that calls
AutoBatchSize.is_oom_error(RuntimeError("shape mismatch")) asserting it returns
False and that the patched torch.cuda.empty_cache was NOT called; reference the
AutoBatchSize.is_oom_error method and the existing tests that patch
"deepmd.pt.utils.auto_batch_size.torch.cuda.empty_cache" to mirror setup.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 04fc2d82-1a39-480a-957e-57976382f46c

📥 Commits

Reviewing files that changed from the base of the PR and between 0dcff08 and a00b10f.

📒 Files selected for processing (2)
  • deepmd/pt/utils/auto_batch_size.py
  • source/tests/pt/test_auto_batch_size.py

Comment thread deepmd/pt/utils/auto_batch_size.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
deepmd/pt/utils/auto_batch_size.py (1)

44-102: Robust OOM detection looks good.

Exception-chain traversal with id-based cycle protection, the dedicated torch.cuda.OutOfMemoryError fast-path, and the AOTInductor wrapper signature match cleanly cover the regression described in the PR. The #4594 reference is now correctly placed under the plain CUDA OOM markers (addressing the prior bot comment), and the safety net described in the comment block (shrinking to batch size 1 → OutOfMemoryError) means non-OOM wrapper failures still surface eventually rather than being retried indefinitely.

One small optional cleanup: the two if ...: torch.cuda.empty_cache(); return True branches could be folded into a single check, but the current form is more readable alongside the explanatory comments — feel free to ignore.

Optional consolidation (only if you prefer DRY over inline narration)
-        if any(m in msg for msg in msgs for m in plain_oom_markers):
-            torch.cuda.empty_cache()
-            return True
-
-        # AOTInductor (.pt2) wraps the underlying CUDA OOM as a generic
-        # ``run_func_(...) API call failed at .../model_container_runner.cpp``.
-        # The original "CUDA out of memory" text is printed to stderr only and
-        # is absent from the Python-level RuntimeError, so we match on the
-        # wrapper signature.  If the root cause turns out to be something
-        # other than OOM, ``execute()`` will keep shrinking the batch and
-        # eventually raise ``OutOfMemoryError`` at batch size 1, which is a
-        # clean failure rather than an uncaught exception.
-        aoti_wrapped = any(
-            "run_func_(" in msg and "model_container_runner" in msg for msg in msgs
-        )
-        if aoti_wrapped:
-            torch.cuda.empty_cache()
-            return True
-
-        return False
+        plain_hit = any(m in msg for msg in msgs for m in plain_oom_markers)
+        # AOTInductor (.pt2) wraps the underlying CUDA OOM as a generic
+        # ``run_func_(...) API call failed at .../model_container_runner.cpp``.
+        # The original "CUDA out of memory" text is only on stderr, so we
+        # match on the wrapper signature; non-OOM wrapper failures will
+        # still surface as ``OutOfMemoryError`` once batch size hits 1.
+        aoti_hit = any(
+            "run_func_(" in msg and "model_container_runner" in msg for msg in msgs
+        )
+        if plain_hit or aoti_hit:
+            torch.cuda.empty_cache()
+            return True
+        return False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/utils/auto_batch_size.py` around lines 44 - 102, The two identical
branches in is_oom_error that call torch.cuda.empty_cache() and return True can
be consolidated: compute the plain OOM match with plain_oom_markers (currently
using the any(m in msg ...) loop) and the AOTInductor match into a single
boolean (e.g., plain_oom_match or aoti_wrapped), then if that combined condition
is true call torch.cuda.empty_cache() once and return True; update the logic
around plain_oom_markers and the aoti_wrapped variable accordingly so behavior
is unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@deepmd/pt/utils/auto_batch_size.py`:
- Around line 44-102: The two identical branches in is_oom_error that call
torch.cuda.empty_cache() and return True can be consolidated: compute the plain
OOM match with plain_oom_markers (currently using the any(m in msg ...) loop)
and the AOTInductor match into a single boolean (e.g., plain_oom_match or
aoti_wrapped), then if that combined condition is true call
torch.cuda.empty_cache() once and return True; update the logic around
plain_oom_markers and the aoti_wrapped variable accordingly so behavior is
unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 62b49561-974e-42e1-8fe2-88bafaecdb50

📥 Commits

Reviewing files that changed from the base of the PR and between a00b10f and 33a3bd8.

📒 Files selected for processing (1)
  • deepmd/pt/utils/auto_batch_size.py

Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the current diff and review thread again. The #4594 reference is now back under the plain CUDA OOM markers instead of the AOTInductor wrapper-specific comment, so the URL-placement issue looks fixed. The current CI/checks are green as well, so this looks good to approve.

— OpenClaw 2026.4.22 (model: gpt-5.4)

@njzjz njzjz added this pull request to the merge queue Apr 26, 2026
Merged via the queue into deepmodeling:master with commit c6ca671 Apr 26, 2026
70 checks passed
@OutisLi OutisLi deleted the pr/bs branch April 27, 2026 05:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants