Skip to content

Commit 68e03c3

Browse files
committed
add support for Float16, UInt8, UInt16, Int8, Int16 (was bored)
1 parent 48cd247 commit 68e03c3

5 files changed

Lines changed: 70 additions & 21 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AcceleratedKernels"
22
uuid = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
3-
version = "0.4.3"
43
authors = ["Andrei-Leonard Nicusan <leonard@evophase.co.uk> and contributors"]
4+
version = "0.4.3"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

docs/src/api/rand.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@ Counter-based random generation for CPU and GPU backends with deterministic beha
55

66
Use an explicit `CounterRNG(seed; alg=...)` when reproducibility matters. For convenience,
77
`AK.rand!(x)` creates a fresh `CounterRNG()` on each call using one auto-seeded
8-
`Base.rand(Random.default_rng(), UInt64)` draw, so repeated calls produce different outputs unless Random.seed!() is used.
8+
`Random.rand(Random.default_rng(), UInt64)` draw, so repeated calls produce different outputs unless Random.seed!() is used.
99

10-
Supported output element types:
11-
- `UInt32`, `UInt64`
12-
- `Int32`, `Int64`
13-
- `Float32`, `Float64`
10+
Supported element types:
11+
- `UInt8`, `UInt16`, `UInt32`, `UInt64`
12+
- `Int8`, `Int16`, `Int32`, `Int64`
13+
- `Float16`, `Float32`, `Float64`
1414
- `Bool`
1515

