Skip to content

Commit dfed4a4

Browse files
committed
cleanup
1 parent d5e4cf5 commit dfed4a4

4 files changed

Lines changed: 101 additions & 92 deletions

File tree

ext/MicrofloatsExt.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@ ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2}) = ct.F8E5M2(
1515
ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E8M0FNU}) = ct.F8E8M0FNU(table)
1616
ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float4_E2M1FN}) = ct.F4E2M1FN(table)
1717

18-
# Microfloats are byte-storage primitives (`sizeof == 1`), so cuTile's default
19-
# `bitwidth` (8 * sizeof) over-counts the sub-byte formats. Forward to
20-
# `Microfloats.bitwidth`, which derives the true width from the format's bit
21-
# fields (e.g. `Float4_E2M1FN` → 4), so whole-tile `reinterpret` packs/unpacks
22-
# them through `UInt8` correctly.
18+
# Microfloats are byte-storage primitives, so cuTile's default
19+
# `bitwidth` (8 * sizeof) over-counts the sub-byte formats.
2320
ct.bitwidth(::Type{T}) where {T<:Microfloats.Microfloat} = Microfloats.bitwidth(T)
2421

2522
# E8M0FNU has no sign bit and represents a power of two; tileiras rejects

src/compiler/intrinsics/conversions.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,6 @@ end
5656
@inline lookup_bitwidth(@nospecialize(T::Type)) =
5757
Base.invokelatest(bitwidth, T)::Int
5858

