1919
2020#include < cstdint>
2121#include < ctime>
22+ #include < mutex>
2223#include < vector>
2324
2425namespace executorch ::backends::cuda {
@@ -43,51 +44,49 @@ namespace {
4344struct RngState {
4445 unsigned long long seed;
4546 unsigned long long counter;
47+ // Per-launch scratch — written by advance_counter_kernel and read by
48+ // the main RNG kernels. Single-threaded host driver is assumed
49+ // (typical inference / CUDA-graph replay use case).
50+ unsigned long long base_scratch;
4651};
4752
4853static RngState* d_rng = nullptr ;
49- static bool g_rng_init_done = false ;
54+ // std::call_once guarantees one-shot initialization even when shims are
55+ // invoked from multiple host threads (e.g. concurrent models / streams).
56+ static std::once_flag g_rng_init_flag;
5057
5158// Initialize RNG state on the given stream.
52- // Must be called during warmup (before graph capture).
59+ // Must be called during warmup (before graph capture). Subsequent calls
60+ // from any thread are no-ops thanks to std::call_once.
5361void ensure_rng_init (cudaStream_t stream) {
54- if (!g_rng_init_done ) {
62+ std::call_once (g_rng_init_flag, [&]( ) {
5563 cudaMallocAsync (&d_rng, sizeof (RngState), stream);
5664 RngState h;
5765 h.seed = static_cast <unsigned long long >(time (nullptr ));
5866 h.counter = 0 ;
67+ h.base_scratch = 0 ;
5968 cudaMemcpyAsync (
6069 d_rng, &h, sizeof (RngState), cudaMemcpyHostToDevice, stream);
6170 // Synchronize to ensure the copy completes before we return
6271 // (the host-side RngState `h` is on the stack).
6372 cudaStreamSynchronize (stream);
64- g_rng_init_done = true ;
65- }
73+ });
6674}
6775
68- // Philox-based randint kernel that reads seed from device-resident state
69- // and atomically advances the counter. The counter pointer survives CUDA
70- // graph replay, so each replay produces different values.
76+ // Philox-based randint kernel. Reads its base offset from `rng->base_scratch`
77+ // (populated by `advance_counter_kernel` immediately before this launch).
78+ // This replaces the previous per-element atomicAdd contention with a single
79+ // atomic per kernel launch.
7180__global__ void philox_randint_graph_kernel (
7281 int64_t * __restrict__ out,
7382 int64_t numel,
7483 int64_t low,
7584 int64_t range,
7685 RngState* __restrict__ rng) {
77- // Each thread reads the seed and computes its unique offset.
78- // The "base offset" is read from rng->counter. We can't atomicAdd per
79- // thread, so we use a two-pass approach: first a single-thread kernel
80- // advances the counter, then the main kernel uses the old value.
81- // But that requires two kernel launches...
82- //
83- // Simpler: since numel=1 for randint seed generation, just one thread.
8486 int64_t idx = static_cast <int64_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
8587 if (idx < numel) {
86- // Each invocation atomically grabs `numel` slots from the counter.
87- // For numel=1, this is just one atomicAdd.
88- unsigned long long my_offset = atomicAdd (&rng->counter , 1ULL );
8988 curandStatePhilox4_32_10_t state;
90- curand_init (rng->seed , idx, my_offset , &state);
89+ curand_init (rng->seed , idx, rng-> base_scratch , &state);
9190 double val = curand_uniform_double (&state);
9291 int64_t ival = static_cast <int64_t >(val * range);
9392 out[idx] = low + (ival >= range ? range - 1 : ival);
@@ -101,9 +100,8 @@ __global__ void philox_rand_float_graph_kernel(
101100 RngState* __restrict__ rng) {
102101 int64_t idx = static_cast <int64_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
103102 if (idx < numel) {
104- unsigned long long my_offset = atomicAdd (&rng->counter , 1ULL );
105103 curandStatePhilox4_32_10_t state;
106- curand_init (rng->seed , idx, my_offset , &state);
104+ curand_init (rng->seed , idx, rng-> base_scratch , &state);
107105 out[idx] = curand_uniform (&state);
108106 }
109107}
@@ -115,9 +113,8 @@ __global__ void philox_rand_bf16_graph_kernel(
115113 RngState* __restrict__ rng) {
116114 int64_t idx = static_cast <int64_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
117115 if (idx < numel) {
118- unsigned long long my_offset = atomicAdd (&rng->counter , 1ULL );
119116 curandStatePhilox4_32_10_t state;
120- curand_init (rng->seed , idx, my_offset , &state);
117+ curand_init (rng->seed , idx, rng-> base_scratch , &state);
121118 float val = curand_uniform (&state);
122119 uint32_t bits;
123120 memcpy (&bits, &val, sizeof (uint32_t ));
@@ -127,6 +124,18 @@ __global__ void philox_rand_bf16_graph_kernel(
127124 }
128125}
129126
127+ // Single-thread helper that grabs a contiguous range of `numel` offsets
128+ // from the on-device counter and writes the base into `rng->base_scratch`.
129+ // Replaces `numel` per-element atomics with a single atomic per launch
130+ // while staying graph-capturable.
131+ __global__ void advance_counter_kernel (
132+ RngState* __restrict__ rng,
133+ unsigned long long numel) {
134+ if (blockIdx .x == 0 && threadIdx .x == 0 ) {
135+ rng->base_scratch = atomicAdd (&rng->counter , numel);
136+ }
137+ }
138+
130139} // anonymous namespace
131140
132141extern " C" {
@@ -188,6 +197,12 @@ AOTITorchError aoti_torch_cuda_rand(
188197 constexpr int kThreads = 256 ;
189198 int blocks = static_cast <int >((numel + kThreads - 1 ) / kThreads );
190199
200+ // Single atomicAdd per launch — grabs `numel` consecutive counter slots
201+ // for the kernel below, eliminating per-element contention on the GPU
202+ // counter.
203+ advance_counter_kernel<<<1 , 1 , 0 , stream>>> (
204+ d_rng, static_cast <unsigned long long >(numel));
205+
191206 if (scalar_type == ScalarType::Float) {
192207 philox_rand_float_graph_kernel<<<blocks, kThreads , 0 , stream>>> (
193208 static_cast <float *>((*ret0)->data_ptr ()), numel, d_rng);
@@ -244,6 +259,9 @@ AOTITorchError aoti_torch_cuda_randint_low_out(
244259
245260 constexpr int kThreads = 256 ;
246261 int blocks = static_cast <int >((numel + kThreads - 1 ) / kThreads );
262+ // One atomicAdd per launch; subsequent kernel reads `rng->base_scratch`.
263+ advance_counter_kernel<<<1 , 1 , 0 , stream>>> (
264+ d_rng, static_cast <unsigned long long >(numel));
247265 philox_randint_graph_kernel<<<blocks, kThreads , 0 , stream>>> (
248266 out_data, numel, low, range, d_rng);
249267
0 commit comments