16-
The core of the random number generation produces a `UInt` of the requested scalar width.
16+
The core of the random number generation produces either a `UInt32` or `UInt64` depending on the width of the requested element type.
1717
That `UInt` is then either:
18-
- Unsigned integers: returned as-is
19-
- Signed integers: reinterpreted as a signed integer bit pattern.
18+
- Unsigned integers: returned as-is or truncated if necessary.
19+
- Signed integers: reinterpreted as a signed integer bit pattern and truncated if necessary.
2020
- Floats: mantissa construction into a uniform grid in `[0, 1)` ([read more](https://lomont.org/posts/2017/unit-random/)).
2121
- Bool: `true` if the `UInt` draw is odd (`isodd(u)`), otherwise `false`.
2222

src/rand/rand.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ Fill `x` in-place with pseudo-random values using a stateless counter-based RNG.
8282
counter is exactly `UInt64(i - 1)` in linear indexing order.
8383
8484
Supported scalar element types are:
85-
- `UInt32`, `UInt64`
86-
- `Int32`, `Int64`
87-
- `Float32`, `Float64`
85+
- `UInt8`, `UInt16`, `UInt32`, `UInt64`
86+
- `Int8`, `Int16`, `Int32`, `Int64`
87+
- `Float16`, `Float32`, `Float64`
8888
- `Bool`
8989
9090
Semantics:

src/rand/utilities.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,36 @@
1515

1616
# Internal scalar eltypes currently supported by rand!.
1717
const ALLOWED_RAND_SCALARS = Union{
18-
UInt32, UInt64,
19-
Int32, Int64,
20-
Float32, Float64,
18+
UInt8, UInt16, UInt32, UInt64,
19+
Int8, Int16, Int32, Int64,
20+
Float16, Float32, Float64,
2121
Bool
2222
}
2323

2424

25+
@inline raw_uint_type(::Type{UInt8}) = UInt32
26+
@inline raw_uint_type(::Type{UInt16}) = UInt32
2527
@inline raw_uint_type(::Type{UInt32}) = UInt32
28+
@inline raw_uint_type(::Type{Int8}) = UInt32
29+
@inline raw_uint_type(::Type{Int16}) = UInt32
2630
@inline raw_uint_type(::Type{Int32}) = UInt32
31+
@inline raw_uint_type(::Type{Float16}) = UInt32
2732
@inline raw_uint_type(::Type{Float32}) = UInt32
2833
@inline raw_uint_type(::Type{UInt64}) = UInt64
2934
@inline raw_uint_type(::Type{Int64}) = UInt64
3035
@inline raw_uint_type(::Type{Float64}) = UInt64
3136
@inline raw_uint_type(::Type{Bool}) = UInt32
3237

3338

39+
@inline from_uint(::Type{UInt8}, u::UInt32)::UInt8 = trunc(UInt8, u >> 24)
40+
@inline from_uint(::Type{UInt16}, u::UInt32)::UInt16 = trunc(UInt16, u >> 16)
3441
@inline from_uint(::Type{UInt32}, u::UInt32)::UInt32 = u
3542
@inline from_uint(::Type{UInt64}, u::UInt64)::UInt64 = u
43+
@inline from_uint(::Type{Int8}, u::UInt32)::Int8 = reinterpret(Int8, trunc(UInt8, u >> 24))
44+
@inline from_uint(::Type{Int16}, u::UInt32)::Int16 = reinterpret(Int16, trunc(UInt16, u >> 16))
3645
@inline from_uint(::Type{Int32}, u::UInt32)::Int32 = reinterpret(Int32, u)
3746
@inline from_uint(::Type{Int64}, u::UInt64)::Int64 = reinterpret(Int64, u)
47+
@inline from_uint(::Type{Float16}, u::UInt32)::Float16 = uint32_to_unit_float16(u)
3848
@inline from_uint(::Type{Float32}, u::UInt32)::Float32 = uint32_to_unit_float32(u)
3949
@inline from_uint(::Type{Float64}, u::UInt64)::Float64 = uint64_to_unit_float64(u)
4050
@inline from_uint(::Type{Bool}, u::UInt32)::Bool = isodd(u)
@@ -81,6 +91,17 @@ end
8191

8292

8393

94+
# Convert random UInt32 bits to Float16 in [0, 1) by mantissa construction.
95+
@inline function uint32_to_unit_float16(u::UInt32)::Float16
96+
# Keep 10 random bits for the mantissa (drop 22 rightmost bits from the UInt32)
97+
# and combine with the bit pattern of Float16(1.0) (sign=0, exponent=15).
98+
bits = UInt16(0x3c00) | UInt16(u >> 22)
99+
100+
# Interpret as 1.mantissa, then subtract 1 for [0, 1)
101+
return reinterpret(Float16, bits) - Float16(1)
102+
end
103+
104+
84105
# Convert random UInt32 bits to Float32 in [0, 1) by mantissa construction.
85106
@inline function uint32_to_unit_float32(u::UInt32)::Float32
86107
# Keep 23 random bits for the mantissa (drop 9 rightmost bits from the UInt32)

test/rand.jl

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
const RAND_ALGS = (AK.SplitMix64(), AK.Philox(), AK.Threefry())
2-
const RAND_SCALAR_TYPES_ALL = (UInt32, UInt64, Int32, Int64, Float32, Float64, Bool)
2+
const RAND_SCALAR_TYPES_ALL = (
3+
UInt8, UInt16, UInt32, UInt64,
4+
Int8, Int16, Int32, Int64,
5+
Float16, Float32, Float64,
6+
Bool,
7+
)
38
const RAND_SCALAR_TYPES_BACKEND = IS_CPU_BACKEND ?
49
RAND_SCALAR_TYPES_ALL :
5-
(UInt32, UInt64, Int32, Int64, Float32, Bool)
10+
(UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float16, Float32, Bool)
611
const RUN_FLOAT64_RAND_TESTS = IS_CPU_BACKEND
712

813

@@ -62,8 +67,13 @@ end
6267
@test AK._counter_from_index(1) == UInt64(0)
6368
@test AK._counter_from_index(17) == UInt64(16)
6469

70+
@test AK.raw_uint_type(UInt8) === UInt32
71+
@test AK.raw_uint_type(UInt16) === UInt32
6572
@test AK.raw_uint_type(UInt32) === UInt32
73+
@test AK.raw_uint_type(Int8) === UInt32
74+
@test AK.raw_uint_type(Int16) === UInt32
6675
@test AK.raw_uint_type(Int32) === UInt32
76+
@test AK.raw_uint_type(Float16) === UInt32
6777
@test AK.raw_uint_type(Float32) === UInt32
6878
@test AK.raw_uint_type(UInt64) === UInt64
6979
@test AK.raw_uint_type(Int64) === UInt64
@@ -72,15 +82,22 @@ end
7282
@test AK.raw_uint_type(Float64) === UInt64
7383
end
7484

85+
@test AK.from_uint(UInt8, UInt32(0xabcdef01)) == UInt8(0xab)
86+
@test AK.from_uint(UInt16, UInt32(0xabcdef01)) == UInt16(0xabcd)
7587
@test AK.from_uint(UInt32, 0b1010 % UInt32) == 0b1010 % UInt32
7688
@test AK.from_uint(UInt64, 0b1010 % UInt64) == 0b1010 % UInt64
89+
@test AK.from_uint(Int8, UInt32(0xff000000)) == Int8(-1)
90+
@test AK.from_uint(Int16, UInt32(0xffff0000)) == Int16(-1)
7791
@test AK.from_uint(Int32, 0b11111111111111111111111111111111 % UInt32) == Int32(-1)
7892
@test AK.from_uint(
7993
Int64, 0b1111111111111111111111111111111111111111111111111111111111111111 % UInt64
8094
) == Int64(-1)
95+
@test AK.from_uint(Float16, UInt32(0)) == Float16(0)
8196
@test AK.from_uint(Bool, UInt32(0)) == false
8297
@test AK.from_uint(Bool, UInt32(1)) == true
8398

99+
@test AK.uint32_to_unit_float16(UInt32(0)) == Float16(0)
100+
@test Float16(0) <= AK.uint32_to_unit_float16(typemax(UInt32)) < Float16(1)
84101
@test AK.uint32_to_unit_float32(UInt32(0)) == 0.0f0
85102
@test 0.0f0 <= AK.uint32_to_unit_float32(typemax(UInt32)) < 1.0f0
86103
if RUN_FLOAT64_RAND_TESTS
@@ -126,7 +143,7 @@ end
126143
s1 = AK.rand_scalar(rng, UInt64(1), T)
127144
@test s0 isa T
128145
@test s1 isa T
129-
if T !== Bool
146+
if !(T in (Bool, Float16, UInt8, UInt16, Int8, Int16))
130147
@test s0 != s1
131148
end
132149
if T <: AbstractFloat
@@ -136,8 +153,19 @@ end
136153
end
137154

138155
c = UInt64(42)
156+
@test AK.rand_scalar(rng, c, UInt8) == trunc(UInt8, AK.rand_uint(rng, c, UInt32) >> 24)
157+
@test AK.rand_scalar(rng, c, UInt16) == trunc(UInt16, AK.rand_uint(rng, c, UInt32) >> 16)
158+
@test AK.rand_scalar(
159+
rng, c, Int8
160+
) == reinterpret(Int8, trunc(UInt8, AK.rand_uint(rng, c, UInt32) >> 24))
161+
@test AK.rand_scalar(
162+
rng, c, Int16
163+
) == reinterpret(Int16, trunc(UInt16, AK.rand_uint(rng, c, UInt32) >> 16))
139164
@test AK.rand_scalar(rng, c, Int32) == reinterpret(Int32, AK.rand_uint(rng, c, UInt32))
140165
@test AK.rand_scalar(rng, c, Int64) == reinterpret(Int64, AK.rand_uint(rng, c, UInt64))
166+
@test AK.rand_scalar(rng, c, Float16) == AK.uint32_to_unit_float16(
167+
AK.rand_uint(rng, c, UInt32)
168+
)
141169
@test AK.rand_scalar(rng, c, Float32) == AK.uint32_to_unit_float32(
142170
AK.rand_uint(rng, c, UInt32)
143171
)
@@ -150,7 +178,7 @@ end
150178
bools = [AK.rand_scalar(rng, UInt64(i), Bool) for i in 0:511]
151179
@test any(identity, bools)
152180
@test any(!, bools)
153-
@test_throws ArgumentError AK.rand_scalar(rng, UInt64(0), UInt16)
181+
@test_throws ArgumentError AK.rand_scalar(rng, UInt64(0), UInt128)
154182
end
155183

156184

@@ -185,7 +213,7 @@ end
185213
@test Array(x1) != Array(x2)
186214
end
187215

188-
for T in (Float32, UInt64, Bool)
216+
for T in (Float16, Float32, UInt64, Bool)
189217
xnd = array_from_host(zeros(T, 7, 11, 5))
190218
_assert_rand_matches_reference!(rng, xnd; prefer_threads, block_size=128)
191219
end
@@ -226,7 +254,7 @@ end
226254
@test Array(x1) == Array(ref1)
227255
@test Array(x2) == Array(ref2)
228256

229-
x_bad = array_from_host(zeros(UInt16, 16))
257+
x_bad = zeros(UInt128, 16)
230258
@test_throws ArgumentError AK.rand!(x_bad; prefer_threads)
231259
@test_throws ArgumentError AK.rand!(AK.CounterRNG(0x1), x_bad; prefer_threads)
232260
end

0 commit comments

Comments
 (0)