59-
# ── Low-level pack / unpack intrinsics (1:1 with cuda_tile.pack / unpack) ──────
60-
#
61-
# Both Tile IR ops are rank-1 → rank-1 and reinterpret the *whole tile* as a byte
62-
# array (not element-wise like bitcast). They are the primitives that the
63-
# whole-tile `Base.reinterpret` below composes (with `reshape`/`bitcast`) into a
64-
# Julia-semantics reinterpret of any rank. The 8-bit element types are handled by
65-
# `bitcast`, never pack/unpack — matching cutile-python's `pack_to_bytes` /
66-
# `unpack_from_bytes`.
67-
6859
"""
6960
Intrinsics.pack(x::Tile{S,Tuple{N}}) -> Tile{UInt8,Tuple{N*bitwidth(S)÷8}}
7061
@@ -161,17 +152,6 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.unpack), args)
161152
CGVal(result_v, result_type_id, Tile{target_type, Tuple{new_n}}, new_shape)
162153
end
163154

164-
# ── Whole-tile reinterpret (Julia semantics, any rank) ────────────────────────
165-
#
166-
# A tile's row-major byte stream equals the column-major stream of its Julia
167-
# shape (we store the reversed shape), so flatten → width-convert → reshape
168-
# reproduces `Base.reinterpret` element-for-element. Equal widths bitcast;
169-
# crossing 8 bits goes through `pack`/`unpack`; a non-byte → non-byte change
170-
# routes through bytes (pack then unpack). All the shape arithmetic is
171-
# compile-time constant, so the intermediate `reshape`s fold away (identity
172-
# reshapes are eliminated by the canonicalizer) — a rank-1 FP4 reinterpret lowers
173-
# to a single pack/unpack.
174-
175155
# Width-convert a rank-1 tile to element type `T` (rank-1 in, rank-1 out).
176156
@inline function reinterpret_width(::Type{T}, flat::Tile{S}) where {T, S}
177157
bs = bitwidth(S)

test/codegen/operations.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,15 +1181,36 @@ end
11811181
end
11821182

11831183
# pack/unpack require v13.3 — older bytecode rejects with a clear error.
1184-
let kernel = (a, b) -> begin
1184+
# (`literal` since the `+` in the message is a regex metachar to FileCheck.)
1185+
@test @filecheck throws=ct.IRError begin
1186+
@check literal=true "v13.3+"
1187+
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}};
1188+
bytecode_version=v"13.2") do a, b
11851189
pid = ct.bid(1)
11861190
tile = ct.load(a, pid, (16,))
11871191
ct.store(b, pid, reinterpret(UInt16, tile))
11881192
return
11891193
end
1190-
@test_throws "v13.3+" code_tiled(devnull, kernel,
1191-
Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}};
1192-
bytecode_version=v"13.2")
1194+
end
1195+
1196+
# Rank-1 scaled: one UInt8 (8 bits) can't fill a UInt16; caught by unpack.
1197+
@test @filecheck throws=ct.IRError begin
1198+
@check "do not evenly divide"
1199+
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}}) do a, b
1200+
pid = ct.bid(1)
1201+
ct.store(b, pid, reinterpret(UInt16, ct.load(a, pid, (1,))))
1202+
return
1203+
end
1204+
end
1205+
1206+
# reshape-widen: leading dim must equal the ratio (2); 1 fails the final reshape.
1207+
@test @filecheck throws=ct.IRError begin
1208+
@check "same number of elements"
1209+
code_tiled(Tuple{ct.TileArray{UInt8,2,spec2d}, ct.TileArray{UInt16,2,spec2d}}) do a, b
1210+
pid = ct.bid(1)
1211+
ct.store(b, pid, reinterpret(reshape, UInt16, ct.load(a, pid, (1, 4))))
1212+
return
1213+
end
11931214
end
11941215
end
11951216

test/device/tile.jl

Lines changed: 74 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,78 +1233,89 @@ end
12331233
end
12341234
end
12351235

1236-
# Whole-tile `reinterpret` (cuda_tile.pack/unpack, 13.3+) must match Julia's own
1237-
# `reinterpret` element-for-element — not merely round-trip. An asymmetric check
1238-
# (kernel result vs host `reinterpret`) is what actually pins down the byte order
1239-
# and the column-major dimension scaling; a symmetric X→Y→X round-trip would pass
1240-
# under any self-consistent convention.
12411236
@testset "reinterpret matches Base.reinterpret" begin
1242-
# 2-D widen→narrow: UInt16 (2,4) → UInt8 (4,4). Exercises both the column-major
1243-
# leading-dim scaling and the within-element (little-endian) byte order.
1244-
function u16_to_u8(a::ct.TileArray{UInt16,2}, b::ct.TileArray{UInt8,2})
1245-
pid = ct.bid(1)
1246-
ct.store(b, pid, reinterpret(UInt8, ct.load(a, pid, (2, 4))))
1247-
return
1248-
end
1249-
let M = reshape(UInt16[0x0102, 0x0304, 0x0506, 0x0708,
1250-
0x090a, 0x0b0c, 0x0d0e, 0x0f10], 2, 4)
1251-
a = CuArray(M)
1252-
b = CUDA.zeros(UInt8, 4, 4)
1253-
@cuda backend=cuTile blocks=1 u16_to_u8(a, b)
1254-
@test Array(b) == Array(reinterpret(UInt8, M))
1237+
@testset "2D narrowing" begin
1238+
function u16_to_u8(a::ct.TileArray{UInt16,2}, b::ct.TileArray{UInt8,2})
1239+
pid = ct.bid(1)
1240+
ct.store(b, pid, reinterpret(UInt8, ct.load(a, pid, (2, 4))))
1241+
return
1242+
end
1243+
let M = reshape(UInt16[0x0102, 0x0304, 0x0506, 0x0708,
1244+
0x090a, 0x0b0c, 0x0d0e, 0x0f10], 2, 4)
1245+
a = CuArray(M)
1246+
b = CUDA.zeros(UInt8, 4, 4)
1247+
@cuda backend=cuTile blocks=1 u16_to_u8(a, b)
1248+
@test Array(b) == Array(reinterpret(UInt8, M))
1249+
end
12551250
end
12561251

1257-
# 1-D narrow→widen the other direction: UInt8 (8,) → UInt16 (4,).
1258-
function u8_to_u16(a::ct.TileArray{UInt8,1}, b::ct.TileArray{UInt16,1})
1259-
pid = ct.bid(1)
1260-
ct.store(b, pid, reinterpret(UInt16, ct.load(a, pid, (8,))))
1261-
return
1262-
end
1263-
let v = UInt8[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
1264-
a = CuArray(v)
1265-
b = CUDA.zeros(UInt16, 4)
1266-
@cuda backend=cuTile blocks=1 u8_to_u16(a, b)
1267-
@test Array(b) == reinterpret(UInt16, v)
1252+
@testset "1D widening" begin
1253+
function u8_to_u16(a::ct.TileArray{UInt8,1}, b::ct.TileArray{UInt16,1})
1254+
pid = ct.bid(1)
1255+
ct.store(b, pid, reinterpret(UInt16, ct.load(a, pid, (8,))))
1256+
return
1257+
end
1258+
let v = UInt8[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
1259+
a = CuArray(v)
1260+
b = CUDA.zeros(UInt16, 4)
1261+
@cuda backend=cuTile blocks=1 u8_to_u16(a, b)
1262+
@test Array(b) == reinterpret(UInt16, v)
1263+
end
12681264
end
12691265

1270-
# `reshape`-form: widening drops the leading dim. UInt8 (2,4) → UInt16 (4,).
1271-
function u8_reshape_u16(a::ct.TileArray{UInt8,2}, b::ct.TileArray{UInt16,1})
1272-
pid = ct.bid(1)
1273-
ct.store(b, pid, reinterpret(reshape, UInt16, ct.load(a, pid, (2, 4))))
1274-
return
1275-
end
1276-
let M = reshape(UInt8[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08], 2, 4)
1277-
a = CuArray(M)
1278-
b = CUDA.zeros(UInt16, 4)
1279-
@cuda backend=cuTile blocks=1 u8_reshape_u16(a, b)
1280-
@test Array(b) == reinterpret(reshape, UInt16, M)
1266+
@testset "narrowing: reshape argument drops dim" begin
1267+
function u8_reshape_u16(a::ct.TileArray{UInt8,2}, b::ct.TileArray{UInt16,1})
1268+
pid = ct.bid(1)
1269+
ct.store(b, pid, reinterpret(reshape, UInt16, ct.load(a, pid, (2, 4))))
1270+
return
1271+
end
1272+
let M = reshape(UInt8[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08], 2, 4)
1273+
a = CuArray(M)
1274+
b = CUDA.zeros(UInt16, 4)
1275+
@cuda backend=cuTile blocks=1 u8_reshape_u16(a, b)
1276+
@test Array(b) == reinterpret(reshape, UInt16, M)
1277+
end
12811278
end
12821279

1283-
# Equal-width route lowers to `bitcast` (shape preserved), not pack/unpack.
1284-
# UInt32 → Float32 is a real bitcast (distinct Tile IR dtypes i32 vs f32).
1285-
function u32_to_f32(a::ct.TileArray{UInt32,1}, b::ct.TileArray{Float32,1})
1286-
pid = ct.bid(1)
1287-
ct.store(b, pid, reinterpret(Float32, ct.load(a, pid, (16,))))
1288-
return
1289-
end
1290-
let v = rand(UInt32, 16)
1291-
a = CuArray(v)
1292-
b = CUDA.zeros(Float32, 16)
1293-
@cuda backend=cuTile blocks=1 u32_to_f32(a, b)
1294-
@test reinterpret(UInt32, Array(b)) == v # bit-exact (avoids NaN ≠ NaN)
1280+
@testset "widening: reshape argument inserts dim" begin
1281+
function u16_reshape_u8(a::ct.TileArray{UInt16,1}, b::ct.TileArray{UInt8,2})
1282+
pid = ct.bid(1)
1283+
ct.store(b, pid, reinterpret(reshape, UInt8, ct.load(a, pid, (4,))))
1284+
return
1285+
end
1286+
let M = UInt16[0x0201, 0x0403, 0x0605, 0x0807]
1287+
a = CuArray(M)
1288+
b = CUDA.zeros(UInt8, 2, 4)
1289+
@cuda backend=cuTile blocks=1 u16_reshape_u8(a, b)
1290+
@test Array(b) == reinterpret(reshape, UInt8, M)
1291+
end
12951292
end
12961293

1297-
# Signless integer no-op (Int32 ↔ UInt32 are both i32): emits no op, but the
1298-
# result must still equal Julia's reinterpret, with the 2-D shape preserved.
1299-
function i32_to_u32_2d(a::ct.TileArray{Int32,2}, b::ct.TileArray{UInt32,2})
1300-
pid = ct.bid(1)
1301-
ct.store(b, pid, reinterpret(UInt32, ct.load(a, pid, (4, 4))))
1302-
return
1294+
@testset "Equal-with round-trip preserves values and shape" begin
1295+
function u32_to_f32(a::ct.TileArray{UInt32,1}, b::ct.TileArray{Float32,1})
1296+
pid = ct.bid(1)
1297+
ct.store(b, pid, reinterpret(Float32, ct.load(a, pid, (16,))))
1298+
return
1299+
end
1300+
let v = rand(UInt32, 16)
1301+
a = CuArray(v)
1302+
b = CUDA.zeros(Float32, 16)
1303+
@cuda backend=cuTile blocks=1 u32_to_f32(a, b)
1304+
@test reinterpret(UInt32, Array(b)) == v # bit-exact (avoids NaN ≠ NaN)
1305+
end
13031306
end
1304-
let M = reshape(Int32.(-8:7), 4, 4)
1305-
a = CuArray(M)
1306-
b = CUDA.zeros(UInt32, 4, 4)
1307-
@cuda backend=cuTile blocks=1 i32_to_u32_2d(a, b)
1308-
@test Array(b) == reinterpret(UInt32, M)
1307+
1308+
@testset "Int32 to UInt32" begin
1309+
function i32_to_u32_2d(a::ct.TileArray{Int32,2}, b::ct.TileArray{UInt32,2})
1310+
pid = ct.bid(1)
1311+
ct.store(b, pid, reinterpret(UInt32, ct.load(a, pid, (4, 4))))
1312+
return
1313+
end
1314+
let M = reshape(Int32.(-8:7), 4, 4)
1315+
a = CuArray(M)
1316+
b = CUDA.zeros(UInt32, 4, 4)
1317+
@cuda backend=cuTile blocks=1 i32_to_u32_2d(a, b)
1318+
@test Array(b) == reinterpret(UInt32, M)
1319+
end
13091320
end
13101321
end

0 commit comments

Comments
 (0)