|
26 | 26 | #include "userbuffersTensor.h" |
27 | 27 | #include <cublasLt.h> |
28 | 28 | #include <torch/extension.h> |
| 29 | +#include <unordered_map> |
29 | 30 |
|
30 | 31 | using torch::Tensor; |
31 | 32 |
|
@@ -189,11 +190,28 @@ void cublas_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tenso |
189 | 190 | cudaDataType_t scaleType = CUDA_R_32F; |
190 | 191 | cublasWrapper->setGemmConfig(aType, bType, outType, /*computeType=*/scaleType); |
191 | 192 |
|
192 | | - auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); |
193 | | - auto workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options); |
194 | | - |
195 | 193 | auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); |
196 | 194 |
|
| 195 | + // Persistent per-stream cublasLt workspace. The cublasLt kernel records a |
| 196 | + // fixed pointer to this workspace; if we destruct the workspace storage at |
| 197 | + // function return (the previous behavior), the CUDA caching allocator may |
| 198 | + // hand the same block out to a later allocation while a captured CUDA |
| 199 | + // graph still references the workspace pointer -> use-after-free that |
| 200 | + // surfaces as free-block-tree corruption on the next allocator operation |
| 201 | + // (e.g. FusedMoeRunner::getWorkspaceInfo's first torch::empty destructor). |
| 202 | + // |
| 203 | + // Keyed by (device, stream) so that concurrent GEMMs on different streams |
| 204 | + // don't race on the same scratch bytes. Mirrors PyTorch's per-stream |
| 205 | + // cublasLt workspace cache in at::cuda. |
| 206 | + thread_local std::unordered_map<cudaStream_t, at::Tensor> workspace_cache; |
| 207 | + auto stream_ptr = stream.stream(); |
| 208 | + auto& workspace = workspace_cache[stream_ptr]; |
| 209 | + if (!workspace.defined() || workspace.device() != a.device()) |
| 210 | + { |
| 211 | + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); |
| 212 | + workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options); |
| 213 | + } |
| 214 | + |
197 | 215 | auto* a_ptr = static_cast<void*>(a.data_ptr()); |
198 | 216 | auto* b_ptr = static_cast<void*>(b.data_ptr()); |
199 | 217 | auto* out_ptr = static_cast<void*>(out.data_ptr()); |
|
0 commit comments