Skip to content

Commit a977de2

Browse files
Make PassthroughRNG dispatch survive overlay method-table shadowing
PassthroughRNG previously defined only the three no-second-arg methods (rand, randexp, randn). On Julia 1.12+, GPU back ends like CUDA.jl install device-side overlay tables via Base.Experimental.@consistent_overlay. Julia's OverlayMethodTable.findall returns overlay matches *without consulting the base method table* whenever the overlay fully covers the signature, so an overlay method like CUDA.jl's `@device_override Random.randexp(rng::AbstractRNG)` shadows our specific `Random.randexp(::PassthroughRNG)` on the device. The override's body then runs with rng::PassthroughRNG and calls `Random.rand(rng, UInt52Raw())`. The stdlib Sampler chain for that bottoms out at `_rand52(r, rng_native_52(r)) → rand(r, UInt64)`; PassthroughRNG had no `rng_native_52` and no typed-arg rand, so the chain statically reached `throw(MethodError, ...)`, which GPUCompiler refuses to lower (see SciML/JumpProcesses.jl#588 for the original repro). Add minimal forwarding methods so the chain still reaches bare rand(T): Random.rng_native_52(::PassthroughRNG) = UInt64 Random.rand(rng::PassthroughRNG, ::Type{T}) where {T} = rand(T) These keep PassthroughRNG's "use whatever default_rng() returns here" semantics — bare rand(T) goes through default_rng(), which GPU back ends device-override to their device RNG (Philox2x32 in CUDA.jl). Verified on Julia 1.12.6 that rand(PassthroughRNG(), UInt52Raw()), rand(PassthroughRNG(), UInt64), etc. all resolve cleanly after this change. Bump to 0.4.8. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent c401dab commit a977de2

3 files changed

Lines changed: 26 additions & 1 deletion

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "PoissonRandom"
22
uuid = "e409e4f3-bfea-5376-8464-e040bb5c01ab"
3-
version = "0.4.7"
3+
version = "0.4.8"
44

55
[deps]
66
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"

src/PoissonRandom.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ Random.rand(rng::PassthroughRNG) = rand()
1313
Random.randexp(rng::PassthroughRNG) = randexp()
1414
Random.randn(rng::PassthroughRNG) = randn()
1515

16+
# When an overlay method table (e.g. CUDA.jl's `@device_override
17+
# Random.randexp(::AbstractRNG)`) shadows the methods above, the overlay body
18+
# runs with rng::PassthroughRNG and may call `Random.rand(rng, UInt52Raw())`
19+
# or `Random.rand(rng, T)`. The stdlib Sampler chain bottoms out at
20+
# `_rand52(r, rng_native_52(r))` → `rand(r, UInt64)`; provide those so the
21+
# chain still reaches bare rand(T) and the device-side default_rng path.
22+
Random.rng_native_52(::PassthroughRNG) = UInt64
23+
Random.rand(rng::PassthroughRNG, ::Type{T}) where {T} = rand(T)
24+
1625
count_rand::Real) = count_rand(Random.default_rng(), λ)
1726
function count_rand(rng::AbstractRNG, λ::Real)
1827
n = 0

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,22 @@ end
116116
end
117117
end
118118

119+
@testset "PassthroughRNG dispatch" begin
120+
using Random: Random, UInt52Raw
121+
prng = PassthroughRNG()
122+
# The CUDA.jl @device_override Random.randexp(::AbstractRNG) shadows our
123+
# specific Random.randexp(::PassthroughRNG) on the GPU because Julia's
124+
# OverlayMethodTable returns overlay matches without consulting the base
125+
# table when the overlay fully covers the signature. The override body
126+
# then calls these against PassthroughRNG; if they MethodError, kernel
127+
# compilation fails with InvalidIRError on jl_f_throw_methoderror.
128+
@test Random.rng_native_52(prng) === UInt64
129+
@test Random.rand(prng, UInt52Raw()) isa UInt64
130+
@test Random.rand(prng, UInt64) isa UInt64
131+
@test Random.rand(prng, Float32) isa Float32
132+
@test Random.rand(prng, Float64) isa Float64
133+
end
134+
119135
if get(ENV, "GROUP", "all") == "all" || get(ENV, "GROUP", "all") == "nopre"
120136
@testset "Allocation Tests" begin
121137
include("alloc_tests.jl")

0 commit comments

Comments
 (0)