Skip to content

Commit 5535d78

Browse files
authored
Merge branch 'gasoonjia/flashdecoding-pp-async-softmax' into fused-deltanet-decode
2 parents 2ca1b22 + 1a79d9d commit 5535d78

11 files changed

Lines changed: 469 additions & 97 deletions

File tree

backends/cuda/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
107107
runtime/shims/cuda_guard.cpp
108108
)
109109

110-
# Only build int4mm shim when CUDA language/toolchain is available.
110+
# Only build CUDA shims when CUDA language/toolchain is available.
111111
if(CMAKE_CUDA_COMPILER)
112112
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
113-
runtime/shims/sort.cu
113+
runtime/shims/sort.cu runtime/shims/rand.cu
114114
)
115115
endif()
116116

@@ -162,7 +162,7 @@ else()
162162
aoti_cuda_shims
163163
PRIVATE cuda_platform
164164
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
165-
CUDA::cudart ${CMAKE_DL_LIBS}
165+
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
166166
)
167167
endif()
168168

backends/cuda/cuda_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
146146
return {
147147
"at::_ops::_weight_int4pack_mm::call": None,
148148
"at::_ops::sort_stable::call": None,
149+
"aoti_torch_cuda_randint_low_out": None,
149150
}
150151

151152
@classmethod

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ class ET_EXPERIMENTAL CudaBackend final
693693

694694
gpu_inputs[i] = make_slimtensor_from_blob_with_etensor_metadata(
695695
static_ptr, cpu_tensor);
696+
696697
continue;
697698
}
698699

