-
Notifications
You must be signed in to change notification settings - Fork 954
Expand file tree
/
Copy pathcuda_backend.cpp
More file actions
810 lines (713 loc) · 29.5 KB
/
cuda_backend.cpp
File metadata and controls
810 lines (713 loc) · 29.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <c10/util/safe_numerics.h>
#include <cuda_runtime.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <cctype>
#include <cstdio>
#include <array>
#include <filesystem>
#include <fstream>
#include <mutex>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>
// Include SlimTensor headers for CUDA backend
#include <executorch/backends/aoti/slim/c10/core/Device.h>
#include <executorch/backends/aoti/slim/c10/cuda/Exception.h>
#include <executorch/backends/aoti/slim/core/slim_tensor.h>
#include <executorch/backends/aoti/slim/core/storage.h>
#include <executorch/backends/aoti/slim/factory/empty.h>
#include <executorch/backends/aoti/slim/factory/from_blob.h>
#include <executorch/backends/aoti/slim/factory/from_etensor.h>
#include <executorch/backends/aoti/slim/util/array_ref_util.h>
// Include our shim layer headers
#include <executorch/backends/aoti/aoti_delegate_handle.h>
#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/cuda/runtime/cuda_delegate_handle.h>
#include <executorch/backends/cuda/runtime/platform/platform.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/utils.h>
namespace executorch::backends::cuda {
using namespace std;
using namespace aoti;
using executorch::aten::ScalarType;
using executorch::runtime::ArrayRef;
using executorch::runtime::Backend;
using executorch::runtime::BackendExecutionContext;
using executorch::runtime::BackendInitContext;
using executorch::runtime::BackendOption;
using executorch::runtime::BackendOptionContext;
using executorch::runtime::CompileSpec;
using executorch::runtime::DelegateHandle;
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::kMaxOptionValueLength;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::NamedDataMap;
using executorch::runtime::Result;
using executorch::runtime::Span;
using executorch::runtime::etensor::Tensor;
// SlimTensor type aliases
using slim::CPU_DEVICE;
using slim::DEFAULT_CUDA_DEVICE;
using slim::DeviceTraits;
using slim::from_etensor;
using slim::SlimTensor;
using slim::c10::Device;
using slim::c10::DeviceType;
namespace {
constexpr char kSkipCopyOutputToCpuForMethod[] =
"skip_copy_output_to_cpu_for_method";
constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream";
constexpr char kShareKvCacheAcrossMethods[] = "share_kv_cache_across_methods";
} // anonymous namespace
class ET_EXPERIMENTAL CudaBackend final
: public ::executorch::runtime::BackendInterface {
private:
// Trim leading/trailing whitespace from a view of the string.
static std::string_view trim(std::string_view s) {
size_t start = 0;
while (start < s.size() &&
std::isspace(static_cast<unsigned char>(s[start]))) {
++start;
}
size_t end = s.size();
while (end > start &&
std::isspace(static_cast<unsigned char>(s[end - 1]))) {
--end;
}
return s.substr(start, end - start);
}
// Check if method_name appears in a comma-separated list.
static bool method_in_csv(
const std::string& method_name,
const std::string& csv) {
size_t pos = 0;
while (pos <= csv.size()) {
const size_t comma = csv.find(',', pos);
const std::string_view token =
trim(std::string_view(csv).substr(pos, comma - pos));
if (!token.empty() && token == method_name) {
return true;
}
if (comma == std::string::npos) {
break;
}
pos = comma + 1;
}
return false;
}
void set_skip_copy_method(
const std::array<char, kMaxOptionValueLength>& raw) {
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
skip_copy_method_ = std::string(raw.data());
}
std::array<char, kMaxOptionValueLength> get_skip_copy_method_as_option()
const {
std::array<char, kMaxOptionValueLength> out{};
std::string value;
{
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
value = skip_copy_method_;
}
std::snprintf(out.data(), out.size(), "%s", value.c_str());
return out;
}
bool should_skip_copy_for_method(const std::string& method_name) const {
if (method_name.empty()) {
return false;
}
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
return method_in_csv(method_name, skip_copy_method_);
}
// Create the shared CUDA stream. Called when use_shared_cuda_stream option
// is set to true. The presence of shared_cuda_stream_ indicates shared mode.
void create_shared_cuda_stream() {
std::lock_guard<std::mutex> guard(cuda_stream_mutex_);
if (shared_cuda_stream_ != nullptr) {
return; // Already created
}
shared_cuda_stream_ = cuda::create_cuda_stream();
if (shared_cuda_stream_ == nullptr) {
ET_LOG(Error, "Failed to create shared CUDA stream");
return;
}
ET_LOG(Info, "Created shared CUDA stream: %p", *shared_cuda_stream_);
}
// Get the shared CUDA stream. Returns nullptr if not in shared mode.
std::shared_ptr<cudaStream_t> get_shared_cuda_stream() const {
std::lock_guard<std::mutex> guard(cuda_stream_mutex_);
return shared_cuda_stream_;
}
// Check if we're using shared CUDA stream mode.
bool is_using_shared_cuda_stream() const {
std::lock_guard<std::mutex> guard(cuda_stream_mutex_);
return shared_cuda_stream_ != nullptr;
}
Error load_function_pointers_into_handle(
void* so_handle,
AOTIDelegateHandle* handle) const {
#define LOAD_SYMBOL(member, name) \
do { \
auto symbol_res = get_function(so_handle, #name); \
if (!symbol_res.ok()) { \
return symbol_res.error(); \
} \
handle->member = reinterpret_cast<name##Func>(symbol_res.get()); \
} while (0)
LOAD_SYMBOL(create_with_device, AOTInductorModelContainerCreateWithDevice);
LOAD_SYMBOL(delete_container, AOTInductorModelContainerDelete);
LOAD_SYMBOL(get_num_inputs, AOTInductorModelContainerGetNumInputs);
LOAD_SYMBOL(get_num_outputs, AOTInductorModelContainerGetNumOutputs);
LOAD_SYMBOL(run, AOTInductorModelContainerRun);
#undef LOAD_SYMBOL
auto symbol_res =
get_function(so_handle, "AOTInductorModelUpdateConstantsFromBlob");
if (symbol_res.ok()) {
handle->update_constants_from_blob =
reinterpret_cast<AOTInductorModelUpdateConstantsFromBlobFunc>(
symbol_res.get());
} else {
ET_LOG(
Info,
"Failed to load AOTInductorModelUpdateConstantsFromBlob. This .so is probably compiled on an old version of torch (<2.9.0)");
}
// Load constant management symbols (optional — needed for cross-method
// buffer sharing). These are available in torch >= 2.6.
#define LOAD_OPTIONAL_SYMBOL(member, name) \
do { \
auto res = get_function(so_handle, #name); \
handle->member = \
res.ok() ? reinterpret_cast<name##Func>(res.get()) : nullptr; \
} while (0)
LOAD_OPTIONAL_SYMBOL(
get_num_constants, AOTInductorModelContainerGetNumConstants);
LOAD_OPTIONAL_SYMBOL(
get_constant_name, AOTInductorModelContainerGetConstantName);
LOAD_OPTIONAL_SYMBOL(
get_constant_original_fqn,
AOTInductorModelContainerGetConstantOriginalFQN);
LOAD_OPTIONAL_SYMBOL(
extract_constants_map, AOTInductorModelContainerExtractConstantsMap);
LOAD_OPTIONAL_SYMBOL(
update_user_managed_constant_buffer_pairs,
AOTInductorModelContainerUpdateUserManagedConstantBufferPairs);
#undef LOAD_OPTIONAL_SYMBOL
return Error::Ok;
}
public:
bool is_available() const override {
return 1;
}
Error set_option(
ET_UNUSED BackendOptionContext& context,
const executorch::runtime::Span<BackendOption>& backend_options)
override {
for (const auto& option : backend_options) {
if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) {
if (auto* val = std::get_if<std::array<char, kMaxOptionValueLength>>(
&option.value)) {
set_skip_copy_method(*val);
} else {
ET_LOG(
Error,
"Option %s must be a method name string.",
kSkipCopyOutputToCpuForMethod);
return Error::InvalidArgument;
}
} else if (std::strcmp(option.key, kUseSharedCudaStream) == 0) {
if (auto* val = std::get_if<bool>(&option.value)) {
if (*val) {
create_shared_cuda_stream();
}
} else {
ET_LOG(Error, "Option %s must be a boolean.", kUseSharedCudaStream);
return Error::InvalidArgument;
}
}
}
return Error::Ok;
}
Error get_option(
ET_UNUSED BackendOptionContext& context,
executorch::runtime::Span<BackendOption>& backend_options) override {
for (auto& option : backend_options) {
if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) {
option.value = get_skip_copy_method_as_option();
}
}
return Error::Ok;
}
// Once per loaded binary blob
Result<DelegateHandle*> init(
BackendInitContext& context,
FreeableBuffer* processed, // This will be a empty buffer
ArrayRef<CompileSpec> compile_specs // This will be my empty list
) const override {
std::string method_name;
bool share_kv_cache = false;
for (const CompileSpec& spec : compile_specs) {
if (std::strcmp(spec.key, "method_name") == 0) {
method_name.assign(
static_cast<const char*>(spec.value.buffer),
spec.value.nbytes); // no nullptr guarantee, so pass size
} else if (std::strcmp(spec.key, kShareKvCacheAcrossMethods) == 0) {
if (spec.value.nbytes >= 1) {
share_kv_cache =
static_cast<const uint8_t*>(spec.value.buffer)[0] != 0;
}
}
}
std::string so_blob_key =
method_name.empty() ? "so_blob" : method_name + "_so_blob";
const NamedDataMap* named_data_map = context.get_named_data_map();
auto aoti_dso_buffer = named_data_map->get_data(so_blob_key.c_str());
ET_CHECK_OR_RETURN_ERROR(
aoti_dso_buffer.ok(),
Internal,
"Failed to get data for key %s: 0x%x",
so_blob_key.c_str(),
static_cast<uint32_t>(aoti_dso_buffer.error()));
// Generate dynamic temporary file path
filesystem::path temp_dir = filesystem::temp_directory_path();
filesystem::path so_path =
temp_dir / (so_blob_key + to_string(get_process_id()) + ".so");
// Create a temporary file
ofstream outfile(so_path, ios::binary);
// Write the ELF buffer to the temporary file
ET_LOG(
Info,
"Writing %zu bytes to %s",
aoti_dso_buffer->size(),
so_path.c_str());
outfile.write(
static_cast<const char*>(aoti_dso_buffer->data()),
aoti_dso_buffer->size());
ET_CHECK_OR_RETURN_ERROR(
outfile, AccessFailed, "Failed to write to file %s", so_path.c_str());
// Finish writing the file to disk
outfile.close();
// Free the buffer immediately after writing to disk
aoti_dso_buffer->Free();
// Load the lib
Result<void*> lib_handle_res = load_library(so_path);
if (!lib_handle_res.ok()) {
return lib_handle_res.error();
}
void* lib_handle = lib_handle_res.get();
processed->Free();
// Create handle and load function pointers into it
cuda::CudaDelegateHandle* handle = new cuda::CudaDelegateHandle();
handle->so_handle = lib_handle;
handle->so_path = so_path.string();
handle->method_name = method_name;
// Load function pointers specific to this handle's shared library
ET_CHECK_OK_OR_RETURN_ERROR(
load_function_pointers_into_handle(lib_handle, handle));
AOTInductorModelContainerHandle container_handle = nullptr;
ET_CHECK_OK_OR_RETURN_ERROR(
handle->create_with_device(&container_handle, 1, "cuda", nullptr));
ET_LOG(Info, "container_handle = %p", container_handle);
handle->container_handle = container_handle;
// Look into named data map for constant data
std::string weights_blob_key =
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
const void* weights_blob = buffer_res->data();
// Feed the weights blob into the container. Under the hood it's copying
// weights, so we should free the buffer immediately.
auto update_err = handle->update_constants_from_blob(
handle->container_handle, static_cast<const uint8_t*>(weights_blob));
if (update_err != Error::Ok) {
ET_LOG(Error, "update_constants_from_blob failed");
return update_err;
}
// Ensure all weight transfers are complete before execution
cudaDeviceSynchronize();
buffer_res->Free();
} else {
ET_LOG(
Info,
"weights_blob '%s' not found or update fn is null",
weights_blob_key.c_str());
}
// Use shared CUDA stream if enabled via options, otherwise create one.
// A shared stream ensures proper ordering across multiple methods
// (e.g., encoder, decoder, sampler) when using skip-copy optimization.
if (is_using_shared_cuda_stream()) {
// Shared stream mode: all handles share the same stream.
handle->cuda_stream = get_shared_cuda_stream();
ET_LOG(
Info,
"Using shared CUDA stream %p for method %s",
handle->get_cuda_stream(),
method_name.c_str());
} else {
// Per-handle stream mode: each handle owns its own stream.
handle->cuda_stream = cuda::create_cuda_stream();
if (handle->cuda_stream == nullptr) {
delete handle;
return Error::Internal;
}
ET_LOG(
Info,
"Created new CUDA stream %p for method %s",
handle->get_cuda_stream(),
method_name.c_str());
}
// ---------------------------------------------------------------
// Cross-method constant sharing (e.g., KV cache between prefill/decode).
//
// Only enabled when share_kv_cache_across_methods compile spec is set.
// The first container to initialize extracts its constants (keyed by
// original FQN) and stores the AtenTensorHandle's. Subsequent containers
// with matching FQNs are updated to point to the same GPU tensors via
// UpdateUserManagedConstantBufferPairs (user_managed = true → no copy,
// the source container retains ownership).
// ---------------------------------------------------------------
if (share_kv_cache && handle->get_num_constants &&
handle->get_constant_name && handle->get_constant_original_fqn &&
handle->extract_constants_map &&
handle->update_user_managed_constant_buffer_pairs) {
size_t num_constants = 0;
handle->get_num_constants(handle->container_handle, &num_constants);
if (num_constants > 0) {
// Build FQN → internal_name mapping for this container.
std::unordered_map<std::string, std::string> fqn_to_name;
for (size_t i = 0; i < num_constants; i++) {
const char* name = nullptr;
const char* fqn = nullptr;
handle->get_constant_name(handle->container_handle, i, &name);
handle->get_constant_original_fqn(handle->container_handle, i, &fqn);
if (name && fqn && fqn[0] != '\0') {
fqn_to_name[fqn] = name;
}
}
std::lock_guard<std::mutex> guard(shared_constants_mutex_);
if (!constants_extracted_) {
// First container: extract its constants and store by FQN.
std::unordered_map<std::string, AtenTensorHandle> extracted_map;
auto extract_err = handle->extract_constants_map(
handle->container_handle,
reinterpret_cast<AOTInductorConstantMapHandle>(&extracted_map),
/*use_inactive=*/false);
if (extract_err == Error::Ok) {
for (const auto& [fqn, internal_name] : fqn_to_name) {
auto it = extracted_map.find(fqn);
if (it != extracted_map.end()) {
shared_constant_tensors_[fqn] = it->second;
}
}
constants_extracted_ = true;
ET_LOG(
Info,
"Extracted %zu shared constants from method '%s'",
shared_constant_tensors_.size(),
method_name.c_str());
} else {
ET_LOG(
Error,
"Failed to extract constants from '%s'",
method_name.c_str());
delete handle;
return Error::Internal;
}
} else {
// Subsequent container: share matching constants from the first.
std::vector<AOTInductorConstantMapEntry> pairs;
for (const auto& [fqn, internal_name] : fqn_to_name) {
auto it = shared_constant_tensors_.find(fqn);
if (it != shared_constant_tensors_.end()) {
// UpdateUserManagedConstantBufferPairs matches against the
// codegen constant name (underscored), not the original FQN.
pairs.push_back({internal_name.c_str(), it->second});
}
}
if (!pairs.empty()) {
auto update_err = handle->update_user_managed_constant_buffer_pairs(
handle->container_handle,
pairs.data(),
pairs.size(),
/*use_inactive=*/false,
/*validate_full_update=*/false);
if (update_err == Error::Ok) {
ET_LOG(
Info,
"Shared %zu constants into method '%s'",
pairs.size(),
method_name.c_str());
} else {
ET_LOG(
Error,
"Failed to share constants into '%s'",
method_name.c_str());
delete handle;
return Error::Internal;
}
}
}
}
} else if (share_kv_cache) {
ET_LOG(
Error,
"share_kv_cache_across_methods requested but constant sharing APIs "
"not available for method '%s'",
method_name.c_str());
delete handle;
return Error::Internal;
} else {
ET_LOG(
Info,
"Constant sharing not requested for method '%s'",
method_name.c_str());
}
return (DelegateHandle*)handle; // Return the handle post-processing
}
// Once per execution
Error execute(
BackendExecutionContext& context,
DelegateHandle* handle_,
Span<EValue*> args) const override {
cuda::CudaDelegateHandle* handle = (cuda::CudaDelegateHandle*)handle_;
size_t n_inputs;
handle->get_num_inputs(handle->container_handle, &n_inputs);
size_t n_outputs;
handle->get_num_outputs(handle->container_handle, &n_outputs);
setCurrentCUDAStream(handle->get_cuda_stream(), 0);
size_t n_io_sum = 0;
ET_CHECK_OR_RETURN_ERROR(
!c10::add_overflows(n_inputs, n_outputs, &n_io_sum) &&
n_io_sum == args.size(),
InvalidArgument,
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
n_inputs,
n_outputs,
args.size())
// NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
// optimization. We need to create GPU copies for CUDA kernel execution
// using SlimTensor.
std::vector<SlimTensor*> gpu_inputs(n_inputs);
std::vector<SlimTensor*> gpu_outputs(n_outputs);
// Process input tensors: convert ETensor (CPU) to SlimTensor (GPU)
for (size_t i = 0; i < n_inputs; i++) {
auto* cpu_tensor = &(args[i]->toTensor());
// Check if input data is already on GPU (skip-copy optimization for
// inputs) This can happen when the caller has pre-staged data on GPU
cudaPointerAttributes attributes{};
const void* data_ptr = cpu_tensor->const_data_ptr();
if (data_ptr != nullptr) {
cudaError_t err = cudaPointerGetAttributes(&attributes, data_ptr);
if (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice) {
// Data is already on GPU - wrap it directly without copy
auto sizes = cpu_tensor->sizes();
auto strides = cpu_tensor->strides();
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
std::vector<int64_t> strides_vec(strides.begin(), strides.end());
gpu_inputs[i] = new SlimTensor(slim::from_blob(
const_cast<void*>(data_ptr),
slim::makeArrayRef(sizes_vec),
slim::makeArrayRef(strides_vec),
static_cast<slim::c10::ScalarType>(cpu_tensor->scalar_type()),
DEFAULT_CUDA_DEVICE,
0 // storage_offset
));
continue;
}
}
// Data is on CPU - use from_etensor to copy to GPU
gpu_inputs[i] = new SlimTensor(
from_etensor(*cpu_tensor, CPU_DEVICE, DEFAULT_CUDA_DEVICE));
}
// Process output tensors: create GPU SlimTensors for kernel output.
// Save pre-run handles to detect orphans after run().
std::vector<SlimTensor*> pre_run_outputs(n_outputs, nullptr);
for (size_t i = 0; i < n_outputs; i++) {
auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor());
auto sizes = cpu_output_tensor->sizes();
auto strides = cpu_output_tensor->strides();
auto scalar_type = cpu_output_tensor->scalar_type();
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
std::vector<int64_t> strides_vec(strides.begin(), strides.end());
gpu_outputs[i] = new SlimTensor(slim::empty_strided(
slim::makeArrayRef(sizes_vec),
slim::makeArrayRef(strides_vec),
static_cast<slim::c10::ScalarType>(scalar_type),
DEFAULT_CUDA_DEVICE));
pre_run_outputs[i] = gpu_outputs[i];
}
bool run_called = false;
// Scope guard: deletes any non-null gpu_outputs on exit. Normal paths
// null entries as they take ownership, so the guard only fires on
// early-return error paths. Also cleans up inputs if run() was never
// called (run() steals them via internal RAII).
executorch::backends::aoti::ScopeGuard cleanup([&]() noexcept {
if (!run_called) {
delete_slimtensor_vector(gpu_inputs);
}
for (size_t i = 0; i < gpu_outputs.size(); i++) {
if (gpu_outputs[i]) {
delete gpu_outputs[i];
}
}
});
// Run the AOTI container.
// NOTE: run() steals input handles (RAII wraps them at the start of
// run_impl) and may replace output handles with its own.
Result<cudaStream_t> cuda_stream_ret = getCurrentCUDAStream(0);
cudaStream_t cuda_stream = cuda_stream_ret.get();
ET_CHECK_OK_OR_RETURN_ERROR(cuda_stream_ret.error());
AOTIRuntimeError error = handle->run(
handle->container_handle,
reinterpret_cast<Tensor**>(gpu_inputs.data()),
n_inputs,
reinterpret_cast<Tensor**>(gpu_outputs.data()),
n_outputs,
static_cast<void*>(cuda_stream),
nullptr);
run_called = true;
// Delete orphaned pre-created outputs that run() replaced.
// Must happen before the error check — if run() fails after
// replacing some outputs, the originals would otherwise leak.
for (size_t i = 0; i < n_outputs; i++) {
if (pre_run_outputs[i] != gpu_outputs[i]) {
delete pre_run_outputs[i];
}
}
ET_CHECK_OR_RETURN_ERROR(
error == Error::Ok,
Internal,
"AOTInductorModelContainerRun failed with error code %d",
error);
const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);
if (copy_outputs) {
for (size_t i = 0; i < n_outputs; i++) {
auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor());
ET_CHECK_OK_OR_RETURN_ERROR(
copy_slimtensor_to_etensor_async(
gpu_outputs[i], cpu_output_tensor, cuda_stream),
"Failed to copy GPU output %zu back to CPU ETensor",
i);
delete gpu_outputs[i];
gpu_outputs[i] = nullptr;
}
} else {
// Skip-copy optimization: point ETensor directly to GPU data.
// Lifetime management: cache GPU tensors and delete previous round's.
{
std::lock_guard<std::mutex> guard(cached_outputs_mutex_);
auto& cached_outputs = cached_outputs_[handle];
delete_slimtensor_vector(cached_outputs);
for (size_t i = 0; i < n_outputs; i++) {
cached_outputs.push_back(gpu_outputs[i]);
gpu_outputs[i] = nullptr;
auto* output_etensor = &(args[i + n_inputs]->toTensor());
ET_CHECK_OK_OR_RETURN_ERROR(
wrap_slimtensor_to_etensor(cached_outputs.back(), output_etensor),
"Failed to wrap GPU output %zu into ETensor",
i);
}
}
}
return Error::Ok;
}
void destroy(DelegateHandle* handle_) const override {
if (handle_ == nullptr) {
return;
}
cuda::CudaDelegateHandle* handle = (cuda::CudaDelegateHandle*)handle_;
// Clean up cached output tensors for this handle
{
std::lock_guard<std::mutex> guard(cached_outputs_mutex_);
auto it = cached_outputs_.find(handle);
if (it != cached_outputs_.end()) {
delete_slimtensor_vector(it->second);
cached_outputs_.erase(it);
}
}
// The CUDA stream is managed by shared_ptr in the handle.
// It will be automatically destroyed when the last handle using it
// is destroyed. Just reset our reference.
handle->cuda_stream.reset();
// NOTE: AOTInductorModelContainerDelete does not work correctly with
// multiple .so files. Deleting one container frees shared resources,
// which causes segmentation faults when attempting to delete other
// containers. As a workaround, we skip explicit container deletion
// and defer cleanup to the OS.
// TODO(gasoonjia): Find a proper solution for safe container deletion.
// AOTInductorModelContainerDelete(handle->container_handle);
// Now close the shared library
if (handle->so_handle != nullptr) {
Error err = close_library(handle->so_handle);
ET_CHECK_OR_LOG_ERROR(
err == Error::Ok,
"Failed to close shared library for %s",
handle->so_path.c_str());
}
// Remove the temporary shared library file
if (!handle->so_path.empty()) {
std::error_code remove_error;
std::filesystem::remove(handle->so_path, remove_error);
ET_CHECK_OR_LOG_ERROR(
!remove_error,
"Failed to remove temporary shared library %s: %s",
handle->so_path.c_str(),
remove_error.message().c_str());
}
delete handle;
}
private:
mutable std::mutex skip_copy_method_mutex_;
std::string skip_copy_method_;
// Shared CUDA stream for all methods. When set (non-null), all methods use
// the same stream to ensure proper ordering (critical for skip-copy
// optimization). Created when use_shared_cuda_stream option is set to true.
// Managed via shared_ptr so it's automatically cleaned up when last handle
// is destroyed.
mutable std::mutex cuda_stream_mutex_;
std::shared_ptr<cudaStream_t> shared_cuda_stream_ = nullptr;
// Cached output tensors for skip-copy optimization.
// When skip-copy is enabled, output SlimTensors are cached here to keep
// the underlying GPU memory alive while the caller processes the results.
// Maps each CudaDelegateHandle* to its vector of cached output tensors.
mutable std::mutex cached_outputs_mutex_;
mutable std::
unordered_map<cuda::CudaDelegateHandle*, std::vector<SlimTensor*>>
cached_outputs_;
// Cross-method constant sharing state.
// When multiple AOTI containers share mutable buffers (e.g., KV cache),
// the first container's constants are extracted and stored here. Subsequent
// containers with matching FQNs share the same GPU tensors via
// UpdateUserManagedConstantBufferPairs.
mutable std::mutex shared_constants_mutex_;
// FQN → AtenTensorHandle from the source (first) container.
// The tensor handles are owned by the source container (which is never
// explicitly deleted — see destroy() comment).
mutable std::unordered_map<std::string, AtenTensorHandle>
shared_constant_tensors_;
// Whether we've already extracted constants from a source container.
mutable bool constants_extracted_ = false;
};
} // namespace executorch::backends::cuda
namespace executorch::backends {
namespace {
auto cls = cuda::CudaBackend();
executorch::runtime::Backend backend{"CudaBackend", &cls};
static executorch::runtime::Error success_with_compiler =
register_backend(backend);
} // namespace
} // namespace executorch::backends