@@ -77,6 +77,10 @@ void ensure_rng_init(cudaStream_t stream) {
7777// (populated by `advance_counter_kernel` immediately before this launch).
7878// This replaces the previous per-element atomicAdd contention with a single
7979// atomic per kernel launch.
80+ //
81+ // Matches PyTorch's `transformation::uniform_int_from_to` semantics: builds
82+ // a 64-bit random value from two 32-bit curand draws, then takes
83+ // `val % range + low` so the output lies in [low, high).
8084__global__ void philox_randint_graph_kernel (
8185 int64_t * __restrict__ out,
8286 int64_t numel,
@@ -87,13 +91,27 @@ __global__ void philox_randint_graph_kernel(
8791 if (idx < numel) {
8892 curandStatePhilox4_32_10_t state;
8993 curand_init (rng->seed , idx, rng->base_scratch , &state);
90- double val = curand_uniform_double (&state);
91- int64_t ival = static_cast <int64_t >(val * range);
92- out[idx] = low + (ival >= range ? range - 1 : ival);
94+ uint32_t hi = curand (&state);
95+ uint32_t lo = curand (&state);
96+ uint64_t rval = (static_cast <uint64_t >(hi) << 32 ) | static_cast <uint64_t >(lo);
97+ uint64_t urange = static_cast <uint64_t >(range);
98+ out[idx] = low + static_cast <int64_t >(rval % urange);
9399 }
94100}
95101
96- // Philox-based uniform float32 generator (graph-safe version).
102+ // Maps a uniformly distributed uint32 to a float32 in [0, 1) following the
103+ // pattern used by PyTorch's `transformation::uniform_real` in
104+ // aten/src/ATen/native/cuda/DistributionTemplates.h: keep the low 24 mantissa
105+ // bits and divide by 2^24.
106+ __device__ inline float uniform_real_from_uint32 (uint32_t val) {
107+ // std::numeric_limits<float>::digits == 24
108+ constexpr uint32_t kMantissaMask = (1u << 24 ) - 1 ;
109+ constexpr float kDivisor = 1 .0f / static_cast <float >(1u << 24 );
110+ return static_cast <float >(val & kMantissaMask ) * kDivisor ;
111+ }
112+
113+ // Philox-based uniform float32 generator (graph-safe version). Produces
114+ // values in [0, 1) to match torch.rand semantics.
97115__global__ void philox_rand_float_graph_kernel (
98116 float * __restrict__ out,
99117 int64_t numel,
@@ -102,11 +120,12 @@ __global__ void philox_rand_float_graph_kernel(
102120 if (idx < numel) {
103121 curandStatePhilox4_32_10_t state;
104122 curand_init (rng->seed , idx, rng->base_scratch , &state);
105- out[idx] = curand_uniform ( &state);
123+ out[idx] = uniform_real_from_uint32 ( curand ( &state) );
106124 }
107125}
108126
109- // Philox-based uniform bfloat16 generator (graph-safe version).
127+ // Philox-based uniform bfloat16 generator (graph-safe version). Produces a
128+ // float in [0, 1) and rounds to bfloat16 with round-to-nearest-even.
110129__global__ void philox_rand_bf16_graph_kernel (
111130 uint16_t * __restrict__ out,
112131 int64_t numel,
@@ -115,7 +134,7 @@ __global__ void philox_rand_bf16_graph_kernel(
115134 if (idx < numel) {
116135 curandStatePhilox4_32_10_t state;
117136 curand_init (rng->seed , idx, rng->base_scratch , &state);
118- float val = curand_uniform ( &state);
137+ float val = uniform_real_from_uint32 ( curand ( &state) );
119138 uint32_t bits;
120139 memcpy (&bits, &val, sizeof (uint32_t ));
121140 uint32_t lsb = (bits >> 16 ) & 1 ;
0 commit comments