|
17 | 17 | #include <cstdio> |
18 | 18 |
|
19 | 19 | #include <array> |
| 20 | +#include <atomic> |
20 | 21 | #include <filesystem> |
21 | 22 | #include <fstream> |
22 | 23 | #include <mutex> |
@@ -80,6 +81,7 @@ namespace { |
80 | 81 | constexpr char kSkipCopyOutputToCpuForMethod[] = |
81 | 82 | "skip_copy_output_to_cpu_for_method"; |
82 | 83 | constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream"; |
| 84 | +constexpr char kWeightSharingAcrossMethods[] = "weight_sharing_across_methods"; |
83 | 85 | } // anonymous namespace |
84 | 86 |
|
85 | 87 | class ET_EXPERIMENTAL CudaBackend final |
@@ -173,6 +175,16 @@ class ET_EXPERIMENTAL CudaBackend final |
173 | 175 | return shared_cuda_stream_ != nullptr; |
174 | 176 | } |
175 | 177 |
|
| 178 | + // Enable cross-method per-FQN weight caching. Set via the |
| 179 | + // kWeightSharingAcrossMethods runtime backend option. |
| 180 | + void set_weight_sharing_across_methods(bool enabled) { |
| 181 | + weight_sharing_across_methods_.store(enabled, std::memory_order_relaxed); |
| 182 | + } |
| 183 | + |
| 184 | + bool is_weight_sharing_across_methods_enabled() const { |
| 185 | + return weight_sharing_across_methods_.load(std::memory_order_relaxed); |
| 186 | + } |
| 187 | + |
176 | 188 | Error load_function_pointers_into_handle( |
177 | 189 | void* so_handle, |
178 | 190 | AOTIDelegateHandle* handle) const { |
@@ -264,6 +276,16 @@ class ET_EXPERIMENTAL CudaBackend final |
264 | 276 | ET_LOG(Error, "Option %s must be a boolean.", kUseSharedCudaStream); |
265 | 277 | return Error::InvalidArgument; |
266 | 278 | } |
| 279 | + } else if (std::strcmp(option.key, kWeightSharingAcrossMethods) == 0) { |
| 280 | + if (auto* val = std::get_if<bool>(&option.value)) { |
| 281 | + set_weight_sharing_across_methods(*val); |
| 282 | + } else { |
| 283 | + ET_LOG( |
| 284 | + Error, |
| 285 | + "Option %s must be a boolean.", |
| 286 | + kWeightSharingAcrossMethods); |
| 287 | + return Error::InvalidArgument; |
| 288 | + } |
267 | 289 | } |
268 | 290 | } |
269 | 291 | return Error::Ok; |
@@ -362,11 +384,20 @@ class ET_EXPERIMENTAL CudaBackend final |
362 | 384 |
|
363 | 385 | handle->container_handle = container_handle; |
364 | 386 |
|
365 | | - // Load constants with per-weight caching. |
366 | | - // This replaces the old update_constants_from_blob + cross-method sharing |
367 | | - // with a unified approach that avoids duplicate GPU allocations. |
368 | | - ET_CHECK_OK_OR_RETURN_ERROR( |
369 | | - load_constants_with_cache(handle, named_data_map, method_name)); |
| 387 | + // Load constants. When weight_sharing_across_methods is enabled (opt-in |
| 388 | + // via the kWeightSharingAcrossMethods runtime backend option set by the |
| 389 | + // runner), use the per-weight FQN cache so methods that share weights |
| 390 | + // (e.g. prefill/decode) avoid duplicate GPU allocations. Otherwise fall |
| 391 | + // back to the legacy per-method blob load — required for models whose |
| 392 | + // methods are independent sub-graphs that may have FQN collisions |
| 393 | + // (e.g. parakeet). |
| 394 | + if (is_weight_sharing_across_methods_enabled()) { |
| 395 | + ET_CHECK_OK_OR_RETURN_ERROR( |
| 396 | + load_constants_with_cache(handle, named_data_map, method_name)); |
| 397 | + } else { |
| 398 | + ET_CHECK_OK_OR_RETURN_ERROR( |
| 399 | + load_constants_legacy(handle, named_data_map, method_name)); |
| 400 | + } |
370 | 401 |
|
371 | 402 | // Use shared CUDA stream if enabled via options, otherwise create one. |
372 | 403 | // A shared stream ensures proper ordering across multiple methods |
@@ -630,6 +661,11 @@ class ET_EXPERIMENTAL CudaBackend final |
630 | 661 | mutable std::mutex cuda_stream_mutex_; |
631 | 662 | std::shared_ptr<cudaStream_t> shared_cuda_stream_ = nullptr; |
632 | 663 |
|
| 664 | + // Whether to enable cross-method per-FQN weight caching at init time. |
| 665 | + // Toggled by the kWeightSharingAcrossMethods runtime backend option. Default |
| 666 | + // OFF — see set_weight_sharing_across_methods() for safety constraints. |
| 667 | + std::atomic<bool> weight_sharing_across_methods_{false}; |
| 668 | + |
633 | 669 | // Cached output tensors for skip-copy optimization. |
634 | 670 | // When skip-copy is enabled, output SlimTensors are cached here to keep |
635 | 671 | // the underlying GPU memory alive while the caller processes the results. |
|
0 commit comments