Skip to content

Commit 5cf39dd

Browse files
committed
cleanup
1 parent 68d9f66 commit 5cf39dd

4 files changed

Lines changed: 101 additions & 92 deletions

File tree

ext/MicrofloatsExt.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@ ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2}) = ct.F8E5M2(
1515
ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E8M0FNU}) = ct.F8E8M0FNU(table)
1616
ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float4_E2M1FN}) = ct.F4E2M1FN(table)
1717

18-
# Microfloats are byte-storage primitives (`sizeof == 1`), so cuTile's default
19-
# `bitwidth` (8 * sizeof) over-counts the sub-byte formats. Forward to
20-
# `Microfloats.bitwidth`, which derives the true width from the format's bit
21-
# fields (e.g. `Float4_E2M1FN` → 4), so whole-tile `reinterpret` packs/unpacks
22-
# them through `UInt8` correctly.
18+
# Microfloats are byte-storage primitives, so cuTile's default
19+
# `bitwidth` (8 * sizeof) over-counts the sub-byte formats.
2320
ct.bitwidth(::Type{T}) where {T<:Microfloats.Microfloat} = Microfloats.bitwidth(T)
2421

2522
# E8M0FNU has no sign bit and represents a power of two; tileiras rejects

src/compiler/intrinsics/conversions.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,6 @@ end
5656
@inline lookup_bitwidth(@nospecialize(T::Type)) =
5757
Base.invokelatest(bitwidth, T)::Int
5858

