Skip to content

Commit ac9efa7

Browse files
authored
turn rand.h into [low, high) following torch.rand pattern (#19468)
This diff makes aoti rand shim follow torch.rand pattern to generate outputs falling into [low, high) range. Reviewed By: GregoryComer Differential Revision: D104723400
1 parent a6bfb3e commit ac9efa7

3 files changed

Lines changed: 48 additions & 17 deletions

File tree

backends/cuda/runtime/shims/rand.cu

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ void ensure_rng_init(cudaStream_t stream) {
7777
// (populated by `advance_counter_kernel` immediately before this launch).
7878
// This replaces the previous per-element atomicAdd contention with a single
7979
// atomic per kernel launch.
80+
//
81+
// Matches PyTorch's `transformation::uniform_int_from_to` semantics: builds
82+
// a 64-bit random value from two 32-bit curand draws, then takes
83+
// `val % range + low` so the output lies in [low, high).
8084
__global__ void philox_randint_graph_kernel(
8185
int64_t* __restrict__ out,
8286
int64_t numel,
@@ -87,13 +91,27 @@ __global__ void philox_randint_graph_kernel(
8791
if (idx < numel) {
8892
curandStatePhilox4_32_10_t state;
8993
curand_init(rng->seed, idx, rng->base_scratch, &state);
90-
double val = curand_uniform_double(&state);
91-
int64_t ival = static_cast<int64_t>(val * range);
92-
out[idx] = low + (ival >= range ? range - 1 : ival);
94+
uint32_t hi = curand(&state);
95+
uint32_t lo = curand(&state);
96+
uint64_t rval = (static_cast<uint64_t>(hi) << 32) | static_cast<uint64_t>(lo);
97+
uint64_t urange = static_cast<uint64_t>(range);
98+
out[idx] = low + static_cast<int64_t>(rval % urange);
9399
}
94100
}
95101

96-
// Philox-based uniform float32 generator (graph-safe version).
102+
// Maps a uniformly distributed uint32 to a float32 in [0, 1) following the
103+
// pattern used by PyTorch's `transformation::uniform_real` in
104+
// aten/src/ATen/native/cuda/DistributionTemplates.h: keep the low 24 mantissa
105+
// bits and divide by 2^24.
106+
__device__ inline float uniform_real_from_uint32(uint32_t val) {
107+
// std::numeric_limits<float>::digits == 24
108+
constexpr uint32_t kMantissaMask = (1u << 24) - 1;
109+
constexpr float kDivisor = 1.0f / static_cast<float>(1u << 24);
110+
return static_cast<float>(val & kMantissaMask) * kDivisor;
111+
}
112+
113+
// Philox-based uniform float32 generator (graph-safe version). Produces
114+
// values in [0, 1) to match torch.rand semantics.
97115
__global__ void philox_rand_float_graph_kernel(
98116
float* __restrict__ out,
99117
int64_t numel,
@@ -102,11 +120,12 @@ __global__ void philox_rand_float_graph_kernel(
102120
if (idx < numel) {
103121
curandStatePhilox4_32_10_t state;
104122
curand_init(rng->seed, idx, rng->base_scratch, &state);
105-
out[idx] = curand_uniform(&state);
123+
out[idx] = uniform_real_from_uint32(curand(&state));
106124
}
107125
}
108126

109-
// Philox-based uniform bfloat16 generator (graph-safe version).
127+
// Philox-based uniform bfloat16 generator (graph-safe version). Produces a
128+
// float in [0, 1) and rounds to bfloat16 with round-to-nearest-even.
110129
__global__ void philox_rand_bf16_graph_kernel(
111130
uint16_t* __restrict__ out,
112131
int64_t numel,
@@ -115,7 +134,7 @@ __global__ void philox_rand_bf16_graph_kernel(
115134
if (idx < numel) {
116135
curandStatePhilox4_32_10_t state;
117136
curand_init(rng->seed, idx, rng->base_scratch, &state);
118-
float val = curand_uniform(&state);
137+
float val = uniform_real_from_uint32(curand(&state));
119138
uint32_t bits;
120139
memcpy(&bits, &val, sizeof(uint32_t));
121140
uint32_t lsb = (bits >> 16) & 1;

backends/cuda/runtime/shims/rand.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,19 @@ using SlimTensor = executorch::backends::aoti::slim::SlimTensor;
2525
extern "C" {
2626

2727
/**
28-
* Generates a tensor filled with uniform random values in [0, 1).
28+
* Generates a tensor filled with uniform random values in [0, 1), matching
29+
* the behavior of torch.rand / aten::rand (see
30+
* aten/src/ATen/native/cuda/DistributionUniform.cu and the
31+
* `transformation::uniform_real` helper in
32+
* aten/src/ATen/native/cuda/DistributionTemplates.h).
2933
*
3034
* Implements the AOTI shim for aten::rand.default on CUDA. Uses cuRAND
31-
* Philox counter-based RNG with GPU-resident state. The counter is
32-
* atomically advanced by each kernel invocation on-device, making it
33-
* fully compatible with CUDA graph capture and replay — each replay
34-
* produces different values because the counter increments on the GPU.
35+
* Philox counter-based RNG with GPU-resident state, then maps the random
36+
* uint32 to [0, 1) using PyTorch's bit-mask + divisor formulation rather
37+
* than curand_uniform (which returns (0, 1]). The counter is atomically
38+
* advanced by each kernel invocation on-device, making it fully compatible
39+
* with CUDA graph capture and replay — each replay produces different
40+
* values because the counter increments on the GPU.
3541
*
3642
* Supports float32 and bfloat16 output dtypes.
3743
*/
@@ -46,7 +52,10 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_rand(
4652
SlimTensor** ret0);
4753

4854
/**
49-
* Fills a pre-allocated int64 tensor with random integers in [low, high).
55+
* Fills a pre-allocated int64 tensor with random integers in [low, high),
56+
* matching the behavior of torch.randint / aten::randint.low_out (see
57+
* `transformation::uniform_int_from_to` in
58+
* aten/src/ATen/native/cuda/DistributionTemplates.h).
5059
*
5160
* Implements the AOTI shim for aten::randint.low_out on CUDA. Used by
5261
* Inductor's Philox RNG to generate random seeds. Each thread atomically

backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_rand.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class AOTITorchCudaRandTest : public ::testing::Test {
9191
// aoti_torch_cuda_rand tests
9292
// ----------------------------------------------------------------------------
9393

94-
// Basic float32 rand: produces a tensor in [0, 1).
94+
// Basic float32 rand: produces a tensor in [0, 1) to match torch.rand.
9595
TEST_F(AOTITorchCudaRandTest, RandFloat32Basic) {
9696
std::vector<int64_t> sizes = {4, 8};
9797
int64_t numel = 4 * 8;
@@ -144,7 +144,9 @@ TEST_F(AOTITorchCudaRandTest, RandDefaultDtypeIsFloat) {
144144
EXPECT_EQ(out->numel(), 16);
145145
}
146146

147-
// BFloat16 rand: values must lie in [0, 1).
147+
// BFloat16 rand: values must lie in [0, 1) to match torch.rand. Note that
148+
// bfloat16 has only 8 mantissa bits so a float in [0, 1) close to 1.0 may
149+
// round-up to bfloat16 1.0; we accept that as PyTorch does.
148150
TEST_F(AOTITorchCudaRandTest, RandBFloat16Basic) {
149151
std::vector<int64_t> sizes = {32};
150152
int64_t numel = 32;
@@ -171,7 +173,7 @@ TEST_F(AOTITorchCudaRandTest, RandBFloat16Basic) {
171173
for (int64_t i = 0; i < numel; ++i) {
172174
float v = bfloat16_bits_to_float(host[i]);
173175
EXPECT_GE(v, 0.0f) << "bf16 value at " << i << " < 0";
174-
EXPECT_LT(v, 1.0f) << "bf16 value at " << i << " >= 1";
176+
EXPECT_LE(v, 1.0f) << "bf16 value at " << i << " > 1";
175177
}
176178
}
177179

@@ -287,7 +289,8 @@ TEST_F(AOTITorchCudaRandTest, RandTwoCallsProduceDifferentValues) {
287289
// aoti_torch_cuda_randint_low_out tests
288290
// ----------------------------------------------------------------------------
289291

290-
// Basic randint into a pre-allocated int64 tensor; values lie in [low, high).
292+
// Basic randint into a pre-allocated int64 tensor; values lie in [low, high)
293+
// to match torch.randint semantics.
291294
TEST_F(AOTITorchCudaRandTest, RandintBasicRange) {
292295
std::vector<int64_t> sizes = {32};
293296
int64_t numel = 32;

0 commit comments

Comments
 (0)