|
1 | 1 | /* |
2 | | - * SPDX-FileCopyrightText: Copyright (out) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | + * Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
3 | 3 | * SPDX-License-Identifier: Apache-2.0 |
4 | 4 | * |
5 | 5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
@@ -159,6 +159,33 @@ bool find_special_algo_deprecated(cublasLtMatmulAlgo_t& algo, std::shared_ptr<Cu |
159 | 159 | return true; |
160 | 160 | } |
161 | 161 |
|
| 162 | +// Helper function: Get or create a workspace tensor for the given (device, stream). |
| 163 | +// Workspace is reused across multiple GEMM calls so the pointer captured by the |
| 164 | +// cublasLt kernel remains valid for CUDA-graph capture/replay. Keyed by |
| 165 | +// (device, stream) so concurrent GEMMs on different streams of the same device |
| 166 | +// don't race on the same scratch bytes. |
| 167 | +inline at::Tensor const& getWorkspaceTensor(c10::Device device, cudaStream_t stream) |
| 168 | +{ |
| 169 | + struct KeyHash |
| 170 | + { |
| 171 | + std::size_t operator()(std::pair<int, cudaStream_t> const& key) const noexcept |
| 172 | + { |
| 173 | + return std::hash<int>()(key.first) ^ (std::hash<cudaStream_t>()(key.second) << 1); |
| 174 | + } |
| 175 | + }; |
| 176 | + |
| 177 | + thread_local std::unordered_map<std::pair<int, cudaStream_t>, at::Tensor, KeyHash> workspace_tensors; |
| 178 | + auto key = std::make_pair(device.index(), stream); |
| 179 | + |
| 180 | + if (workspace_tensors.find(key) == workspace_tensors.end()) |
| 181 | + { |
| 182 | + workspace_tensors[key] |
| 183 | + = torch::empty(CUBLAS_WORKSPACE_SIZE, torch::TensorOptions().dtype(torch::kUInt8).device(device)); |
| 184 | + } |
| 185 | + |
| 186 | + return workspace_tensors[key]; |
| 187 | +} |
| 188 | + |
162 | 189 | void cublas_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, |
163 | 190 | std::optional<at::Tensor> const& scale_a, std::optional<at::Tensor> const& scale_b, |
164 | 191 | std::optional<at::Tensor> const& bias, bool fast_acc = false) |
@@ -191,26 +218,7 @@ void cublas_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tenso |
191 | 218 | cublasWrapper->setGemmConfig(aType, bType, outType, /*computeType=*/scaleType); |
192 | 219 |
|
193 | 220 | auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); |
194 | | - |
195 | | - // Persistent per-stream cublasLt workspace. The cublasLt kernel records a |
196 | | - // fixed pointer to this workspace; if we destruct the workspace storage at |
197 | | - // function return (the previous behavior), the CUDA caching allocator may |
198 | | - // hand the same block out to a later allocation while a captured CUDA |
199 | | - // graph still references the workspace pointer -> use-after-free that |
200 | | - // surfaces as free-block-tree corruption on the next allocator operation |
201 | | - // (e.g. FusedMoeRunner::getWorkspaceInfo's first torch::empty destructor). |
202 | | - // |
203 | | - // Keyed by (device, stream) so that concurrent GEMMs on different streams |
204 | | - // don't race on the same scratch bytes. Mirrors PyTorch's per-stream |
205 | | - // cublasLt workspace cache in at::cuda. |
206 | | - thread_local std::unordered_map<cudaStream_t, at::Tensor> workspace_cache; |
207 | | - auto stream_ptr = stream.stream(); |
208 | | - auto& workspace = workspace_cache[stream_ptr]; |
209 | | - if (!workspace.defined() || workspace.device() != a.device()) |
210 | | - { |
211 | | - auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); |
212 | | - workspace = torch::empty(CUBLAS_WORKSPACE_SIZE, workspace_options); |
213 | | - } |
| 221 | + auto const& workspace = getWorkspaceTensor(a.device(), stream.stream()); |
214 | 222 |
|
215 | 223 | auto* a_ptr = static_cast<void*>(a.data_ptr()); |
216 | 224 | auto* b_ptr = static_cast<void*>(b.data_ptr()); |
|
0 commit comments