@@ -805,6 +806,7 @@ class ET_EXPERIMENTAL CudaBackend final
805806
// End capture → instantiate graph
806807
cudaError_t gerr =
807808
cudaStreamEndCapture(cuda_stream, &handle->cuda_graph_state.graph);
809+
808810
ET_CHECK_OR_RETURN_ERROR(
809811
gerr == cudaSuccess,
810812
Internal,
@@ -814,6 +816,7 @@ class ET_EXPERIMENTAL CudaBackend final
814816
gerr = cudaGraphInstantiate(
815817
&handle->cuda_graph_state.graph_exec,
816818
handle->cuda_graph_state.graph,
819+
817820
cudaGraphInstantiateFlagAutoFreeOnLaunch);
818821
ET_CHECK_OR_RETURN_ERROR(
819822
gerr == cudaSuccess,

backends/cuda/runtime/cuda_delegate_handle.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,44 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle {
149149

150150
// CUDA graph state (warmup, capture, replay, static buffers)
151151
CudaGraphState cuda_graph_state;
152+
// --- CUDA graph state ---
153+
// Phase: 0=disabled, 1=warmup, 2=captured (replay mode)
154+
int cuda_graph_phase = 0;
155+
int cuda_graph_warmup_remaining = 0;
156+
157+
// Captured graph and executable instance
158+
cudaGraph_t cuda_graph = nullptr;
159+
cudaGraphExec_t cuda_graph_exec = nullptr;
160+
161+
// Static input/output GPU buffers pinned during capture.
162+
// These hold the tensor metadata; the underlying data pointers are fixed
163+
// addresses that CUDA graph replay will write to / read from.
164+
// SlimTensor pointers — owned by this handle.
165+
std::vector<void*> static_input_ptrs; // raw GPU data pointers for inputs
166+
std::vector<void*> static_output_ptrs; // raw GPU data pointers for outputs
167+
std::vector<std::vector<int64_t>> static_input_sizes;
168+
std::vector<std::vector<int64_t>> static_input_strides;
169+
std::vector<std::vector<int64_t>> static_output_sizes;
170+
std::vector<std::vector<int64_t>> static_output_strides;
171+
std::vector<int> static_input_scalar_types;
172+
std::vector<int> static_output_scalar_types;
173+
std::vector<size_t> static_input_nbytes;
174+
std::vector<size_t> static_output_nbytes;
175+
176+
~CudaDelegateHandle() {
177+
if (cuda_graph_exec) {
178+
cudaGraphExecDestroy(cuda_graph_exec);
179+
}
180+
if (cuda_graph) {
181+
cudaGraphDestroy(cuda_graph);
182+
}
183+
// Only free input buffers — output buffers are owned by the AOTI runtime
184+
// (allocated during graph capture via the caching allocator).
185+
for (auto* ptr : static_input_ptrs) {
186+
if (ptr)
187+
cudaFree(ptr);
188+
}
189+
}
152190
};
153191

154192
} // namespace cuda
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cuda/runtime/shims/rand.h>
10+
11+
#include <executorch/backends/aoti/slim/cuda/guard.h>
12+
#include <executorch/backends/aoti/slim/factory/empty.h>
13+
#include <executorch/backends/aoti/slim/util/size_util.h>
14+
#include <executorch/runtime/platform/assert.h>
15+
#include <executorch/runtime/platform/log.h>
16+
17+
#include <cuda_runtime.h>
18+
#include <curand_kernel.h>
19+
20+
#include <cstdint>
21+
#include <ctime>
22+
#include <vector>
23+
24+
namespace executorch::backends::cuda {
25+
26+
namespace c10 = executorch::backends::aoti::slim::c10;
27+
using c10::Device;
28+
using c10::DeviceIndex;
29+
using c10::DeviceType;
30+
using c10::ScalarType;
31+
using executorch::backends::aoti::slim::empty_strided;
32+
using executorch::backends::aoti::slim::IntArrayRef;
33+
using executorch::backends::aoti::slim::makeArrayRef;
34+
35+
namespace {
36+
37+
// ---- GPU-resident RNG state ----
38+
// Seed and counter live in device memory allocated during the first call
39+
// (warmup phase, before CUDA graph capture). The counter is atomically
40+
// advanced by each kernel invocation on-device, so it automatically
41+
// produces different random sequences on every CUDA graph replay.
42+
43+
struct RngState {
44+
unsigned long long seed;
45+
unsigned long long counter;
46+
};
47+
48+
static RngState* d_rng = nullptr;
49+
static bool g_rng_init_done = false;
50+
51+
// Initialize RNG state on the given stream.
52+
// Must be called during warmup (before graph capture).
53+
void ensure_rng_init(cudaStream_t stream) {
54+
if (!g_rng_init_done) {
55+
cudaMallocAsync(&d_rng, sizeof(RngState), stream);
56+
RngState h;
57+
h.seed = static_cast<unsigned long long>(time(nullptr));
58+
h.counter = 0;
59+
cudaMemcpyAsync(
60+
d_rng, &h, sizeof(RngState), cudaMemcpyHostToDevice, stream);
61+
// Synchronize to ensure the copy completes before we return
62+
// (the host-side RngState `h` is on the stack).
63+
cudaStreamSynchronize(stream);
64+
g_rng_init_done = true;
65+
}
66+
}
67+
68+
// Philox-based randint kernel that reads seed from device-resident state
69+
// and atomically advances the counter. The counter pointer survives CUDA
70+
// graph replay, so each replay produces different values.
71+
__global__ void philox_randint_graph_kernel(
72+
int64_t* __restrict__ out,
73+
int64_t numel,
74+
int64_t low,
75+
int64_t range,
76+
RngState* __restrict__ rng) {
77+
// Each thread reads the seed and computes its unique offset.
78+
// The "base offset" is read from rng->counter. We can't atomicAdd per
79+
// thread, so we use a two-pass approach: first a single-thread kernel
80+
// advances the counter, then the main kernel uses the old value.
81+
// But that requires two kernel launches...
82+
//
83+
// Simpler: since numel=1 for randint seed generation, just one thread.
84+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
85+
if (idx < numel) {
86+
// Each invocation atomically grabs `numel` slots from the counter.
87+
// For numel=1, this is just one atomicAdd.
88+
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
89+
curandStatePhilox4_32_10_t state;
90+
curand_init(rng->seed, idx, my_offset, &state);
91+
double val = curand_uniform_double(&state);
92+
int64_t ival = static_cast<int64_t>(val * range);
93+
out[idx] = low + (ival >= range ? range - 1 : ival);
94+
}
95+
}
96+
97+
// Philox-based uniform float32 generator (graph-safe version).
98+
__global__ void philox_rand_float_graph_kernel(
99+
float* __restrict__ out,
100+
int64_t numel,
101+
RngState* __restrict__ rng) {
102+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
103+
if (idx < numel) {
104+
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
105+
curandStatePhilox4_32_10_t state;
106+
curand_init(rng->seed, idx, my_offset, &state);
107+
out[idx] = curand_uniform(&state);
108+
}
109+
}
110+
111+
// Philox-based uniform bfloat16 generator (graph-safe version).
112+
__global__ void philox_rand_bf16_graph_kernel(
113+
uint16_t* __restrict__ out,
114+
int64_t numel,
115+
RngState* __restrict__ rng) {
116+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
117+
if (idx < numel) {
118+
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
119+
curandStatePhilox4_32_10_t state;
120+
curand_init(rng->seed, idx, my_offset, &state);
121+
float val = curand_uniform(&state);
122+
uint32_t bits;
123+
memcpy(&bits, &val, sizeof(uint32_t));
124+
uint32_t lsb = (bits >> 16) & 1;
125+
bits += 0x7FFFu + lsb;
126+
out[idx] = static_cast<uint16_t>(bits >> 16);
127+
}
128+
}
129+
130+
} // anonymous namespace
131+
132+
extern "C" {
133+
134+
AOTITorchError aoti_torch_cuda_rand(
135+
const int64_t* size,
136+
int64_t size_len_,
137+
int32_t* dtype,
138+
int32_t* layout,
139+
int32_t* device,
140+
int32_t device_index_,
141+
int32_t* pin_memory,
142+
SlimTensor** ret0) {
143+
(void)layout;
144+
(void)device;
145+
(void)pin_memory;
146+
147+
ET_CHECK_OR_RETURN_ERROR(
148+
ret0 != nullptr,
149+
InvalidArgument,
150+
"aoti_torch_cuda_rand: ret0 is null");
151+
152+
// Default to float32 if dtype not specified.
153+
ScalarType scalar_type = ScalarType::Float;
154+
if (dtype != nullptr) {
155+
scalar_type = static_cast<ScalarType>(*dtype);
156+
}
157+
158+
// Compute contiguous strides and total elements.
159+
std::vector<int64_t> strides(size_len_);
160+
int64_t numel = 1;
161+
for (int64_t i = size_len_ - 1; i >= 0; i--) {
162+
strides[i] = numel;
163+
numel *= size[i];
164+
}
165+
166+
// Allocate output tensor.
167+
IntArrayRef sizes_ref(size, static_cast<size_t>(size_len_));
168+
*ret0 = new SlimTensor(empty_strided(
169+
sizes_ref,
170+
makeArrayRef(strides),
171+
scalar_type,
172+
Device(DeviceType::CUDA, static_cast<DeviceIndex>(device_index_))));
173+
174+
if (numel == 0) {
175+
return Error::Ok;
176+
}
177+
178+
// Get the current CUDA stream.
179+
auto stream_result = getCurrentCUDAStream(0);
180+
ET_CHECK_OR_RETURN_ERROR(
181+
stream_result.ok(),
182+
Internal,
183+
"aoti_torch_cuda_rand: failed to get CUDA stream");
184+
cudaStream_t stream = stream_result.get();
185+
186+
ensure_rng_init(stream);
187+
188+
constexpr int kThreads = 256;
189+
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
190+
191+
if (scalar_type == ScalarType::Float) {
192+
philox_rand_float_graph_kernel<<<blocks, kThreads, 0, stream>>>(
193+
static_cast<float*>((*ret0)->data_ptr()), numel, d_rng);
194+
} else if (scalar_type == ScalarType::BFloat16) {
195+
philox_rand_bf16_graph_kernel<<<blocks, kThreads, 0, stream>>>(
196+
static_cast<uint16_t*>((*ret0)->data_ptr()), numel, d_rng);
197+
} else {
198+
ET_LOG(
199+
Error,
200+
"aoti_torch_cuda_rand: unsupported dtype %d",
201+
static_cast<int>(scalar_type));
202+
return Error::NotSupported;
203+
}
204+
205+
return Error::Ok;
206+
}
207+
208+
AOTITorchError aoti_torch_cuda_randint_low_out(
209+
SlimTensor* out,
210+
int64_t low,
211+
int64_t high,
212+
const int64_t* size,
213+
int64_t size_len_) {
214+
ET_CHECK_OR_RETURN_ERROR(
215+
out != nullptr,
216+
InvalidArgument,
217+
"aoti_torch_cuda_randint_low_out: out tensor is null");
218+
219+
ET_CHECK_OR_RETURN_ERROR(
220+
high > low,
221+
InvalidArgument,
222+
"aoti_torch_cuda_randint_low_out: requires high > low");
223+
224+
int64_t numel = 1;
225+
for (int64_t i = 0; i < size_len_; i++) {
226+
numel *= size[i];
227+
}
228+
if (numel == 0) {
229+
return Error::Ok;
230+
}
231+
232+
// Get the current CUDA stream.
233+
auto stream_result = getCurrentCUDAStream(0);
234+
ET_CHECK_OR_RETURN_ERROR(
235+
stream_result.ok(),
236+
Internal,
237+
"aoti_torch_cuda_randint_low_out: failed to get CUDA stream");
238+
cudaStream_t stream = stream_result.get();
239+
240+
ensure_rng_init(stream);
241+
242+
int64_t range = high - low;
243+
int64_t* out_data = static_cast<int64_t*>(out->data_ptr());
244+
245+
constexpr int kThreads = 256;
246+
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
247+
philox_randint_graph_kernel<<<blocks, kThreads, 0, stream>>>(
248+
out_data, numel, low, range, d_rng);
249+
250+
return Error::Ok;
251+
}
252+
253+
} // extern "C"
254+
255+
} // namespace executorch::backends::cuda

0 commit comments

Comments
 (0)