Skip to content

Commit 2828ba9

Browse files
authored
Merge branch 'cuda-graph-sampling' into gasoonjia/flashdecoding-pp-async-softmax
2 parents 1a79d9d + 61d47aa commit 2828ba9

12 files changed

Lines changed: 667 additions & 38 deletions

File tree

.github/workflows/cuda.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ jobs:
145145
# Run CUDA backend Python tests
146146
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
147147
148-
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache)
149-
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts="
148+
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler)
149+
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="
150150
151151
export-model-cuda-artifact:
152152
name: export-model-cuda-artifact

backends/cuda/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ endif()
152152
# retention.
153153
if(_cuda_is_msvc_toolchain)
154154
target_link_libraries(
155-
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart ${CMAKE_DL_LIBS}
155+
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
156+
${CMAKE_DL_LIBS}
156157
)
157158
# Link object library directly so symbols are pulled exactly once while
158159
# avoiding duplicate static/object inclusion and interface leakage.

backends/cuda/runtime/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ runtime.cxx_library(
3333
"shims/cuda_guard.cpp",
3434
"shims/int4mm.cu",
3535
"shims/memory.cpp",
36+
"shims/rand.cu",
3637
"shims/sort.cu",
3738
"shims/tensor_attribute.cpp",
3839
],
@@ -41,6 +42,7 @@ runtime.cxx_library(
4142
"shims/int4mm.cuh",
4243
"shims/int4mm.h",
4344
"shims/memory.h",
45+
"shims/rand.h",
4446
"shims/sort.h",
4547
"shims/tensor_attribute.h",
4648
"utils.h",

backends/cuda/runtime/shims/rand.cu

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <cstdint>
2121
#include <ctime>
22+
#include <mutex>
2223
#include <vector>
2324

2425
namespace executorch::backends::cuda {
@@ -43,51 +44,49 @@ namespace {
4344
struct RngState {
4445
unsigned long long seed;
4546
unsigned long long counter;
47+
// Per-launch scratch — written by advance_counter_kernel and read by
48+
// the main RNG kernels. Single-threaded host driver is assumed
49+
// (typical inference / CUDA-graph replay use case).
50+
unsigned long long base_scratch;
4651
};
4752

4853
static RngState* d_rng = nullptr;
49-
static bool g_rng_init_done = false;
54+
// std::call_once guarantees one-shot initialization even when shims are
55+
// invoked from multiple host threads (e.g. concurrent models / streams).
56+
static std::once_flag g_rng_init_flag;
5057

5158
// Initialize RNG state on the given stream.
52-
// Must be called during warmup (before graph capture).
59+
// Must be called during warmup (before graph capture). Subsequent calls
60+
// from any thread are no-ops thanks to std::call_once.
5361
void ensure_rng_init(cudaStream_t stream) {
54-
if (!g_rng_init_done) {
62+
std::call_once(g_rng_init_flag, [&]() {
5563
cudaMallocAsync(&d_rng, sizeof(RngState), stream);
5664
RngState h;
5765
h.seed = static_cast<unsigned long long>(time(nullptr));
5866
h.counter = 0;
67+
h.base_scratch = 0;
5968
cudaMemcpyAsync(
6069
d_rng, &h, sizeof(RngState), cudaMemcpyHostToDevice, stream);
6170
// Synchronize to ensure the copy completes before we return
6271
// (the host-side RngState `h` is on the stack).
6372
cudaStreamSynchronize(stream);
64-
g_rng_init_done = true;
65-
}
73+
});
6674
}
6775

68-
// Philox-based randint kernel that reads seed from device-resident state
69-
// and atomically advances the counter. The counter pointer survives CUDA
70-
// graph replay, so each replay produces different values.
76+
// Philox-based randint kernel. Reads its base offset from `rng->base_scratch`
77+
// (populated by `advance_counter_kernel` immediately before this launch).
78+
// This replaces the previous per-element atomicAdd contention with a single
79+
// atomic per kernel launch.
7180
__global__ void philox_randint_graph_kernel(
7281
int64_t* __restrict__ out,
7382
int64_t numel,
7483
int64_t low,
7584
int64_t range,
7685
RngState* __restrict__ rng) {
77-
// Each thread reads the seed and computes its unique offset.
78-
// The "base offset" is read from rng->counter. We can't atomicAdd per
79-
// thread, so we use a two-pass approach: first a single-thread kernel
80-
// advances the counter, then the main kernel uses the old value.
81-
// But that requires two kernel launches...
82-
//
83-
// Simpler: since numel=1 for randint seed generation, just one thread.
8486
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
8587
if (idx < numel) {
86-
// Each invocation atomically grabs `numel` slots from the counter.
87-
// For numel=1, this is just one atomicAdd.
88-
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
8988
curandStatePhilox4_32_10_t state;
90-
curand_init(rng->seed, idx, my_offset, &state);
89+
curand_init(rng->seed, idx, rng->base_scratch, &state);
9190
double val = curand_uniform_double(&state);
9291
int64_t ival = static_cast<int64_t>(val * range);
9392
out[idx] = low + (ival >= range ? range - 1 : ival);
@@ -101,9 +100,8 @@ __global__ void philox_rand_float_graph_kernel(
101100
RngState* __restrict__ rng) {
102101
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
103102
if (idx < numel) {
104-
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
105103
curandStatePhilox4_32_10_t state;
106-
curand_init(rng->seed, idx, my_offset, &state);
104+
curand_init(rng->seed, idx, rng->base_scratch, &state);
107105
out[idx] = curand_uniform(&state);
108106
}
109107
}
@@ -115,9 +113,8 @@ __global__ void philox_rand_bf16_graph_kernel(
115113
RngState* __restrict__ rng) {
116114
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
117115
if (idx < numel) {
118-
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
119116
curandStatePhilox4_32_10_t state;
120-
curand_init(rng->seed, idx, my_offset, &state);
117+
curand_init(rng->seed, idx, rng->base_scratch, &state);
121118
float val = curand_uniform(&state);
122119
uint32_t bits;
123120
memcpy(&bits, &val, sizeof(uint32_t));
@@ -127,6 +124,18 @@ __global__ void philox_rand_bf16_graph_kernel(
127124
}
128125
}
129126

127+
// Single-thread helper that grabs a contiguous range of `numel` offsets
128+
// from the on-device counter and writes the base into `rng->base_scratch`.
129+
// Replaces `numel` per-element atomics with a single atomic per launch
130+
// while staying graph-capturable.
131+
__global__ void advance_counter_kernel(
132+
RngState* __restrict__ rng,
133+
unsigned long long numel) {
134+
if (blockIdx.x == 0 && threadIdx.x == 0) {
135+
rng->base_scratch = atomicAdd(&rng->counter, numel);
136+
}
137+
}
138+
130139
} // anonymous namespace
131140

132141
extern "C" {
@@ -188,6 +197,12 @@ AOTITorchError aoti_torch_cuda_rand(
188197
constexpr int kThreads = 256;
189198
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
190199

200+
// Single atomicAdd per launch — grabs `numel` consecutive counter slots
201+
// for the kernel below, eliminating per-element contention on the GPU
202+
// counter.
203+
advance_counter_kernel<<<1, 1, 0, stream>>>(
204+
d_rng, static_cast<unsigned long long>(numel));
205+
191206
if (scalar_type == ScalarType::Float) {
192207
philox_rand_float_graph_kernel<<<blocks, kThreads, 0, stream>>>(
193208
static_cast<float*>((*ret0)->data_ptr()), numel, d_rng);
@@ -244,6 +259,9 @@ AOTITorchError aoti_torch_cuda_randint_low_out(
244259

245260
constexpr int kThreads = 256;
246261
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
262+
// One atomicAdd per launch; subsequent kernel reads `rng->base_scratch`.
263+
advance_counter_kernel<<<1, 1, 0, stream>>>(
264+
d_rng, static_cast<unsigned long long>(numel));
247265
philox_randint_graph_kernel<<<blocks, kThreads, 0, stream>>>(
248266
out_data, numel, low, range, d_rng);
249267

backends/cuda/runtime/shims/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ set(CUDA_SHIM_TESTS
4242
test_aoti_torch_delete_tensor_object
4343
test_aoti_torch__reinterpret_tensor
4444
test_aoti_torch_copy_
45+
test_aoti_torch_cuda_rand
4546
test_aoti_torch_new_tensor_handle
4647
test_aoti_torch_item_bool
4748
test_aoti_torch_assign_tensors_out

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def define_common_targets():
3838
cuda_shim_cpp_unittest("aoti_torch_copy_")
3939
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
4040
cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm")
41+
cuda_shim_cpp_unittest("aoti_torch_cuda_rand")
4142
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
4243
cuda_shim_cpp_unittest("aoti_torch_item_bool")
4344
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")

0 commit comments

Comments
 (0)