Skip to content

Commit 9f03772

Browse files
addrses ai comment
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
1 parent 4eb8db7 commit 9f03772

1 file changed

Lines changed: 29 additions & 21 deletions

File tree

cpp/tensorrt_llm/thop/cublasScaledMM.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (out) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
* SPDX-License-Identifier: Apache-2.0
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -159,6 +159,33 @@ bool find_special_algo_deprecated(cublasLtMatmulAlgo_t& algo, std::shared_ptr<Cu
159159
return true;
160160
}
161161

162+
// Helper function: Get or create a workspace tensor for the given (device, stream).
163+
// Workspace is reused across multiple GEMM calls so the pointer captured by the
164+
// cublasLt kernel remains valid for CUDA-graph capture/replay. Keyed by
165+
// (device, stream) so concurrent GEMMs on different streams of the same device
166+
// don't race on the same scratch bytes.
167+
inline at::Tensor const& getWorkspaceTensor(c10::Device device, cudaStream_t stream)
168+
{
169+
struct KeyHash
170+
{
171+
std::size_t operator()(std::pair<int, cudaStream_t> const& key) const noexcept
172+
{
173+
return std::hash<int>()(key.first) ^ (std::hash<cudaStream_t>()(key.second) << 1);
174+
}
175+
};
176+
177+
thread_local std::unordered_map<std::pair<int, cudaStream_t>, at::Tensor, KeyHash> workspace_tensors;
178+
auto key = std::make_pair(device.index(), stream);
179+
180+
if (workspace_tensors.find(key) == workspace_tensors.end())
181+
{
182+
workspace_tensors[key]
183+
= torch::empty(CUBLAS_WORKSPACE_SIZE, torch::TensorOptions().dtype(torch::kUInt8).device(device));
184+
}
185+
186+
return workspace_tensors[key];
187+
}
188+
162189
void cublas_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b,
163190
std::optional<at::Tensor> const& scale_a, std::optional<at::Tensor> const& scale_b,
164191
std::optional<at::Tensor> const& bias, bool fast_acc = false)
@@ -191,26 +218,7 @@ void cublas_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tenso
191218
cublasWrapper->setGemmConfig(aType, bType, outType, /*computeType=*/scaleType);
192219

193220
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
194-
195-
// Persistent per-stream cublasLt workspace. The cublasLt kernel records a
196-
// fixed pointer to this workspace; if we destruct the workspace storage at
197-
// function return (the previous behavior), the CUDA caching allocator may
198-
// hand the same block out to a later allocation while a captured CUDA
199-
// graph still references the workspace pointer -> use-after-free that
200-
// surfaces as free-block-tree corruption on the next allocator operation
201-
// (e.g. FusedMoeRunner::getWorkspaceInfo's first torch::empty destructor).
202-
//
203-
// Keyed by (device, stream) so that concurrent GEMMs on different streams
204-
// don't race on the same scratch bytes. Mirrors PyTorch's per-stream
205-
// cublasLt workspace cache in at::cuda.
206-
thread_local std::unordered_map<cudaStream_t, at::Tensor> workspace_cache;
207-
auto stream_ptr = stream.stream();
208-
auto& workspace = workspace_cache[stream_ptr];
209-
if (!workspace.defined() || workspace.device() != a.device())
210-
{
211-
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
212-
workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options);
213-
}
221+
auto const& workspace = getWorkspaceTensor(a.device(), stream.stream());
214222

215223
auto* a_ptr = static_cast<void*>(a.data_ptr());
216224
auto* b_ptr = static_cast<void*>(b.data_ptr());

0 commit comments

Comments
 (0)