Skip to content

Commit c1731fd

Browse files
Gasoonjiagasoonjia
andauthored
Add GPU-side Gumbel-max sampling for CUDA graph compatibility (#18844)
This PR replaces cpu sampler with CUDA sampler and fuse sampler with forward method to both eliminate unnecessary data transfer and improve sampling efficient. Decode performance increases from 113.8 token/s to 119.5 token/s Once we land the device support pipeline, we should decompose the forward method with sampling. --------- Co-authored-by: gasoonjia <gasoonjia@fb.com>
1 parent b4d4507 commit c1731fd

15 files changed

Lines changed: 1026 additions & 42 deletions

File tree

.github/workflows/cuda.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ jobs:
145145
# Run CUDA backend Python tests
146146
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
147147
148-
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache)
149-
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts="
148+
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler)
149+
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="
150150
151151
export-model-cuda-artifact:
152152
name: export-model-cuda-artifact

backends/cuda/CMakeLists.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
107107
runtime/shims/cuda_guard.cpp
108108
)
109109

110-
# Only build int4mm shim when CUDA language/toolchain is available.
110+
# Only build CUDA shims when CUDA language/toolchain is available.
111111
if(CMAKE_CUDA_COMPILER)
112112
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
113-
runtime/shims/sort.cu
113+
runtime/shims/sort.cu runtime/shims/rand.cu
114114
)
115115
endif()
116116

@@ -152,7 +152,8 @@ endif()
152152
# retention.
153153
if(_cuda_is_msvc_toolchain)
154154
target_link_libraries(
155-
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart ${CMAKE_DL_LIBS}
155+
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
156+
${CMAKE_DL_LIBS}
156157
)
157158
# Link object library directly so symbols are pulled exactly once while
158159
# avoiding duplicate static/object inclusion and interface leakage.
@@ -162,7 +163,7 @@ else()
162163
aoti_cuda_shims
163164
PRIVATE cuda_platform
164165
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
165-
CUDA::cudart ${CMAKE_DL_LIBS}
166+
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
166167
)
167168
endif()
168169

backends/cuda/cuda_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
146146
return {
147147
"at::_ops::_weight_int4pack_mm::call": None,
148148
"at::_ops::sort_stable::call": None,
149+
"aoti_torch_cuda_randint_low_out": None,
149150
}
150151

151152
@classmethod

