diff --git a/src/host/random.jl b/src/host/random.jl index ee285ab3..5beb7d31 100644 --- a/src/host/random.jl +++ b/src/host/random.jl @@ -41,22 +41,27 @@ end ctr end +@inline function philox4x32_10(ctr0::UInt64, ctr1::UInt64, key::UInt64)::NTuple{4, UInt32} + philox4x32_10((ctr0%UInt32, (ctr0>>32)%UInt32, ctr1%UInt32, (ctr1>>32)%UInt32), (key%UInt32, (key>>32)%UInt32)) +end ## Float conversions: unsigned integer → uniform float, strictly positive -## (Float32 can round up to exactly 1.0; Float64 stays strictly below 1.0.) +## (can round up to exactly 1.0) -@inline function u01(::Type{Float32}, u::UInt32) - fma(Float32(u), Float32(2)^(-32), Float32(2)^(-33)) -end +""" + u01(F, u::Union{UInt32, UInt64})::F + +Convert an unsigned integer to a float of type `F` uniformly distributed in (0, 1]. -@inline function u01(::Type{Float64}, u::UInt64) - # Bit-pattern construction avoids the expensive Float64(::UInt64) conversion - # and fma on consumer GPUs (where FP64 throughput is as low as 1:64). The - # low bit of the mantissa is forced set so the result is strictly in (0, 1), - # which is required by Box-Muller's log(u). - reinterpret(Float64, ((u >> 12) | 0x1) | 0x3ff0000000000000) - 1.0 +Ported from [Random123 uniform.hpp](https://github.com/DEShawResearch/random123/blob/v1.14.0/include/Random123/uniform.hpp#L175). +""" +@inline function u01(::Type{F}, u::UInt32)::F where F + fma(F(u), F(2)^Int32(-32), F(2)^Int32(-33)) end +@inline function u01(::Type{F}, u::UInt64)::F where F + fma(F(u), F(2)^Int32(-64), F(2)^Int32(-65)) +end ## Fast sincospi for Box-Muller # @@ -69,91 +74,225 @@ end # select one of 8 octants (swap/negate), upper 29 bits give the reduced # argument (+0.5-biased so y ≠ 0). # -# Float64 keeps using Base.sincospi: backends that support FP64 all have -# intrinsics, and the polynomial alternative is ~8× slower on consumer GPUs -# with low FP64 throughput. +# Vendored contents of: +# https://github.com/medyan-dev/PhiloxRNG.jl/blob/v1.1.1/src/fastsincospi.jl +# With minor style changes. + +# ============================================================ +# Fast sincospi for Box-Muller +# +# Computes (sin(θ), cos(θ)) from a uniform UInt32 (or UInt64), +# placing 2^N points uniformly around the unit circle with +# no point landing exactly on an axis. +# +# The bottom 3 bits of u select one of 8 octants (π/4 each). +# The upper bits give the reduced argument y ∈ (0, 0.25), +# with a +0.5 bias to avoid y = 0. The polynomials evaluate +# sin(πy) and cos(πy) — with π baked into the coefficients — +# then each octant bit directly controls one operation: +# bit 0 → swap sin/cos +# bit 1 → negate sin +# bit 2 → negate cos +# +# The octants are not in geometric order, but the 2^N points +# are uniformly distributed around the unit circle regardless. +# ============================================================ +# --- Float32 minimax coefficients for sin(πy)/y and cos(πy) in y² --- +# +# 4-term (degree 3 in y²) minimax via Remez algorithm on [0, 0.0625]. const SP_F32 = (3.1415927f0, -5.167708f0, 2.5497673f0, -0.58907866f0) const CP_F32 = (1.0f0, -4.934788f0, 4.057578f0, -1.3061346f0) -@inline function fast_sincospi(::Type{Float32}, u::UInt32) +@inline function fast_sincospi(::Type{Float32}, u::Union{UInt32, UInt64}) oct = (u % Int32) & Int32(7) - y = fma(Float32(u & ~UInt32(7)), Float32(2)^(-34), Float32(2)^(-32)) + y = fma(Float32(u & ~oftype(u, 7)), Float32(2)^Int32(-(sizeof(u)*8+2)), Float32(2)^Int32(-(sizeof(u)*8))) + sp = y * evalpoly(y * y, SP_F32) cp = evalpoly(y * y, CP_F32) + swap = !iszero(oct & Int32(1)) sin_neg = !iszero(oct & Int32(2)) cos_neg = !iszero(oct & Int32(4)) + s_raw = ifelse(swap, cp, sp) c_raw = ifelse(swap, sp, cp) - (ifelse(sin_neg, -s_raw, s_raw), ifelse(cos_neg, -c_raw, c_raw)) + sin_val = ifelse(sin_neg, -s_raw, s_raw) + cos_val = ifelse(cos_neg, -c_raw, c_raw) + (sin_val, cos_val) end +# ============================================================ +# Float64 / UInt64 version +# +# Same structure as Float32: bottom 3 bits → octant, upper +# 61 bits → reduced argument, +0.5 bias, direct bit mapping. +# ============================================================ + +const SP_F64 = (3.141592653589793, -5.167712780049954, 2.5501640398733785, + -0.5992645289398095, 0.08214586918507949, -0.007370021659123395, + 0.0004615322405282014) +const CP_F64 = (1.0, -4.934802200544605, 4.0587121263978485, + -1.3352627670374702, 0.23533054723811608, -0.025804938901032953, + 0.0019068114005246046) + +@inline function fast_sincospi(::Type{Float64}, u::Union{UInt32, UInt64}) + oct = (u % Int32) & Int32(7) + y = fma(Float64(u & ~oftype(u, 7)), Float64(2)^Int32(-(sizeof(u)*8+2)), Float64(2)^Int32(-(sizeof(u)*8))) + + sp = y * evalpoly(y * y, SP_F64) + cp = evalpoly(y * y, CP_F64) + + swap = !iszero(oct & Int32(1)) + sin_neg = !iszero(oct & Int32(2)) + cos_neg = !iszero(oct & Int32(4)) + + s_raw = ifelse(swap, cp, sp) + c_raw = ifelse(swap, sp, cp) + sin_val = ifelse(sin_neg, -s_raw, s_raw) + cos_val = ifelse(cos_neg, -c_raw, c_raw) + (sin_val, cos_val) +end + +# End of vendored https://github.com/medyan-dev/PhiloxRNG.jl/blob/v1.1.1/src/fastsincospi.jl ## Fast log for Box-Muller # # Base.log(::Float32) widens to Float64 internally (see base/special/log.jl:242 # `Float64(2f0*f)/(2.0+f)`), same Metal / FP64-emulation problem as sincospi. # -# Vendored from PhiloxRNG.jl (MIT), ported from fdlibm's e_logf.c: a Float32 -# minimax polynomial. Takes the raw Philox UInt32 output; the u01 conversion -# is folded into the first fma so there's no intermediate float. -# -# Same Float64-path reasoning as the sincospi block above. +# Vendored contents of: +# https://github.com/medyan-dev/PhiloxRNG.jl/blob/v1.1.1/src/fastlog.jl +# With minor style changes. -const SQRT_HALF_I32 = reinterpret(Int32, Float32(sqrt(0.5))) -const LOG_ODD_F32 = (reinterpret(Float32, Int32(0x3f2aaaaa)), - reinterpret(Float32, Int32(0x3e91e9ee))) -const LOG_EVEN_F32 = (reinterpret(Float32, Int32(0x3eccce13)), - reinterpret(Float32, Int32(0x3e789e26))) +# Core log algorithm (polynomial coefficients, ln2 splitting, and reconstruction) +# adapted from fdlibm's e_log.c / e_logf.c (Sun Microsystems, 1993). +# See: https://github.com/JuliaMath/openlibm/blob/v0.8.7/src/e_log.c +# https://github.com/JuliaMath/openlibm/blob/v0.8.7/src/e_logf.c -@inline function fast_log(::Type{Float32}, u::UInt32) - x = fma(Float32(u), Float32(2)^(-32), Float32(2)^(-33)) +const SQRT_HALF_I32 = reinterpret(Int32, Float32(sqrt(0.5))) +const LOG_POLY_F32 = (0.6666666f0, 0.40000972f0, 0.28498787f0, 0.24279079f0) +const LN2_HI_F32 = 0.6931381f0 +const LN2_LO_F32 = 9.058001f-6 + +@inline function fast_log(::Type{Float32}, u::Union{UInt32, UInt64}) + x = u01(Float32, u) + + # Goal: find k and f such that + # x = 2^k * (1+f) + # where sqrt(2)/2 ≤ (1+f) < sqrt(2) + # if k is zero + # we calculate f by -u01(Float32, ~u) which is more accurate for x near 1 + + # Float32 has 23 fractional bits. + # x is ordered by value in Int32 space. + # Starting from x=1, k starts at 0, then ix becomes negative at x = prevfloat(sqrt(0.5f0)) + # making k = -1. For each power of 2 scale in x, + # k changes by one, because we shift out the 23 fraction bits. ix = reinterpret(Int32, x) - SQRT_HALF_I32 k = ix >> Int32(23) - f_std = reinterpret(Float32, (ix & Int32(0x007fffff)) + SQRT_HALF_I32) - 1.0f0 - f_comp = -fma(Float32(~u), Float32(2)^(-32), Float32(2)^(-33)) + + # `f_plus_one_std` will have the same fraction bits as `x` + # because `- SQRT_HALF_I32` and `+ SQRT_HALF_I32` cancel out in the low 23 bits. + # `& Int32(0x007fffff)` clears the exponent and sign fields. + # `f_plus_one_std` must either have an exponent of -1 or 0. + # If x's fractional bits are less than the fractional bits of SQRT_HALF_I32 + # the `- SQRT_HALF_I32` borrows a 2^23 from the exponent field of x, + # which then shows up as an extra 2^23 in the low 23 bits after masking. + # When adding SQRT_HALF_I32 this extra 2^23 propagates up and + # bumps the exponent from -1 to 0. + f_plus_one_std = reinterpret(Float32, (ix & Int32(0x007fffff)) + SQRT_HALF_I32) + f_std = f_plus_one_std - 1.0f0 + + f_comp = -u01(Float32, ~u) f = ifelse(k == Int32(0), f_comp, f_std) + + # Goal: get log(1+f) via a polynomial approx. + # Let s = f/(2+f), z = s², and log_poly(z) ≈ evalpoly(z, LOG_POLY_F32) + # log(1+f) = 2s + s³*log_poly(s²) + # R = s²*log_poly(s²) + # log(1+f) = f - f²/2 + s*(f²/2 + R) s = f / (2.0f0 + f) - z = s * s; w = z * z - R = z * evalpoly(w, LOG_ODD_F32) + w * evalpoly(w, LOG_EVEN_F32) + z = s * s + R = z * evalpoly(z, LOG_POLY_F32) hfsq = 0.5f0 * f * f - Float32(k) * reinterpret(Float32, Int32(0x3f317180)) - - ((hfsq - (s * (hfsq + R) + - Float32(k) * reinterpret(Float32, Int32(0x3717f7d1)))) - f) + + # log(x) = k*log(2) + log(1+f) + k_f32 = Float32(k) + # Simpler version, but fails the mean test by 2E-9 + # fma(k_f32, 0.6931472f0 #= log(2) =#, fma(s, R-f, f)) + # log(2) = LN2_HI_F32 + LN2_LO_F32 + fma(k_f32, LN2_HI_F32, + f - (hfsq - fma(s, (hfsq + R), k_f32 * LN2_LO_F32)) + ) end +const SQRT_HALF_I64 = reinterpret(Int64, sqrt(0.5)) +const LOG_POLY_F64 = ( + 6.666666666666735130e-01, + 3.999999999940941908e-01, + 2.857142874366239149e-01, + 2.222219843214978396e-01, + 1.818357216161805012e-01, + 1.531383769920937332e-01, + 1.479819860511658591e-01, +) +const LN2_HI_F64 = 6.93147180369123816490e-01 +const LN2_LO_F64 = 1.90821492927058770002e-10 + +@inline function fast_log(::Type{Float64}, u::Union{UInt32, UInt64}) + # See Float32 version for commentary + x = u01(Float64, u) + + ix = reinterpret(Int64, x) - SQRT_HALF_I64 + k = ix >> Int64(52) + f_std = reinterpret(Float64, (ix & Int64(0x000fffffffffffff)) + SQRT_HALF_I64) - 1.0 + + f_comp = -u01(Float64, ~u) + f = ifelse(k == Int64(0), f_comp, f_std) + + s = f / (2.0 + f) + z = s * s + R = z * evalpoly(z, LOG_POLY_F64) + hfsq = 0.5 * f * f + + # log(x) = k*ln2 + log(1+f) + k_f64 = Float64(k) + fma(k_f64, LN2_HI_F64, + f - (hfsq - fma(s, (hfsq + R), k_f64 * LN2_LO_F64)) + ) +end -## Box-Muller transform +# End of vendored https://github.com/medyan-dev/PhiloxRNG.jl/blob/v1.1.1/src/fastlog.jl -using Base.FastMath + +## Box-Muller transform # ≤32-bit float output: both log and sincospi go through the Float32 # polynomials above. +# Using Base.sqrt_llvm to avoid the DomainError check. @inline function boxmuller(::Type{F}, u1::UInt32, u2::UInt32) where F <: Union{Float16,Float32} - r = sqrt(-2f0 * fast_log(Float32, u2)) + r = Base.sqrt_llvm(-2f0 * fast_log(Float32, u2)) s, c = fast_sincospi(Float32, u1) (F(r * s), F(r * c)) end -# Float64: Base.log_fast / Base.sincospi have FP64 intrinsics on the backends -# that support it. -@inline function boxmuller(::Type{Float64}, u1::Float64, u2::Float64) - r = sqrt(-2.0 * FastMath.log_fast(u1)) - s, c = sincospi(2 * u2) +@inline function boxmuller(::Type{Float64}, u1::UInt64, u2::UInt64) + r = Base.sqrt_llvm(-2.0 * fast_log(Float64, u2)) + s, c = fast_sincospi(Float64, u1) (r * s, r * c) end # For complex normals each component has variance 1/2, so the radius is # sqrt(-log(U)) rather than sqrt(-2·log(U)). @inline function boxmuller(::Type{Complex{F}}, u1::UInt32, u2::UInt32) where F <: Union{Float16,Float32} - r = sqrt(-fast_log(Float32, u2)) + r = Base.sqrt_llvm(-fast_log(Float32, u2)) s, c = fast_sincospi(Float32, u1) complex(F(r * s), F(r * c)) end -@inline function boxmuller(::Type{Complex{Float64}}, u1::Float64, u2::Float64) - r = sqrt(FastMath.neg_float_fast(FastMath.log_fast(u1))) - s, c = sincospi(2 * u2) +@inline function boxmuller(::Type{Complex{Float64}}, u1::UInt64, u2::UInt64) + r = Base.sqrt_llvm(-fast_log(Float64, u2)) + s, c = fast_sincospi(Float64, u1) complex(r * s, r * c) end @@ -167,24 +306,20 @@ end mutable struct RNG{AT} <: AbstractRNG seed::UInt64 - counter::UInt32 + counter::UInt64 end -RNG{AT}() where {AT} = RNG{AT}(rand(Random.RandomDevice(), UInt64), UInt32(0)) -RNG{AT}(seed::Integer) where {AT} = RNG{AT}(seed % UInt64, UInt32(0)) +RNG{AT}() where {AT} = RNG{AT}(rand(Random.RandomDevice(), UInt64), rand(Random.RandomDevice(), UInt64)) +RNG{AT}(seed::Integer) where {AT} = RNG{AT}(seed % UInt64, UInt64(0)) -Random.seed!(rng::RNG) = (rng.seed = rand(Random.RandomDevice(), UInt64); rng.counter = 0; rng) -Random.seed!(rng::RNG, seed::Integer) = (rng.seed = seed % UInt64; rng.counter = 0; rng) +Random.seed!(rng::RNG) = (rng.seed = rand(Random.RandomDevice(), UInt64); rng.counter = rand(Random.RandomDevice(), UInt64); rng) +Random.seed!(rng::RNG, seed::Integer) = (rng.seed = seed % UInt64; rng.counter = UInt64(0); rng) function advance_counter!(rng::RNG) - rng.counter += one(UInt32) - rng.counter == 0 && (rng.seed += one(UInt64)) + rng.counter += one(UInt64) rng end -# Split the 64-bit seed into the two 32-bit lanes of the Philox key. -@inline philox_key(seed::UInt64) = (seed % UInt32, (seed >> 32) % UInt32) - ## Specialized rand! kernels for common types @@ -228,22 +363,18 @@ vals_per_call(::Type{T}) where T = sizeof(T) <= 4 ? 4 : sizeof(T) <= 8 ? 2 : 1 vals_per_call(::Type{Complex{T}}) where T = sizeof(T) <= 4 ? 2 : 1 # Batched kernel: N values per work item from one Philox call -@kernel function rand_batched_kernel!(@Const(seed), @Const(counter), A::AbstractArray{T}) where T +@kernel function rand_batched_kernel!(seed::UInt64, counter::UInt64, A::AbstractArray{T}) where T gid = @index(Global, Linear) N = vals_per_call(T) idx = N * gid len = length(A) if idx <= len - vals = philox_to_vals(T, philox4x32_10( - (gid % UInt32, UInt32(0), counter, UInt32(0)), - philox_key(seed))...) + vals = philox_to_vals(T, philox4x32_10(gid % UInt64, counter, seed)...) for j in 1:N @inbounds A[idx - N + j] = vals[j] end elseif idx - N < len - vals = philox_to_vals(T, philox4x32_10( - (gid % UInt32, UInt32(0), counter, UInt32(0)), - philox_key(seed))...) + vals = philox_to_vals(T, philox4x32_10(gid % UInt64, counter, seed)...) for j in 1:min(N, len - idx + N) @inbounds A[idx - N + j] = vals[j] end @@ -278,19 +409,17 @@ end struct ElementRNG <: AbstractRNG seed::UInt64 - counter::UInt32 - gid::UInt32 - subctr_ptr::Ptr{UInt32} + counter::UInt64 + nthreads::UInt64 + ctr0_ptr::Ptr{UInt64} end @inline Random.rng_native_52(::ElementRNG) = UInt64 @inline function Random.rand(rng::ElementRNG, ::Random.SamplerType{UInt64}) - sc = unsafe_load(rng.subctr_ptr) + UInt32(1) - unsafe_store!(rng.subctr_ptr, sc) - a1, a2, _, _ = philox4x32_10( - (rng.gid, sc, rng.counter, UInt32(0)), - philox_key(rng.seed)) + sc = unsafe_load(rng.ctr0_ptr) + rng.nthreads + unsafe_store!(rng.ctr0_ptr, sc) + a1, a2, _, _ = philox4x32_10(sc, rng.counter, rng.seed) UInt64(a1) | UInt64(a2) << 32 end @@ -305,13 +434,16 @@ end ## Generic rand! fallback via ElementRNG -@kernel function rand_generic_kernel!(@Const(seed), @Const(counter), A::AbstractArray{T}) where T +@kernel function rand_generic_kernel!(seed::UInt64, counter::UInt64, A::AbstractArray{T}) where T gid = @index(Global, Linear) - if gid <= length(A) - subctr = Ref{UInt32}(0) - rng = ElementRNG(seed, counter, gid % UInt32, - Base.unsafe_convert(Ptr{UInt32}, subctr)) - @inbounds A[gid] = rand(rng, T) + len_A = length(A) + if gid <= len_A + subctr = Ref{UInt64}(gid%UInt64) + GC.@preserve subctr begin + rng = ElementRNG(seed, counter, len_A % UInt64, + Base.unsafe_convert(Ptr{UInt64}, subctr)) + @inbounds A[gid] = rand(rng, T) + end end end @@ -342,9 +474,8 @@ end # output. # Convert Philox UInt32 outputs to N normally-distributed values of type T. -# ≤32-bit float targets pass UInt32s to boxmuller directly (the polynomial -# sincospi extracts bits itself; log still goes through u01 for now). 64-bit -# targets assemble UInt64s and convert to Float64 on the way in. +# ≤32-bit float targets pass UInt32s to boxmuller directly. 64-bit +# targets use UInt64s for finer sampling. for T in (Float16, Float32) @eval @inline function philox_to_normals(::Type{$T}, a1, a2, a3, a4) n1, n2 = boxmuller($T, a1, a2) @@ -354,33 +485,29 @@ for T in (Float16, Float32) end @inline function philox_to_normals(::Type{Float64}, a1, a2, a3, a4) boxmuller(Float64, - u01(Float64, UInt64(a1) | UInt64(a2) << 32), - u01(Float64, UInt64(a3) | UInt64(a4) << 32)) + UInt64(a1) | UInt64(a2) << 32, + UInt64(a3) | UInt64(a4) << 32) end @inline philox_to_normals(::Type{Complex{Float32}}, a1, a2, a3, a4) = (boxmuller(Complex{Float32}, a1, a2), boxmuller(Complex{Float32}, a3, a4)) @inline philox_to_normals(::Type{Complex{Float64}}, a1, a2, a3, a4) = (boxmuller(Complex{Float64}, - u01(Float64, UInt64(a1) | UInt64(a2) << 32), - u01(Float64, UInt64(a3) | UInt64(a4) << 32)),) + UInt64(a1) | UInt64(a2) << 32, + UInt64(a3) | UInt64(a4) << 32),) -@kernel function randn_batched_kernel!(@Const(seed), @Const(counter), A::AbstractArray{T}) where T +@kernel function randn_batched_kernel!(seed::UInt64, counter::UInt64, A::AbstractArray{T}) where T gid = @index(Global, Linear) N = vals_per_call(T) idx = N * gid len = length(A) if idx <= len - vals = philox_to_normals(T, philox4x32_10( - (gid % UInt32, UInt32(0), counter, UInt32(0)), - philox_key(seed))...) + vals = philox_to_normals(T, philox4x32_10(gid % UInt64, counter, seed)...) for j in 1:N @inbounds A[idx - N + j] = vals[j] end elseif idx - N < len - vals = philox_to_normals(T, philox4x32_10( - (gid % UInt32, UInt32(0), counter, UInt32(0)), - philox_key(seed))...) + vals = philox_to_normals(T, philox4x32_10(gid % UInt64, counter, seed)...) for j in 1:min(N, len - idx + N) @inbounds A[idx - N + j] = vals[j] end @@ -424,15 +551,18 @@ end @inline Random.randn(rng::ElementRNG, ::Type{Float32}) = first(boxmuller(Float32, rand(rng, UInt32), rand(rng, UInt32))) @inline Random.randn(rng::ElementRNG, ::Type{Float64}) = - first(boxmuller(Float64, u01(Float64, rand(rng, UInt64)), u01(Float64, rand(rng, UInt64)))) + first(boxmuller(Float64, rand(rng, UInt64), rand(rng, UInt64))) -@kernel function randn_generic_kernel!(@Const(seed), @Const(counter), A::AbstractArray{T}) where T +@kernel function randn_generic_kernel!(seed::UInt64, counter::UInt64, A::AbstractArray{T}) where T gid = @index(Global, Linear) - if gid <= length(A) - subctr = Ref{UInt32}(0) - rng = ElementRNG(seed, counter, gid % UInt32, - Base.unsafe_convert(Ptr{UInt32}, subctr)) - @inbounds A[gid] = randn(rng, T) + len_A = length(A) + if gid <= len_A + subctr = Ref{UInt64}(gid%UInt64) + GC.@preserve subctr begin + rng = ElementRNG(seed, counter, len_A % UInt64, + Base.unsafe_convert(Ptr{UInt64}, subctr)) + @inbounds A[gid] = randn(rng, T) + end end end