[https://nvbugs/6150288][fix] Use persistent per-stream workspace in cublas_mm for CUDA-graph safety#15534
Conversation
📝 WalkthroughWalkthroughIn ChangesPer-stream cuBLAS workspace cache
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
cpp/tensorrt_llm/thop/cublasScaledMM.cpp (1)
206-213: 🩺 Stability & Availability | 🔵 Trivial | ⚖️ Poor tradeoffWorkspace cache is never evicted — unbounded growth with transient streams.
workspace_cacheretains aCUBLAS_WORKSPACE_SIZE(32 MiB) tensor per stream handle for the thread's lifetime, with no eviction. Entries for destroyed/short-lived streams are never reclaimed, and a stale stream handle can be reused by a new stream, returning a workspace that may still be referenced by a graph captured on the old stream. In typical TRT-LLM usage streams are bounded, so impact is limited, but if many transient streams are created this grows unboundedly.Consider bounding the cache or clearing stale entries when a stream is destroyed.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/thop/cublasScaledMM.cpp` around lines 206 - 213, The workspace_cache thread-local unordered_map retains Tensor entries for each cudaStream_t indefinitely, causing unbounded memory growth when transient streams are created since destroyed stream handles are never evicted and can be reused by new streams. Implement cache management by either limiting the cache size with a bounded container (e.g., LRU cache) or detecting and clearing stale stream entries before allocating new workspaces. Check if the cached stream handle is still valid before reusing its workspace, or implement periodic cleanup of inactive stream entries in the workspace_cache.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/tensorrt_llm/thop/cublasScaledMM.cpp`:
- Around line 195-213: Locate the copyright header at the beginning of the
cublasScaledMM.cpp file and find the malformed placeholder text that reads
"Copyright (out)". Replace this entire copyright line with the correct standard
format: "Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights
reserved." to match the project's copyright convention.
- Around line 203-213: The workspace_cache in the cublasScaledMM function is
keyed only by stream pointer, but it should be keyed by (device, stream) pair to
prevent cache collisions when the same thread uses the default stream on
multiple devices. Change the thread_local declaration from unordered_map keyed
by cudaStream_t to std::map keyed by a pair of device and stream (you can
construct the key by combining a.device() and stream_ptr). Update the workspace
lookup to use this pair as the key instead of just stream_ptr. This ensures each
device-stream combination has its own cache entry and prevents the reallocation
and freeing of one device's workspace when switching to another device in the
same thread.
---
Nitpick comments:
In `@cpp/tensorrt_llm/thop/cublasScaledMM.cpp`:
- Around line 206-213: The workspace_cache thread-local unordered_map retains
Tensor entries for each cudaStream_t indefinitely, causing unbounded memory
growth when transient streams are created since destroyed stream handles are
never evicted and can be reused by new streams. Implement cache management by
either limiting the cache size with a bounded container (e.g., LRU cache) or
detecting and clearing stale stream entries before allocating new workspaces.
Check if the cached stream handle is still valid before reusing its workspace,
or implement periodic cleanup of inactive stream entries in the workspace_cache.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 8196186e-5b43-4d80-a1af-0aa89a5bb819
📒 Files selected for processing (1)
cpp/tensorrt_llm/thop/cublasScaledMM.cpp
| // Persistent per-stream cublasLt workspace. The cublasLt kernel records a | ||
| // fixed pointer to this workspace; if we destruct the workspace storage at | ||
| // function return (the previous behavior), the CUDA caching allocator may | ||
| // hand the same block out to a later allocation while a captured CUDA | ||
| // graph still references the workspace pointer -> use-after-free that | ||
| // surfaces as free-block-tree corruption on the next allocator operation | ||
| // (e.g. FusedMoeRunner::getWorkspaceInfo's first torch::empty destructor). | ||
| // | ||
| // Keyed by (device, stream) so that concurrent GEMMs on different streams | ||
| // don't race on the same scratch bytes. Mirrors PyTorch's per-stream | ||
| // cublasLt workspace cache in at::cuda. | ||
| thread_local std::unordered_map<cudaStream_t, at::Tensor> workspace_cache; | ||
| auto stream_ptr = stream.stream(); | ||
| auto& workspace = workspace_cache[stream_ptr]; | ||
| if (!workspace.defined() || workspace.device() != a.device()) | ||
| { | ||
| auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); | ||
| workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options); | ||
| } |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🔴 Critical
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
sed -n '1,15p' cpp/tensorrt_llm/thop/cublasScaledMM.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 860
Fix malformed copyright header placeholder text.
The copyright header contains a malformed placeholder Copyright (out) that must be corrected. While the year 2026 is already present and correct, the (out) text appears to be a substitution error and should be removed or properly filled in. Update the copyright line to match the standard format: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/tensorrt_llm/thop/cublasScaledMM.cpp` around lines 195 - 213, Locate the
copyright header at the beginning of the cublasScaledMM.cpp file and find the
malformed placeholder text that reads "Copyright (out)". Replace this entire
copyright line with the correct standard format: "Copyright (c) 2026, NVIDIA
CORPORATION & AFFILIATES. All rights reserved." to match the project's copyright
convention.
Source: Coding guidelines
| // Keyed by (device, stream) so that concurrent GEMMs on different streams | ||
| // don't race on the same scratch bytes. Mirrors PyTorch's per-stream | ||
| // cublasLt workspace cache in at::cuda. | ||
| thread_local std::unordered_map<cudaStream_t, at::Tensor> workspace_cache; | ||
| auto stream_ptr = stream.stream(); | ||
| auto& workspace = workspace_cache[stream_ptr]; | ||
| if (!workspace.defined() || workspace.device() != a.device()) | ||
| { | ||
| auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); | ||
| workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options); | ||
| } |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
rg -nP --type=cpp -C3 '\bcublas_(mm_out|scaled_mm_out|gemm_caller)\s*\('
rg -nP --type=cpp -C3 'getCurrentCUDAStream'Repository: NVIDIA/TensorRT-LLM
Length of output: 157
🏁 Script executed:
#!/bin/bash
# First, check if the file exists
find . -name "cublasScaledMM.cpp" -type f | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 200
🏁 Script executed:
#!/bin/bash
# If file exists, read the specific lines mentioned (203-213) and surrounding context
if [ -f "cpp/tensorrt_llm/thop/cublasScaledMM.cpp" ]; then
echo "=== Lines 200-220 ==="
sed -n '200,220p' cpp/tensorrt_llm/thop/cublasScaledMM.cpp | cat -n
echo ""
echo "=== Check for includes and see if we have unordered_map ==="
head -50 cpp/tensorrt_llm/thop/cublasScaledMM.cpp | cat -n
fiRepository: NVIDIA/TensorRT-LLM
Length of output: 3489
🏁 Script executed:
#!/bin/bash
# Search for function definitions containing "cublas" and workspace_cache
rg -n "workspace_cache|cublas_gemm_caller" cpp/tensorrt_llm/thop/ -A 5 -B 2Repository: NVIDIA/TensorRT-LLM
Length of output: 3312
🏁 Script executed:
#!/bin/bash
# Look for callers of functions in cublasScaledMM or usage patterns
rg -n "cublasScaledMM|cublas.*MM" --type=cpp -B 2 -A 2 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 157
🏁 Script executed:
#!/bin/bash
# Find all functions that call cublas_gemm_caller
rg -n "cublas_gemm_caller|cublas_scaled_mm|cublas_mm" cpp/tensorrt_llm/ --type=cpp -B 3 -A 1 | grep -E "(def|^.*\.cpp.*:.*\()" | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 2906
🏁 Script executed:
#!/bin/bash
# Look for Python bindings or module initialization that exposes these functions
rg -n "PYBIND|cublas_scaled_mm|cublas_mm" --type=cpp --type=py -B 2 -A 2 | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 157
🏁 Script executed:
#!/bin/bash
# Check if there's context about how the stream is obtained and if multiple devices can be used
sed -n '160,230p' cpp/tensorrt_llm/thop/cublasScaledMM.cpp | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 3632
🏁 Script executed:
#!/bin/bash
# Look at the broader context of the function and how streams/devices are handled
rg -n "device|stream|cuda" cpp/tensorrt_llm/thop/cublasScaledMM.cpp | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 2242
Cache key should be (device, stream), not stream alone — risks re-introducing UAF on multi-device default streams.
The comment states the cache is "Keyed by (device, stream)", but workspace_cache is keyed only by cudaStream_t. For the default/per-thread stream, stream.stream() returns the same sentinel handle (e.g. cudaStreamPerThread) on every device. If a single thread issues GEMMs on two devices through the default stream, both collide on one map entry. The workspace.device() != a.device() guard then reallocates on each device switch, which not only defeats the cache but reassigns (and frees) the other device's cached workspace—exactly the use-after-free this PR aims to prevent if a captured graph on that device still references the freed block.
Key the map on the (device, stream) pair so distinct devices never share an entry:
🛡️ Proposed fix: include device in the cache key
- thread_local std::unordered_map<cudaStream_t, at::Tensor> workspace_cache;
- auto stream_ptr = stream.stream();
- auto& workspace = workspace_cache[stream_ptr];
- if (!workspace.defined() || workspace.device() != a.device())
+ thread_local std::map<std::pair<int, cudaStream_t>, at::Tensor> workspace_cache;
+ auto stream_ptr = stream.stream();
+ auto& workspace = workspace_cache[{a.get_device(), stream_ptr}];
+ if (!workspace.defined())
{
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options);
}(std::map avoids needing a custom hash for std::pair; include <map> instead of/in addition to <unordered_map>.)
The function is exposed via Python PYBIND (cublas_scaled_mm, cublas_mm) and can be called from user code with tensors on different devices in the same thread. The default stream scenario you identified is realistic and represents a latent UAF bug.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/tensorrt_llm/thop/cublasScaledMM.cpp` around lines 203 - 213, The
workspace_cache in the cublasScaledMM function is keyed only by stream pointer,
but it should be keyed by (device, stream) pair to prevent cache collisions when
the same thread uses the default stream on multiple devices. Change the
thread_local declaration from unordered_map keyed by cudaStream_t to std::map
keyed by a pair of device and stream (you can construct the key by combining
a.device() and stream_ptr). Update the workspace lookup to use this pair as the
key instead of just stream_ptr. This ensures each device-stream combination has
its own cache entry and prevents the reallocation and freeing of one device's
workspace when switching to another device in the same thread.
|
/bot run |
|
PR_Github #55157 [ run ] triggered by Bot. Commit: |
|
PR_Github #55157 [ run ] completed with state
|
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
df1d37c to
9f03772
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #55263 [ run ] triggered by Bot. Commit: |
Summary by CodeRabbit
Description
torch.ops.trtllm.cublas_mm(which Nemotron-H reroutes Linearapply()to on SM121 since #13160) callscublas_gemm_caller, which allocated a freshtorch::empty(CUBLAS_WORKSPACE_SIZE /* 32 MiB */)workspace on every invocation and let it destruct at function-return. Inside a CUDA-graph capture region the captured cublasLt kernel records a fixed pointer into the workspace storage; freeing the storage back to the CUDA caching allocator while the captured graph still references the pointer is a use-after-free. It surfaces later as free-block-tree corruption on the next allocator operation, manifesting as theFusedMoeRunner::getWorkspaceInfosegfault reported on the second generation-only CUDA-graph warmup with mamba-hybrid NVFP4 MoE.Repro: Nemotron-3-Nano-Omni-30B-A3B-Reasoning-NVFP4 on DGX-Spark with the YAML in the bug. Exposure is Spark-specific because #13160 only enables
use_custom_cublas_mmwhensm_version() == 121, but the underlying graph-unsafety is not.Fix
Replace the per-call allocation with a per-stream workspace cache, mirroring PyTorch's own per-stream cublasLt workspace cache in
at::cuda. Keyed by stream so concurrent GEMMs on different streams don't race on the same scratch bytes.Validation
Rebuilt rc14 wheel at the first-bad commit
43e3070de4+ the QA repro YAML (CUDA graphs ON,enable_padding: true):Application startup completereached, noBus error / Failing at address: 0x1on aarch64.cuda_graph_config: nullbaseline:"What is the capital of France?"→Paris"Write the integers from 1 to 10..."→1, 2, 3, 4, 5, 6, 7, 8, 9, 10"Translate 'Hello, world' to French."→Bonjour, mondeNVBUG: https://nvbugs/6150288
Test plan
cuda_graph_config: nullbaseline/bot runPR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.