backends/cuda/runtime/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ runtime.cxx_library(
3333
"shims/cuda_guard.cpp",
3434
"shims/int4mm.cu",
3535
"shims/memory.cpp",
36+
"shims/rand.cu",
3637
"shims/sort.cu",
3738
"shims/tensor_attribute.cpp",
3839
],
@@ -41,6 +42,7 @@ runtime.cxx_library(
4142
"shims/int4mm.cuh",
4243
"shims/int4mm.h",
4344
"shims/memory.h",
45+
"shims/rand.h",
4446
"shims/sort.h",
4547
"shims/tensor_attribute.h",
4648
"utils.h",
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cuda/runtime/shims/rand.h>
10+
11+
#include <executorch/backends/aoti/slim/cuda/guard.h>
12+
#include <executorch/backends/aoti/slim/factory/empty.h>
13+
#include <executorch/backends/aoti/slim/util/size_util.h>
14+
#include <executorch/runtime/platform/assert.h>
15+
#include <executorch/runtime/platform/log.h>
16+
17+
#include <cuda_runtime.h>
18+
#include <curand_kernel.h>
19+
20+
#include <cstdint>
21+
#include <ctime>
22+
#include <mutex>
23+
#include <vector>
24+
25+
namespace executorch::backends::cuda {
26+
27+
namespace c10 = executorch::backends::aoti::slim::c10;
28+
using c10::Device;
29+
using c10::DeviceIndex;
30+
using c10::DeviceType;
31+
using c10::ScalarType;
32+
using executorch::backends::aoti::slim::empty_strided;
33+
using executorch::backends::aoti::slim::IntArrayRef;
34+
using executorch::backends::aoti::slim::makeArrayRef;
35+
36+
namespace {
37+
38+
// ---- GPU-resident RNG state ----
39+
// Seed and counter live in device memory allocated during the first call
40+
// (warmup phase, before CUDA graph capture). The counter is atomically
41+
// advanced by each kernel invocation on-device, so it automatically
42+
// produces different random sequences on every CUDA graph replay.
43+
44+
struct RngState {
45+
unsigned long long seed;
46+
unsigned long long counter;
47+
// Per-launch scratch — written by advance_counter_kernel and read by
48+
// the main RNG kernels. Single-threaded host driver is assumed
49+
// (typical inference / CUDA-graph replay use case).
50+
unsigned long long base_scratch;
51+
};
52+
53+
static RngState* d_rng = nullptr;
54+
// std::call_once guarantees one-shot initialization even when shims are
55+
// invoked from multiple host threads (e.g. concurrent models / streams).
56+
static std::once_flag g_rng_init_flag;
57+
58+
// Initialize RNG state on the given stream.
59+
// Must be called during warmup (before graph capture). Subsequent calls
60+
// from any thread are no-ops thanks to std::call_once.
61+
void ensure_rng_init(cudaStream_t stream) {
62+
std::call_once(g_rng_init_flag, [&]() {
63+
cudaMallocAsync(&d_rng, sizeof(RngState), stream);
64+
RngState h;
65+
h.seed = static_cast<unsigned long long>(time(nullptr));
66+
h.counter = 0;
67+
h.base_scratch = 0;
68+
cudaMemcpyAsync(
69+
d_rng, &h, sizeof(RngState), cudaMemcpyHostToDevice, stream);
70+
// Synchronize to ensure the copy completes before we return
71+
// (the host-side RngState `h` is on the stack).
72+
cudaStreamSynchronize(stream);
73+
});
74+
}
75+
76+
// Philox-based randint kernel. Reads its base offset from `rng->base_scratch`
77+
// (populated by `advance_counter_kernel` immediately before this launch).
78+
// This replaces the previous per-element atomicAdd contention with a single
79+
// atomic per kernel launch.
80+
__global__ void philox_randint_graph_kernel(
81+
int64_t* __restrict__ out,
82+
int64_t numel,
83+
int64_t low,
84+
int64_t range,
85+
RngState* __restrict__ rng) {
86+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
87+
if (idx < numel) {
88+
curandStatePhilox4_32_10_t state;
89+
curand_init(rng->seed, idx, rng->base_scratch, &state);
90+
double val = curand_uniform_double(&state);
91+
int64_t ival = static_cast<int64_t>(val * range);
92+
out[idx] = low + (ival >= range ? range - 1 : ival);
93+
}
94+
}
95+
96+
// Philox-based uniform float32 generator (graph-safe version).
97+
__global__ void philox_rand_float_graph_kernel(
98+
float* __restrict__ out,
99+
int64_t numel,
100+
RngState* __restrict__ rng) {
101+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
102+
if (idx < numel) {
103+
curandStatePhilox4_32_10_t state;
104+
curand_init(rng->seed, idx, rng->base_scratch, &state);
105+
out[idx] = curand_uniform(&state);
106+
}
107+
}
108+
109+
// Philox-based uniform bfloat16 generator (graph-safe version).
110+
__global__ void philox_rand_bf16_graph_kernel(
111+
uint16_t* __restrict__ out,
112+
int64_t numel,
113+
RngState* __restrict__ rng) {
114+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
115+
if (idx < numel) {
116+
curandStatePhilox4_32_10_t state;
117+
curand_init(rng->seed, idx, rng->base_scratch, &state);
118+
float val = curand_uniform(&state);
119+
uint32_t bits;
120+
memcpy(&bits, &val, sizeof(uint32_t));
121+
uint32_t lsb = (bits >> 16) & 1;
122+
bits += 0x7FFFu + lsb;
123+
out[idx] = static_cast<uint16_t>(bits >> 16);
124+
}
125+
}
126+
127+
// Single-thread helper that grabs a contiguous range of `numel` offsets
128+
// from the on-device counter and writes the base into `rng->base_scratch`.
129+
// Replaces `numel` per-element atomics with a single atomic per launch
130+
// while staying graph-capturable.
131+
__global__ void advance_counter_kernel(
132+
RngState* __restrict__ rng,
133+
unsigned long long numel) {
134+
if (blockIdx.x == 0 && threadIdx.x == 0) {
135+
rng->base_scratch = atomicAdd(&rng->counter, numel);
136+
}
137+
}
138+
139+
} // anonymous namespace
140+
141+
extern "C" {
142+
143+
AOTITorchError aoti_torch_cuda_rand(
144+
const int64_t* size,
145+
int64_t size_len_,
146+
int32_t* dtype,
147+
int32_t* layout,
148+
int32_t* device,
149+
int32_t device_index_,
150+
int32_t* pin_memory,
151+
SlimTensor** ret0) {
152+
(void)layout;
153+
(void)device;
154+
(void)pin_memory;
155+
156+
ET_CHECK_OR_RETURN_ERROR(
157+
ret0 != nullptr,
158+
InvalidArgument,
159+
"aoti_torch_cuda_rand: ret0 is null");
160+
161+
// Default to float32 if dtype not specified.
162+
ScalarType scalar_type = ScalarType::Float;
163+
if (dtype != nullptr) {
164+
scalar_type = static_cast<ScalarType>(*dtype);
165+
}
166+
167+
// Compute contiguous strides and total elements.
168+
std::vector<int64_t> strides(size_len_);
169+
int64_t numel = 1;
170+
for (int64_t i = size_len_ - 1; i >= 0; i--) {
171+
strides[i] = numel;
172+
numel *= size[i];
173+
}
174+
175+
// Allocate output tensor.
176+
IntArrayRef sizes_ref(size, static_cast<size_t>(size_len_));
177+
*ret0 = new SlimTensor(empty_strided(
178+
sizes_ref,
179+
makeArrayRef(strides),
180+
scalar_type,
181+
Device(DeviceType::CUDA, static_cast<DeviceIndex>(device_index_))));
182+
183+
if (numel == 0) {
184+
return Error::Ok;
185+
}
186+
187+
// Get the current CUDA stream.
188+
auto stream_result = getCurrentCUDAStream(0);
189+
ET_CHECK_OR_RETURN_ERROR(
190+
stream_result.ok(),
191+
Internal,
192+
"aoti_torch_cuda_rand: failed to get CUDA stream");
193+
cudaStream_t stream = stream_result.get();
194+
195+
ensure_rng_init(stream);
196+
197+
constexpr int kThreads = 256;
198+
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
199+
200+
// Single atomicAdd per launch — grabs `numel` consecutive counter slots
201+
// for the kernel below, eliminating per-element contention on the GPU
202+
// counter.
203+
advance_counter_kernel<<<1, 1, 0, stream>>>(
204+
d_rng, static_cast<unsigned long long>(numel));
205+
206+
if (scalar_type == ScalarType::Float) {
207+
philox_rand_float_graph_kernel<<<blocks, kThreads, 0, stream>>>(
208+
static_cast<float*>((*ret0)->data_ptr()), numel, d_rng);
209+
} else if (scalar_type == ScalarType::BFloat16) {
210+
philox_rand_bf16_graph_kernel<<<blocks, kThreads, 0, stream>>>(
211+
static_cast<uint16_t*>((*ret0)->data_ptr()), numel, d_rng);
212+
} else {
213+
ET_LOG(
214+
Error,
215+
"aoti_torch_cuda_rand: unsupported dtype %d",
216+
static_cast<int>(scalar_type));
217+
return Error::NotSupported;
218+
}
219+
220+
return Error::Ok;
221+
}
222+
223+
AOTITorchError aoti_torch_cuda_randint_low_out(
224+
SlimTensor* out,
225+
int64_t low,
226+
int64_t high,
227+
const int64_t* size,
228+
int64_t size_len_) {
229+
ET_CHECK_OR_RETURN_ERROR(
230+
out != nullptr,
231+
InvalidArgument,
232+
"aoti_torch_cuda_randint_low_out: out tensor is null");
233+
234+
ET_CHECK_OR_RETURN_ERROR(
235+
high > low,
236+
InvalidArgument,
237+
"aoti_torch_cuda_randint_low_out: requires high > low");
238+
239+
int64_t numel = 1;
240+
for (int64_t i = 0; i < size_len_; i++) {
241+
numel *= size[i];
242+
}
243+
if (numel == 0) {
244+
return Error::Ok;
245+
}
246+
247+
// Get the current CUDA stream.
248+
auto stream_result = getCurrentCUDAStream(0);
249+
ET_CHECK_OR_RETURN_ERROR(
250+
stream_result.ok(),
251+
Internal,
252+
"aoti_torch_cuda_randint_low_out: failed to get CUDA stream");
253+
cudaStream_t stream = stream_result.get();
254+
255+
ensure_rng_init(stream);
256+
257+
int64_t range = high - low;
258+
int64_t* out_data = static_cast<int64_t*>(out->data_ptr());
259+
260+
constexpr int kThreads = 256;
261+
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
262+
// One atomicAdd per launch; subsequent kernel reads `rng->base_scratch`.
263+
advance_counter_kernel<<<1, 1, 0, stream>>>(
264+
d_rng, static_cast<unsigned long long>(numel));
265+
philox_randint_graph_kernel<<<blocks, kThreads, 0, stream>>>(
266+
out_data, numel, low, range, d_rng);
267+
268+
return Error::Ok;
269+
}
270+
271+
} // extern "C"
272+
273+
} // namespace executorch::backends::cuda

0 commit comments

Comments
 (0)