Skip to content

[https://nvbugs/6150288][fix] Use persistent per-stream workspace in cublas_mm for CUDA-graph safety#15534

Open
pamelap-nvidia wants to merge 2 commits into
NVIDIA:mainfrom
pamelap-nvidia:pamelap/fix-cublas-workspace-graph-safety
Open

[https://nvbugs/6150288][fix] Use persistent per-stream workspace in cublas_mm for CUDA-graph safety#15534
pamelap-nvidia wants to merge 2 commits into
NVIDIA:mainfrom
pamelap-nvidia:pamelap/fix-cublas-workspace-graph-safety

Conversation

@pamelap-nvidia

@pamelap-nvidia pamelap-nvidia commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

  • Bug Fixes
    • Fixed memory management for CUDA workspace allocation to prevent stale pointer references when using CUDA graphs with allocator caching, improving system stability and runtime performance.

Description

torch.ops.trtllm.cublas_mm (which Nemotron-H reroutes Linear apply() to on SM121 since #13160) calls cublas_gemm_caller, which allocated a fresh torch::empty(CUBLAS_WORKSPACE_SIZE /* 32 MiB */) workspace on every invocation and let it destruct at function-return. Inside a CUDA-graph capture region the captured cublasLt kernel records a fixed pointer into the workspace storage; freeing the storage back to the CUDA caching allocator while the captured graph still references the pointer is a use-after-free. It surfaces later as free-block-tree corruption on the next allocator operation, manifesting as the FusedMoeRunner::getWorkspaceInfo segfault reported on the second generation-only CUDA-graph warmup with mamba-hybrid NVFP4 MoE.

Repro: Nemotron-3-Nano-Omni-30B-A3B-Reasoning-NVFP4 on DGX-Spark with the YAML in the bug. Exposure is Spark-specific because #13160 only enables use_custom_cublas_mm when sm_version() == 121, but the underlying graph-unsafety is not.

Fix

Replace the per-call allocation with a per-stream workspace cache, mirroring PyTorch's own per-stream cublasLt workspace cache in at::cuda. Keyed by stream so concurrent GEMMs on different streams don't race on the same scratch bytes.

Validation

Rebuilt rc14 wheel at the first-bad commit 43e3070de4 + the QA repro YAML (CUDA graphs ON, enable_padding: true):

  • ✅ Crash cleared — Application startup complete reached, no Bus error / Failing at address: 0x1 on aarch64.
  • ✅ Three deterministic prompts (temperature=0, greedy decode) produce correct outputs matching the cuda_graph_config: null baseline:
    • "What is the capital of France?"Paris
    • "Write the integers from 1 to 10..."1, 2, 3, 4, 5, 6, 7, 8, 9, 10
    • "Translate 'Hello, world' to French."Bonjour, monde
  • ✅ Peak GPU memory drops 91.65 GiB (stock rc14, crashed) → 67.58 GiB (patched) — the stock build was effectively leaking ~24 GiB through cached 32 MiB workspace blocks the caching allocator was holding but couldn't reclaim. KV cache budget for this config goes up from 23.06 GiB to 42.33 GiB.

NVBUG: https://nvbugs/6150288

Test plan

  • Crash repro on DGX-Spark cleared with this fix
  • Numerical correctness verified vs cuda_graph_config: null baseline
  • CI / /bot run

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@coderabbitai

coderabbitai Bot commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

In cublas_gemm_caller, the cuBLAS workspace buffer changes from a fresh torch::empty allocation on every call to a thread_local unordered_map<cudaStream_t, at::Tensor> that caches the workspace tensor per stream, re-allocating only when the entry is absent or on a different device.

Changes

Per-stream cuBLAS workspace cache

Layer / File(s) Summary
thread_local workspace cache in cublas_gemm_caller
cpp/tensorrt_llm/thop/cublasScaledMM.cpp
Adds <unordered_map> header and replaces the per-call torch::empty workspace allocation with a thread_local unordered_map<cudaStream_t, at::Tensor> cache; re-allocates only when the cached tensor is undefined or resides on a different device.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: using a persistent per-stream workspace cache in cublas_mm to fix CUDA-graph safety issues.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering the problem, solution, validation, and test plan with detailed technical context and evidence of testing.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
cpp/tensorrt_llm/thop/cublasScaledMM.cpp (1)

206-213: 🩺 Stability & Availability | 🔵 Trivial | ⚖️ Poor tradeoff

Workspace cache is never evicted — unbounded growth with transient streams.

workspace_cache retains a CUBLAS_WORKSPACE_SIZE (32 MiB) tensor per stream handle for the thread's lifetime, with no eviction. Entries for destroyed/short-lived streams are never reclaimed, and a stale stream handle can be reused by a new stream, returning a workspace that may still be referenced by a graph captured on the old stream. In typical TRT-LLM usage streams are bounded, so impact is limited, but if many transient streams are created this grows unboundedly.

Consider bounding the cache or clearing stale entries when a stream is destroyed.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tensorrt_llm/thop/cublasScaledMM.cpp` around lines 206 - 213, The
workspace_cache thread-local unordered_map retains Tensor entries for each
cudaStream_t indefinitely, causing unbounded memory growth when transient
streams are created since destroyed stream handles are never evicted and can be
reused by new streams. Implement cache management by either limiting the cache
size with a bounded container (e.g., LRU cache) or detecting and clearing stale
stream entries before allocating new workspaces. Check if the cached stream
handle is still valid before reusing its workspace, or implement periodic
cleanup of inactive stream entries in the workspace_cache.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/tensorrt_llm/thop/cublasScaledMM.cpp`:
- Around line 195-213: Locate the copyright header at the beginning of the
cublasScaledMM.cpp file and find the malformed placeholder text that reads
"Copyright (out)". Replace this entire copyright line with the correct standard
format: "Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights
reserved." to match the project's copyright convention.
- Around line 203-213: The workspace_cache in the cublasScaledMM function is
keyed only by stream pointer, but it should be keyed by (device, stream) pair to
prevent cache collisions when the same thread uses the default stream on
multiple devices. Change the thread_local declaration from unordered_map keyed
by cudaStream_t to std::map keyed by a pair of device and stream (you can
construct the key by combining a.device() and stream_ptr). Update the workspace
lookup to use this pair as the key instead of just stream_ptr. This ensures each
device-stream combination has its own cache entry and prevents the reallocation
and freeing of one device's workspace when switching to another device in the
same thread.

---

Nitpick comments:
In `@cpp/tensorrt_llm/thop/cublasScaledMM.cpp`:
- Around line 206-213: The workspace_cache thread-local unordered_map retains
Tensor entries for each cudaStream_t indefinitely, causing unbounded memory
growth when transient streams are created since destroyed stream handles are
never evicted and can be reused by new streams. Implement cache management by
either limiting the cache size with a bounded container (e.g., LRU cache) or
detecting and clearing stale stream entries before allocating new workspaces.
Check if the cached stream handle is still valid before reusing its workspace,
or implement periodic cleanup of inactive stream entries in the workspace_cache.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 8196186e-5b43-4d80-a1af-0aa89a5bb819

📥 Commits

Reviewing files that changed from the base of the PR and between 3448424 and 6e90a79.

📒 Files selected for processing (1)
  • cpp/tensorrt_llm/thop/cublasScaledMM.cpp

Comment on lines +195 to +213
// Persistent per-stream cublasLt workspace. The cublasLt kernel records a
// fixed pointer to this workspace; if we destruct the workspace storage at
// function return (the previous behavior), the CUDA caching allocator may
// hand the same block out to a later allocation while a captured CUDA
// graph still references the workspace pointer -> use-after-free that
// surfaces as free-block-tree corruption on the next allocator operation
// (e.g. FusedMoeRunner::getWorkspaceInfo's first torch::empty destructor).
//
// Keyed by (device, stream) so that concurrent GEMMs on different streams
// don't race on the same scratch bytes. Mirrors PyTorch's per-stream
// cublasLt workspace cache in at::cuda.
thread_local std::unordered_map<cudaStream_t, at::Tensor> workspace_cache;
auto stream_ptr = stream.stream();
auto& workspace = workspace_cache[stream_ptr];
if (!workspace.defined() || workspace.device() != a.device())
{
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
sed -n '1,15p' cpp/tensorrt_llm/thop/cublasScaledMM.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 860


Fix malformed copyright header placeholder text.

The copyright header contains a malformed placeholder Copyright (out) that must be corrected. While the year 2026 is already present and correct, the (out) text appears to be a substitution error and should be removed or properly filled in. Update the copyright line to match the standard format: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tensorrt_llm/thop/cublasScaledMM.cpp` around lines 195 - 213, Locate the
copyright header at the beginning of the cublasScaledMM.cpp file and find the
malformed placeholder text that reads "Copyright (out)". Replace this entire
copyright line with the correct standard format: "Copyright (c) 2026, NVIDIA
CORPORATION & AFFILIATES. All rights reserved." to match the project's copyright
convention.

Source: Coding guidelines

Comment on lines +203 to +213
// Keyed by (device, stream) so that concurrent GEMMs on different streams
// don't race on the same scratch bytes. Mirrors PyTorch's per-stream
// cublasLt workspace cache in at::cuda.
thread_local std::unordered_map<cudaStream_t, at::Tensor> workspace_cache;
auto stream_ptr = stream.stream();
auto& workspace = workspace_cache[stream_ptr];
if (!workspace.defined() || workspace.device() != a.device())
{
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
rg -nP --type=cpp -C3 '\bcublas_(mm_out|scaled_mm_out|gemm_caller)\s*\('
rg -nP --type=cpp -C3 'getCurrentCUDAStream'

Repository: NVIDIA/TensorRT-LLM

Length of output: 157


🏁 Script executed:

#!/bin/bash
# First, check if the file exists
find . -name "cublasScaledMM.cpp" -type f | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 200


🏁 Script executed:

#!/bin/bash
# If file exists, read the specific lines mentioned (203-213) and surrounding context
if [ -f "cpp/tensorrt_llm/thop/cublasScaledMM.cpp" ]; then
  echo "=== Lines 200-220 ===" 
  sed -n '200,220p' cpp/tensorrt_llm/thop/cublasScaledMM.cpp | cat -n
  echo ""
  echo "=== Check for includes and see if we have unordered_map ===" 
  head -50 cpp/tensorrt_llm/thop/cublasScaledMM.cpp | cat -n
fi

Repository: NVIDIA/TensorRT-LLM

Length of output: 3489


🏁 Script executed:

#!/bin/bash
# Search for function definitions containing "cublas" and workspace_cache
rg -n "workspace_cache|cublas_gemm_caller" cpp/tensorrt_llm/thop/ -A 5 -B 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 3312


🏁 Script executed:

#!/bin/bash
# Look for callers of functions in cublasScaledMM or usage patterns
rg -n "cublasScaledMM|cublas.*MM" --type=cpp -B 2 -A 2 | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 157


🏁 Script executed:

#!/bin/bash
# Find all functions that call cublas_gemm_caller
rg -n "cublas_gemm_caller|cublas_scaled_mm|cublas_mm" cpp/tensorrt_llm/ --type=cpp -B 3 -A 1 | grep -E "(def|^.*\.cpp.*:.*\()" | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 2906


🏁 Script executed:

#!/bin/bash
# Look for Python bindings or module initialization that exposes these functions
rg -n "PYBIND|cublas_scaled_mm|cublas_mm" --type=cpp --type=py -B 2 -A 2 | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 157


🏁 Script executed:

#!/bin/bash
# Check if there's context about how the stream is obtained and if multiple devices can be used
sed -n '160,230p' cpp/tensorrt_llm/thop/cublasScaledMM.cpp | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 3632


🏁 Script executed:

#!/bin/bash
# Look at the broader context of the function and how streams/devices are handled
rg -n "device|stream|cuda" cpp/tensorrt_llm/thop/cublasScaledMM.cpp | head -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 2242


Cache key should be (device, stream), not stream alone — risks re-introducing UAF on multi-device default streams.

The comment states the cache is "Keyed by (device, stream)", but workspace_cache is keyed only by cudaStream_t. For the default/per-thread stream, stream.stream() returns the same sentinel handle (e.g. cudaStreamPerThread) on every device. If a single thread issues GEMMs on two devices through the default stream, both collide on one map entry. The workspace.device() != a.device() guard then reallocates on each device switch, which not only defeats the cache but reassigns (and frees) the other device's cached workspace—exactly the use-after-free this PR aims to prevent if a captured graph on that device still references the freed block.

Key the map on the (device, stream) pair so distinct devices never share an entry:

🛡️ Proposed fix: include device in the cache key
-    thread_local std::unordered_map<cudaStream_t, at::Tensor> workspace_cache;
-    auto stream_ptr = stream.stream();
-    auto& workspace = workspace_cache[stream_ptr];
-    if (!workspace.defined() || workspace.device() != a.device())
+    thread_local std::map<std::pair<int, cudaStream_t>, at::Tensor> workspace_cache;
+    auto stream_ptr = stream.stream();
+    auto& workspace = workspace_cache[{a.get_device(), stream_ptr}];
+    if (!workspace.defined())
     {
         auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
         workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options);
     }

(std::map avoids needing a custom hash for std::pair; include <map> instead of/in addition to <unordered_map>.)

The function is exposed via Python PYBIND (cublas_scaled_mm, cublas_mm) and can be called from user code with tensors on different devices in the same thread. The default stream scenario you identified is realistic and represents a latent UAF bug.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tensorrt_llm/thop/cublasScaledMM.cpp` around lines 203 - 213, The
workspace_cache in the cublasScaledMM function is keyed only by stream pointer,
but it should be keyed by (device, stream) pair to prevent cache collisions when
the same thread uses the default stream on multiple devices. Change the
thread_local declaration from unordered_map keyed by cudaStream_t to std::map
keyed by a pair of device and stream (you can construct the key by combining
a.device() and stream_ptr). Update the workspace lookup to use this pair as the
key instead of just stream_ptr. This ensures each device-stream combination has
its own cache entry and prevents the reallocation and freeing of one device's
workspace when switching to another device in the same thread.

@pamelap-nvidia

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55157 [ run ] triggered by Bot. Commit: df1d37c Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55157 [ run ] completed with state SUCCESS. Commit: df1d37c
/LLM/main/L0_MergeRequest_PR pipeline #44131 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
@pamelap-nvidia pamelap-nvidia force-pushed the pamelap/fix-cublas-workspace-graph-safety branch from df1d37c to 9f03772 Compare June 23, 2026 15:17
@pamelap-nvidia

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55263 [ run ] triggered by Bot. Commit: 9f03772 Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants