Skip to content

Commit fc5ede9

Browse files
AntonOrestenclaude
andauthored
Add reinterpret as interface to bitcast, pack, and unpack (#238)
* Add `reinterpret` as interface to `bitcast`, `pack`, and `unpack` * reinterpret: centralize validation in emit so invalid casts fail cleanly The shape helpers and pack/unpack tfuncs ran inside the kernel-inferred path, where two failure modes produced confusing errors: - A tfunc returning `nothing` on an indivisible width left the result untypable, surfacing downstream as `internal error: invalid terminators`. - A `throw(ArgumentError(...))` in a shape helper became an unsupported `String` in kernel IR (`format_string`/`unsupported String` error), masking the intended message. Make both layers total: pack/unpack tfuncs always return a concrete type (via `fld`), and the shape helpers are pure arithmetic. Validation now lives solely in the pack/unpack/reshape emit, which throws a clear `IRError` (e.g. "unpack: 1 bytes do not evenly divide into Float32"). Valid reinterprets are unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * cleanup * Make bytes in 2D narrowing test ascending * Move `Base.reinterpret` methods to operations.jl --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 1a3958e commit fc5ede9

8 files changed

Lines changed: 490 additions & 2 deletions

File tree

ext/MicrofloatsExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module MicrofloatsExt
22

33
import cuTile as ct
4+
import Microfloats
45

56
using Microfloats: Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU, Float4_E2M1FN
67

@@ -14,6 +15,10 @@ ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2}) = ct.F8E5M2(
1415
ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E8M0FNU}) = ct.F8E8M0FNU(table)
1516
ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float4_E2M1FN}) = ct.F4E2M1FN(table)
1617

18+
# Microfloats are byte-storage primitives, so cuTile's default
19+
# `bitwidth` (8 * sizeof) over-counts the sub-byte formats.
20+
ct.bitwidth(::Type{T}) where {T<:Microfloats.Microfloat} = Microfloats.bitwidth(T)
21+
1722
# E8M0FNU has no sign bit and represents a power of two; tileiras rejects
1823
# nearest-even on f32→E8M0FNU (only `zero` and `positive_inf` are valid).
1924
ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero

src/bytecode/encodings.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ module Opcode
9696
const XOrIOp = 108
9797
const YieldOp = 109
9898
const Atan2Op = 110 # since 13.2
99+
const PackOp = 111 # since 13.3
100+
const UnpackOp = 112 # since 13.3
99101
end
100102

101103
# Enums for operation attributes
@@ -1796,6 +1798,36 @@ function encode_BitcastOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
17961798
return new_op!(cb)
17971799
end
17981800

1801+
"""
1802+
encode_PackOp!(cb, result_type, source) -> Value
1803+
1804+
Pack a rank-1 numeric tile into a rank-1 `tile<i8>`. Unlike `bitcast`, this is
1805+
not element-wise: the whole tile is reinterpreted as a byte array, so the result
1806+
length is the input's total byte count. The source must not be an 8-bit type
1807+
(use `bitcast`). Since 13.3. Opcode: 111
1808+
"""
1809+
function encode_PackOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
1810+
encode_varint!(cb.buf, Opcode.PackOp)
1811+
encode_typeid!(cb.buf, result_type)
1812+
encode_operand!(cb.buf, source)
1813+
return new_op!(cb)
1814+
end
1815+
1816+
"""
1817+
encode_UnpackOp!(cb, result_type, source) -> Value
1818+
1819+
Unpack a rank-1 `tile<i8>` into a rank-1 numeric tile (the inverse of
1820+
[`encode_PackOp!`](@ref)). The input byte count must equal the output's total
1821+
byte count. The result must not be an 8-bit type (use `bitcast`). Since 13.3.
1822+
Opcode: 112
1823+
"""
1824+
function encode_UnpackOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
1825+
encode_varint!(cb.buf, Opcode.UnpackOp)
1826+
encode_typeid!(cb.buf, result_type)
1827+
encode_operand!(cb.buf, source)
1828+
return new_op!(cb)
1829+
end
1830+
17991831
"""
18001832
encode_BroadcastOp!(cb, result_type, source) -> Value
18011833

src/compiler/intrinsics/conversions.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,105 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.bitcast), args)
5353
CGVal(result_v, result_type_id, result_jltype, source.shape)
5454
end
5555

56+
@inline lookup_bitwidth(@nospecialize(T::Type)) =
57+
Base.invokelatest(bitwidth, T)::Int
58+
59+
"""
60+
Intrinsics.pack(x::Tile{S,Tuple{N}}) -> Tile{UInt8,Tuple{N*bitwidth(S)÷8}}
61+
62+
Pack a rank-1 numeric tile into a rank-1 `UInt8` tile (the tile's bits viewed as
63+
a byte array); lowers to `cuda_tile.pack`. `S` must not be 8-bit (use `bitcast`).
64+
Requires Tile IR bytecode v13.3+.
65+
"""
66+
@intrinsic pack(x)
67+
function tfunc(𝕃, ::typeof(Intrinsics.pack), @nospecialize(x))
68+
src = CC.widenconst(x)
69+
src <: Tile || return nothing
70+
S = src.parameters[1]
71+
Shape = src.parameters[2]
72+
(S isa Type && Shape isa Type) || return nothing
73+
dims = Shape.parameters
74+
length(dims) == 1 || return nothing
75+
n = dims[1]::Int
76+
bs = lookup_bitwidth(S)
77+
return Tile{UInt8, Tuple{fld(n * bs, 8)}}
78+
end
79+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.pack), args)
80+
cb = ctx.cb
81+
tt = ctx.tt
82+
83+
source = @something emit_value!(ctx, args[1]) throw(IRError("pack: cannot resolve source"))
84+
tt.version >= v"13.3" ||
85+
throw(IRError("cuda_tile.pack requires Tile IR bytecode v13.3+, got v$(tt.version)"))
86+
length(source.shape) == 1 ||
87+
throw(IRError("pack: requires a rank-1 tile, got a $(length(source.shape))-D tile"))
88+
89+
src_type = CC.widenconst(source.jltype)
90+
S = eltype(src_type)
91+
sbits = lookup_bitwidth(S)
92+
sbits == 8 &&
93+
throw(IRError("pack: 8-bit element type $S should be reinterpreted via bitcast, not packed"))
94+
n = source.shape[1]
95+
(n * sbits) % 8 == 0 ||
96+
throw(IRError("pack: a $n-element $S tile ($(n * sbits) bits) is not a whole number of bytes"))
97+
new_n = (n * sbits) ÷ 8
98+
99+
new_shape = RowMajorShape([new_n])
100+
result_type_id = tile_type!(tt, lookup_dtype!(tt, UInt8), new_shape)
101+
result_v = encode_PackOp!(cb, result_type_id, source.v)
102+
CGVal(result_v, result_type_id, Tile{UInt8, Tuple{new_n}}, new_shape)
103+
end
104+
105+
"""
106+
Intrinsics.unpack(x::Tile{UInt8,Tuple{N}}, ::Type{T}) -> Tile{T,Tuple{N*8÷bitwidth(T)}}
107+
108+
Unpack a rank-1 `UInt8` tile into a rank-1 numeric tile of element type `T` (the
109+
inverse of [`pack`](@ref Intrinsics.pack)); lowers to `cuda_tile.unpack`. `T`
110+
must be a compile-time constant and must not be 8-bit (use `bitcast`). Requires
111+
Tile IR bytecode v13.3+.
112+
"""
113+
@intrinsic unpack(x, ::Type{T}) where {T}
114+
function tfunc(𝕃, ::typeof(Intrinsics.unpack), @nospecialize(x), @nospecialize(target_type))
115+
T = instanceof_tfunc(target_type)
116+
T === nothing && return nothing
117+
src = CC.widenconst(x)
118+
src <: Tile || return nothing
119+
Shape = src.parameters[2]
120+
Shape isa Type || return nothing
121+
dims = Shape.parameters
122+
length(dims) == 1 || return nothing
123+
n = dims[1]::Int
124+
bt = lookup_bitwidth(T)
125+
return Tile{T, Tuple{fld(n * 8, bt)}}
126+
end
127+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.unpack), args)
128+
cb = ctx.cb
129+
tt = ctx.tt
130+
131+
source = @something emit_value!(ctx, args[1]) throw(IRError("unpack: cannot resolve source"))
132+
target_type = @something get_constant(ctx, args[2]) throw(IRError("unpack: requires compile-time target type"))
133+
tt.version >= v"13.3" ||
134+
throw(IRError("cuda_tile.unpack requires Tile IR bytecode v13.3+, got v$(tt.version)"))
135+
length(source.shape) == 1 ||
136+
throw(IRError("unpack: requires a rank-1 tile, got a $(length(source.shape))-D tile"))
137+
138+
src_type = CC.widenconst(source.jltype)
139+
eltype(src_type) === UInt8 ||
140+
throw(IRError("unpack: requires a UInt8 tile, got $(eltype(src_type))"))
141+
tbits = lookup_bitwidth(target_type)
142+
tbits == 8 &&
143+
throw(IRError("unpack: 8-bit target $target_type should be reinterpreted via bitcast, not unpacked"))
144+
n = source.shape[1]
145+
(n * 8) % tbits == 0 ||
146+
throw(IRError("unpack: $n bytes ($(n * 8) bits) do not evenly divide into $target_type ($tbits-bit) elements"))
147+
new_n = (n * 8) ÷ tbits
148+
149+
new_shape = RowMajorShape([new_n])
150+
result_type_id = tile_type!(tt, lookup_dtype!(tt, target_type), new_shape)
151+
result_v = encode_UnpackOp!(cb, result_type_id, source.v)
152+
CGVal(result_v, result_type_id, Tile{target_type, Tuple{new_n}}, new_shape)
153+
end
154+
56155
"""
57156
Intrinsics.exti(x::Tile{<:Integer}, ::Type{T}, s::Signedness.T) -> Tile{T} where {T<:Integer}
58157

src/language/operations.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,82 @@ Equivalent to single-arg `permutedims`.
932932
end
933933
end
934934

935+
# Width-convert a rank-1 tile to element type `T` (rank-1 in, rank-1 out).
936+
@inline function reinterpret_width(::Type{T}, flat::Tile{S}) where {T, S}
937+
bs = bitwidth(S)
938+
bt = bitwidth(T)
939+
if bs == bt
940+
return Intrinsics.bitcast(flat, T) # same width
941+
elseif bt == 8
942+
return Intrinsics.bitcast(Intrinsics.pack(flat), T) # S → bytes → T8
943+
elseif bs == 8
944+
return Intrinsics.unpack(Intrinsics.bitcast(flat, UInt8), T) # S8 → bytes → T
945+
else
946+
return Intrinsics.unpack(Intrinsics.pack(flat), T) # S → bytes → T
947+
end
948+
end
949+
950+
# Result shape for `reinterpret(T, x)`: rescale the leading (column-major)
951+
# dimension by the element-width ratio, like `reinterpret(T, ::AbstractArray)`.
952+
@inline function reinterpret_scaled_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
953+
bs = bitwidth(S)
954+
bt = bitwidth(T)
955+
N == 0 && return () # 0-D: only equal-width is valid; cross-width caught at emit
956+
return (fld(sz[1] * bs, bt), Base.tail(sz)...)
957+
end
958+
959+
# Result shape for `reinterpret(reshape, T, x)`: drop the leading dim on widening
960+
# (it must equal the ratio), prepend one on narrowing, like the array version.
961+
@inline function reinterpret_reshape_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
962+
bs = bitwidth(S)
963+
bt = bitwidth(T)
964+
bs == bt && return sz
965+
N == 0 && return () # cross-width on a 0-D tile is invalid; caught at emit
966+
return bt > bs ? Base.tail(sz) : (div(bs, bt), sz...)
967+
end
968+
969+
"""
970+
Base.reinterpret(::Type{T}, x::Tile) -> Tile{T}
971+
972+
Reinterpret the *whole tile* `x` as a tile of element type `T`, like
973+
`reinterpret(T, ::AbstractArray)`: the underlying bits are viewed as a contiguous
974+
(column-major) block and the leading dimension is rescaled by the ratio of
975+
element widths. Lowers to `cuda_tile.bitcast` for equal widths and to
976+
`cuda_tile.pack`/`unpack` (via `reshape` to rank-1) when widths differ.
977+
978+
This is how sub-byte formats move through global memory: a `Tile{UInt8,(N,)}`
979+
reinterprets to a `Tile{Float4_E2M1FN,(2N,)}` and back, so FP4 data can be stored
980+
in a `UInt8` array. The total bit-width is preserved, so it must divide evenly.
981+
982+
Note `reinterpret.(T, x)` (with a dot) is the unrelated *element-wise* broadcast,
983+
which keeps the shape and requires `T` to be the same width as `eltype(x)`.
984+
985+
```julia
986+
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)}
987+
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)}
988+
vals = convert(ct.Tile{Float32}, fp4) # widen for compute
989+
```
990+
"""
991+
@inline function Base.reinterpret(::Type{T}, x::Tile) where {T}
992+
rshape = reinterpret_scaled_shape(T, eltype(x), size(x))
993+
flat = Intrinsics.reshape(x, (prod(size(x)),))
994+
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
995+
end
996+
997+
"""
998+
Base.reinterpret(reshape, ::Type{T}, x::Tile) -> Tile{T}
999+
1000+
The `reshape`-form whole-tile reinterpret, mirroring
1001+
`reinterpret(reshape, T, ::AbstractArray)`: instead of rescaling the leading
1002+
dimension it *removes* it when widening (the leading dim must equal
1003+
`bitwidth(T) ÷ bitwidth(eltype(x))`) and *prepends* one when narrowing.
1004+
"""
1005+
@inline function Base.reinterpret(::typeof(reshape), ::Type{T}, x::Tile) where {T}
1006+
rshape = reinterpret_reshape_shape(T, eltype(x), size(x))
1007+
flat = Intrinsics.reshape(x, (prod(size(x)),))
1008+
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
1009+
end
1010+
9351011
@inline Base.convert(::Type{Tile{T}}, tile::Tile{T}) where {T} = tile
9361012
@inline Base.convert(::Type{Tile{T2}}, tile::Tile{T1, Shape}) where {T1, T2, Shape} =
9371013
map(T2, tile)

src/language/types.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,23 @@ similar_type(::Type{Tile{T, Shape}}, ::Type{U}, new_shape::Tuple) where {T, Shap
377377
similar_type(::Type{<:Tile{T}}, ::Type{U}) where {T, U} = Tile{U}
378378
similar_type(::Type, ::Type{T}) where {T} = T # fallback for non-Tile types
379379

380+
"""
381+
bitwidth(::Type{T}) -> Int
382+
383+
Number of bits a single element of `T` occupies in a Tile IR tile. Used by the
384+
whole-tile [`reinterpret`](@ref Base.reinterpret(::Type, ::Tile)) to scale the
385+
tile shape across a change of element width (e.g. `UInt8` ↔ `Float4_E2M1FN`,
386+
8 bits ↔ 4 bits).
387+
388+
The default is `8 * sizeof(T)`, which is correct for the standard integer and
389+
floating-point types and for the byte-wide `Float8_*` formats. Sub-byte formats
390+
whose `sizeof` rounds up to a whole byte (e.g. `Float4_E2M1FN`, 4 bits but
391+
`sizeof == 1`) override this; the `Microfloats` extension forwards to
392+
`Microfloats.bitwidth`, which derives the true width from the format's bit
393+
fields. Matches the `bitwidth` convention used by `Microfloats`/`Narrow`.
394+
"""
395+
bitwidth(::Type{T}) where {T} = 8 * sizeof(T)
396+
380397

381398
"""
382399
TFloat32 <: AbstractFloat

test/codegen/operations.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,80 @@ end
11401140
end
11411141
end
11421142

1143+
@testset "reinterpret (whole-tile)" begin
1144+
# Equal width, different Tile IR dtype (Float16 -> Int16): whole-tile
1145+
# reinterpret is a plain bitcast — no pack/unpack, shape preserved.
1146+
@test @filecheck begin
1147+
@check_label "entry"
1148+
code_tiled(Tuple{ct.TileArray{Float16,1,spec1d}, ct.TileArray{Int16,1,spec1d}}) do a, b
1149+
pid = ct.bid(1)
1150+
tile = ct.load(a, pid, (16,))
1151+
@check "bitcast"
1152+
@check_not "pack"
1153+
ct.store(b, pid, reinterpret(Int16, tile))
1154+
return
1155+
end
1156+
end
1157+
1158+
# Widen UInt8 -> UInt16 (1D): lowers to a single unpack, identity reshapes
1159+
# folded away.
1160+
@test @filecheck begin
1161+
@check_label "entry"
1162+
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}}) do a, b
1163+
pid = ct.bid(1)
1164+
tile = ct.load(a, pid, (16,))
1165+
@check "unpack"
1166+
ct.store(b, pid, reinterpret(UInt16, tile))
1167+
return
1168+
end
1169+
end
1170+
1171+
# Narrow UInt16 -> UInt8 (1D): lowers to a single pack.
1172+
@test @filecheck begin
1173+
@check_label "entry"
1174+
code_tiled(Tuple{ct.TileArray{UInt16,1,spec1d}, ct.TileArray{UInt8,1,spec1d}}) do a, b
1175+
pid = ct.bid(1)
1176+
tile = ct.load(a, pid, (8,))
1177+
@check "pack"
1178+
ct.store(b, pid, reinterpret(UInt8, tile))
1179+
return
1180+
end
1181+
end
1182+
1183+
# pack/unpack require v13.3 — older bytecode rejects with a clear error.
1184+
# (`literal` since the `+` in the message is a regex metachar to FileCheck.)
1185+
@test @filecheck throws=ct.IRError begin
1186+
@check literal=true "v13.3+"
1187+
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}};
1188+
bytecode_version=v"13.2") do a, b
1189+
pid = ct.bid(1)
1190+
tile = ct.load(a, pid, (16,))
1191+
ct.store(b, pid, reinterpret(UInt16, tile))
1192+
return
1193+
end
1194+
end
1195+
1196+
# Rank-1 scaled: one UInt8 (8 bits) can't fill a UInt16; caught by unpack.
1197+
@test @filecheck throws=ct.IRError begin
1198+
@check "do not evenly divide"
1199+
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}}) do a, b
1200+
pid = ct.bid(1)
1201+
ct.store(b, pid, reinterpret(UInt16, ct.load(a, pid, (1,))))
1202+
return
1203+
end
1204+
end
1205+
1206+
# reshape-widen: leading dim must equal the ratio (2); 1 fails the final reshape.
1207+
@test @filecheck throws=ct.IRError begin
1208+
@check "same number of elements"
1209+
code_tiled(Tuple{ct.TileArray{UInt8,2,spec2d}, ct.TileArray{UInt16,2,spec2d}}) do a, b
1210+
pid = ct.bid(1)
1211+
ct.store(b, pid, reinterpret(reshape, UInt16, ct.load(a, pid, (1, 4))))
1212+
return
1213+
end
1214+
end
1215+
end
1216+
11431217
# TODO: exti - sign/zero extend integer
11441218
# TODO: ftoi - float to integer
11451219
# TODO: itof - integer to float

0 commit comments

Comments
 (0)