Skip to content

Commit 4eb8db7

Browse files
fix cublasMM workspace
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
1 parent 93feb57 commit 4eb8db7

1 file changed

Lines changed: 21 additions & 3 deletions

File tree

cpp/tensorrt_llm/thop/cublasScaledMM.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "userbuffersTensor.h"
2727
#include <cublasLt.h>
2828
#include <torch/extension.h>
29+
#include <unordered_map>
2930

3031
using torch::Tensor;
3132

@@ -189,11 +190,28 @@ void cublas_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tenso
189190
cudaDataType_t scaleType = CUDA_R_32F;
190191
cublasWrapper->setGemmConfig(aType, bType, outType, /*computeType=*/scaleType);
191192

192-
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
193-
auto workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options);
194-
195193
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
196194

195+
// Persistent per-stream cublasLt workspace. The cublasLt kernel records a
196+
// fixed pointer to this workspace; if we destruct the workspace storage at
197+
// function return (the previous behavior), the CUDA caching allocator may
198+
// hand the same block out to a later allocation while a captured CUDA
199+
// graph still references the workspace pointer -> use-after-free that
200+
// surfaces as free-block-tree corruption on the next allocator operation
201+
// (e.g. FusedMoeRunner::getWorkspaceInfo's first torch::empty destructor).
202+
//
203+
// Keyed by (device, stream) so that concurrent GEMMs on different streams
204+
// don't race on the same scratch bytes. Mirrors PyTorch's per-stream
205+
// cublasLt workspace cache in at::cuda.
206+
thread_local std::unordered_map<cudaStream_t, at::Tensor> workspace_cache;
207+
auto stream_ptr = stream.stream();
208+
auto& workspace = workspace_cache[stream_ptr];
209+
if (!workspace.defined() || workspace.device() != a.device())
210+
{
211+
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
212+
workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options);
213+
}
214+
197215
auto* a_ptr = static_cast<void*>(a.data_ptr());
198216
auto* b_ptr = static_cast<void*>(b.data_ptr());
199217
auto* out_ptr = static_cast<void*>(out.data_ptr());

0 commit comments

Comments
 (0)