Skip to content

Commit 98aee58

Browse files
Merge pull request #72 from ChrisRackauckas-Claude/fix-passthroughrng-overlay-shadowing
Make PassthroughRNG dispatch survive overlay method-table shadowing
2 parents c401dab + a977de2 commit 98aee58

3 files changed

Lines changed: 26 additions & 1 deletion

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "PoissonRandom"
22
uuid = "e409e4f3-bfea-5376-8464-e040bb5c01ab"
3-
version = "0.4.7"
3+
version = "0.4.8"
44

55
[deps]
66
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"

src/PoissonRandom.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ Random.rand(rng::PassthroughRNG) = rand()
1313
Random.randexp(rng::PassthroughRNG) = randexp()
1414
Random.randn(rng::PassthroughRNG) = randn()
1515

16+
# When an overlay method table (e.g. CUDA.jl's `@device_override
17+
# Random.randexp(::AbstractRNG)`) shadows the methods above, the overlay body
18+
# runs with rng::PassthroughRNG and may call `Random.rand(rng, UInt52Raw())`
19+
# or `Random.rand(rng, T)`. The stdlib Sampler chain bottoms out at
20+
# `_rand52(r, rng_native_52(r))` → `rand(r, UInt64)`; provide those so the
21+
# chain still reaches bare rand(T) and the device-side default_rng path.
22+
Random.rng_native_52(::PassthroughRNG) = UInt64
23+
Random.rand(rng::PassthroughRNG, ::Type{T}) where {T} = rand(T)
24+
1625
count_rand::Real) = count_rand(Random.default_rng(), λ)
1726
function count_rand(rng::AbstractRNG, λ::Real)
1827
n = 0

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,22 @@ end
116116
end
117117
end
118118

119+
@testset "PassthroughRNG dispatch" begin
120+
using Random: Random, UInt52Raw
121+
prng = PassthroughRNG()
122+
# The CUDA.jl @device_override Random.randexp(::AbstractRNG) shadows our
123+
# specific Random.randexp(::PassthroughRNG) on the GPU because Julia's
124+
# OverlayMethodTable returns overlay matches without consulting the base
125+
# table when the overlay fully covers the signature. The override body
126+
# then calls these against PassthroughRNG; if they MethodError, kernel
127+
# compilation fails with InvalidIRError on jl_f_throw_methoderror.
128+
@test Random.rng_native_52(prng) === UInt64
129+
@test Random.rand(prng, UInt52Raw()) isa UInt64
130+
@test Random.rand(prng, UInt64) isa UInt64
131+
@test Random.rand(prng, Float32) isa Float32
132+
@test Random.rand(prng, Float64) isa Float64
133+
end
134+
119135
if get(ENV, "GROUP", "all") == "all" || get(ENV, "GROUP", "all") == "nopre"
120136
@testset "Allocation Tests" begin
121137
include("alloc_tests.jl")

0 commit comments

Comments
 (0)