59-
# ── Low-level pack / unpack intrinsics (1:1 with cuda_tile.pack / unpack) ──────
60-
#
61-
# Both Tile IR ops are rank-1 → rank-1 and reinterpret the *whole tile* as a byte
62-
# array (not element-wise like bitcast). They are the primitives that the
63-
# whole-tile `Base.reinterpret` below composes (with `reshape`/`bitcast`) into a
64-
# Julia-semantics reinterpret of any rank. The 8-bit element types are handled by
65-
# `bitcast`, never pack/unpack — matching cutile-python's `pack_to_bytes` /
66-
# `unpack_from_bytes`.
67-
6859
"""
6960
Intrinsics.pack(x::Tile{S,Tuple{N}}) -> Tile{UInt8,Tuple{N*bitwidth(S)÷8}}
7061
@@ -161,17 +152,6 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.unpack), args)
161152
CGVal(result_v, result_type_id, Tile{target_type, Tuple{new_n}}, new_shape)
162153
end
163154

164-
# ── Whole-tile reinterpret (Julia semantics, any rank) ────────────────────────
165-
#
166-
# A tile's row-major byte stream equals the column-major stream of its Julia
167-
# shape (we store the reversed shape), so flatten → width-convert → reshape
168-
# reproduces `Base.reinterpret` element-for-element. Equal widths bitcast;
169-
# crossing 8 bits goes through `pack`/`unpack`; a non-byte → non-byte change
170-
# routes through bytes (pack then unpack). All the shape arithmetic is
171-
# compile-time constant, so the intermediate `reshape`s fold away (identity
172-
# reshapes are eliminated by the canonicalizer) — a rank-1 FP4 reinterpret lowers
173-
# to a single pack/unpack.
174-
175155
# Width-convert a rank-1 tile to element type `T` (rank-1 in, rank-1 out).
176156
@inline function reinterpret_width(::Type{T}, flat::Tile{S}) where {T, S}
177157
bs = bitwidth(S)

test/codegen/operations.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,15 +1181,36 @@ end
11811181
end
11821182

11831183
# pack/unpack require v13.3 — older bytecode rejects with a clear error.
1184-
let kernel = (a, b) -> begin
1184+
# (`literal` since the `+` in the message is a regex metachar to FileCheck.)
1185+
@test @filecheck throws=ct.IRError begin
1186+
@check literal=true "v13.3+"
1187+
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}};
1188+
bytecode_version=v"13.2") do a, b
11851189
pid = ct.bid(1)
11861190
tile = ct.load(a, pid, (16,))
11871191
ct.store(b, pid, reinterpret(UInt16, tile))
11881192
return
11891193
end
1190-
@test_throws "v13.3+" code_tiled(devnull, kernel,
1191-
Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}};
1192-
bytecode_version=v"13.2")
1194+
end
1195+
1196+
# Rank-1 scaled: one UInt8 (8 bits) can't fill a UInt16; caught by unpack.
1197+
@test @filecheck throws=ct.IRError begin
1198+
@check "do not evenly divide"
1199+
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}}) do a, b
1200+
pid = ct.bid(1)
1201+
ct.store(b, pid, reinterpret(UInt16, ct.load(a, pid, (1,))))
1202+
return
1203+
end
1204+
end
1205+
1206+
# reshape-widen: leading dim must equal the ratio (2); 1 fails the final reshape.
1207+
@test @filecheck throws=ct.IRError begin
1208+
@check "same number of elements"
1209+
code_tiled(Tuple{ct.TileArray{UInt8,2,spec2d}, ct.TileArray{UInt16,2,spec2d}}) do a, b
1210+
pid = ct.bid(1)
1211+
ct.store(b, pid, reinterpret(reshape, UInt16, ct.load(a, pid, (1, 4))))
1212+
return
1213+
end
11931214
end
11941215
end
11951216

test/device/tile.jl

Lines changed: 74 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,78 +1164,89 @@ end
11641164
end
11651165
end
11661166

1167-
# Whole-tile `reinterpret` (cuda_tile.pack/unpack, 13.3+) must match Julia's own
1168-
# `reinterpret` element-for-element — not merely round-trip. An asymmetric check
1169-
# (kernel result vs host `reinterpret`) is what actually pins down the byte order
1170-
# and the column-major dimension scaling; a symmetric X→Y→X round-trip would pass
1171-
# under any self-consistent convention.
11721167
@testset "reinterpret matches Base.reinterpret" begin
1173-
# 2-D widen→narrow: UInt16 (2,4) → UInt8 (4,4). Exercises both the column-major
1174-
# leading-dim scaling and the within-element (little-endian) byte order.
1175-
function u16_to_u8(a::ct.TileArray{UInt16,2}, b::ct.TileArray{UInt8,2})
1176-
pid = ct.bid(1)
1177-
ct.store(b, pid, reinterpret(UInt8, ct.load(a, pid, (2, 4))))
1178-
return
1179-
end
1180-
let M = reshape(UInt16[0x0102, 0x0304, 0x0506, 0x0708,
1181-
0x090a, 0x0b0c, 0x0d0e, 0x0f10], 2, 4)
1182-
a = CuArray(M)
1183-
b = CUDA.zeros(UInt8, 4, 4)
1184-
@cuda backend=cuTile blocks=1 u16_to_u8(a, b)
1185-
@test Array(b) == Array(reinterpret(UInt8, M))
1168+
@testset "2D narrowing" begin
1169+
function u16_to_u8(a::ct.TileArray{UInt16,2}, b::ct.TileArray{UInt8,2})
1170+
pid = ct.bid(1)
1171+
ct.store(b, pid, reinterpret(UInt8, ct.load(a, pid, (2, 4))))
1172+
return
1173+
end
1174+
let M = reshape(UInt16[0x0102, 0x0304, 0x0506, 0x0708,
1175+
0x090a, 0x0b0c, 0x0d0e, 0x0f10], 2, 4)
1176+
a = CuArray(M)
1177+
b = CUDA.zeros(UInt8, 4, 4)
1178+
@cuda backend=cuTile blocks=1 u16_to_u8(a, b)
1179+
@test Array(b) == Array(reinterpret(UInt8, M))
1180+
end
11861181
end
11871182

1188-
# 1-D narrow→widen the other direction: UInt8 (8,) → UInt16 (4,).
1189-
function u8_to_u16(a::ct.TileArray{UInt8,1}, b::ct.TileArray{UInt16,1})
1190-
pid = ct.bid(1)
1191-
ct.store(b, pid, reinterpret(UInt16, ct.load(a, pid, (8,))))
1192-
return
1193-
end
1194-
let v = UInt8[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
1195-
a = CuArray(v)
1196-
b = CUDA.zeros(UInt16, 4)
1197-
@cuda backend=cuTile blocks=1 u8_to_u16(a, b)
1198-
@test Array(b) == reinterpret(UInt16, v)
1183+
@testset "1D widening" begin
1184+
function u8_to_u16(a::ct.TileArray{UInt8,1}, b::ct.TileArray{UInt16,1})
1185+
pid = ct.bid(1)
1186+
ct.store(b, pid, reinterpret(UInt16, ct.load(a, pid, (8,))))
1187+
return
1188+
end
1189+
let v = UInt8[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
1190+
a = CuArray(v)
1191+
b = CUDA.zeros(UInt16, 4)
1192+
@cuda backend=cuTile blocks=1 u8_to_u16(a, b)
1193+
@test Array(b) == reinterpret(UInt16, v)
1194+
end
11991195
end
12001196

1201-
# `reshape`-form: widening drops the leading dim. UInt8 (2,4) → UInt16 (4,).
1202-
function u8_reshape_u16(a::ct.TileArray{UInt8,2}, b::ct.TileArray{UInt16,1})
1203-
pid = ct.bid(1)
1204-
ct.store(b, pid, reinterpret(reshape, UInt16, ct.load(a, pid, (2, 4))))
1205-
return
1206-
end
1207-
let M = reshape(UInt8[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08], 2, 4)
1208-
a = CuArray(M)
1209-
b = CUDA.zeros(UInt16, 4)
1210-
@cuda backend=cuTile blocks=1 u8_reshape_u16(a, b)
1211-
@test Array(b) == reinterpret(reshape, UInt16, M)
1197+
@testset "narrowing: reshape argument drops dim" begin
1198+
function u8_reshape_u16(a::ct.TileArray{UInt8,2}, b::ct.TileArray{UInt16,1})
1199+
pid = ct.bid(1)
1200+
ct.store(b, pid, reinterpret(reshape, UInt16, ct.load(a, pid, (2, 4))))
1201+
return
1202+
end
1203+
let M = reshape(UInt8[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08], 2, 4)
1204+
a = CuArray(M)
1205+
b = CUDA.zeros(UInt16, 4)
1206+
@cuda backend=cuTile blocks=1 u8_reshape_u16(a, b)
1207+
@test Array(b) == reinterpret(reshape, UInt16, M)
1208+
end
12121209
end
12131210

1214-
# Equal-width route lowers to `bitcast` (shape preserved), not pack/unpack.
1215-
# UInt32 → Float32 is a real bitcast (distinct Tile IR dtypes i32 vs f32).
1216-
function u32_to_f32(a::ct.TileArray{UInt32,1}, b::ct.TileArray{Float32,1})
1217-
pid = ct.bid(1)
1218-
ct.store(b, pid, reinterpret(Float32, ct.load(a, pid, (16,))))
1219-
return
1220-
end
1221-
let v = rand(UInt32, 16)
1222-
a = CuArray(v)
1223-
b = CUDA.zeros(Float32, 16)
1224-
@cuda backend=cuTile blocks=1 u32_to_f32(a, b)
1225-
@test reinterpret(UInt32, Array(b)) == v # bit-exact (avoids NaN ≠ NaN)
1211+
@testset "widening: reshape argument inserts dim" begin
1212+
function u16_reshape_u8(a::ct.TileArray{UInt16,1}, b::ct.TileArray{UInt8,2})
1213+
pid = ct.bid(1)
1214+
ct.store(b, pid, reinterpret(reshape, UInt8, ct.load(a, pid, (4,))))
1215+
return
1216+
end
1217+
let M = UInt16[0x0201, 0x0403, 0x0605, 0x0807]
1218+
a = CuArray(M)
1219+
b = CUDA.zeros(UInt8, 2, 4)
1220+
@cuda backend=cuTile blocks=1 u16_reshape_u8(a, b)
1221+
@test Array(b) == reinterpret(reshape, UInt8, M)
1222+
end
12261223
end
12271224

1228-
# Signless integer no-op (Int32 ↔ UInt32 are both i32): emits no op, but the
1229-
# result must still equal Julia's reinterpret, with the 2-D shape preserved.
1230-
function i32_to_u32_2d(a::ct.TileArray{Int32,2}, b::ct.TileArray{UInt32,2})
1231-
pid = ct.bid(1)
1232-
ct.store(b, pid, reinterpret(UInt32, ct.load(a, pid, (4, 4))))
1233-
return
1225+
@testset "Equal-with round-trip preserves values and shape" begin
1226+
function u32_to_f32(a::ct.TileArray{UInt32,1}, b::ct.TileArray{Float32,1})
1227+
pid = ct.bid(1)
1228+
ct.store(b, pid, reinterpret(Float32, ct.load(a, pid, (16,))))
1229+
return
1230+
end
1231+
let v = rand(UInt32, 16)
1232+
a = CuArray(v)
1233+
b = CUDA.zeros(Float32, 16)
1234+
@cuda backend=cuTile blocks=1 u32_to_f32(a, b)
1235+
@test reinterpret(UInt32, Array(b)) == v # bit-exact (avoids NaN ≠ NaN)
1236+
end
12341237
end
1235-
let M = reshape(Int32.(-8:7), 4, 4)
1236-
a = CuArray(M)
1237-
b = CUDA.zeros(UInt32, 4, 4)
1238-
@cuda backend=cuTile blocks=1 i32_to_u32_2d(a, b)
1239-
@test Array(b) == reinterpret(UInt32, M)
1238+
1239+
@testset "Int32 to UInt32" begin
1240+
function i32_to_u32_2d(a::ct.TileArray{Int32,2}, b::ct.TileArray{UInt32,2})
1241+
pid = ct.bid(1)
1242+
ct.store(b, pid, reinterpret(UInt32, ct.load(a, pid, (4, 4))))
1243+
return
1244+
end
1245+
let M = reshape(Int32.(-8:7), 4, 4)
1246+
a = CuArray(M)
1247+
b = CUDA.zeros(UInt32, 4, 4)
1248+
@cuda backend=cuTile blocks=1 i32_to_u32_2d(a, b)
1249+
@test Array(b) == reinterpret(UInt32, M)
1250+
end
12401251
end
12411252
end

0 commit comments

Comments
 (0)