Skip to content

Commit ed8ddc1

Browse files
committed
Add reinterpret as interface to bitcast, pack, and unpack
1 parent 4a3cc60 commit ed8ddc1

7 files changed

Lines changed: 498 additions & 2 deletions

File tree

ext/MicrofloatsExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module MicrofloatsExt
22

33
import cuTile as ct
4+
import Microfloats
45

56
using Microfloats: Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU, Float4_E2M1FN
67

@@ -14,6 +15,13 @@ ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2}) = ct.F8E5M2(
1415
ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E8M0FNU}) = ct.F8E8M0FNU(table)
1516
ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float4_E2M1FN}) = ct.F4E2M1FN(table)
1617

18+
# Microfloats are byte-storage primitives (`sizeof == 1`), so cuTile's default
19+
# `bitwidth` (8 * sizeof) over-counts the sub-byte formats. Forward to
20+
# `Microfloats.bitwidth`, which derives the true width from the format's bit
21+
# fields (e.g. `Float4_E2M1FN` → 4), so whole-tile `reinterpret` packs/unpacks
22+
# them through `UInt8` correctly.
23+
ct.bitwidth(::Type{T}) where {T<:Microfloats.Microfloat} = Microfloats.bitwidth(T)
24+
1725
# E8M0FNU has no sign bit and represents a power of two; tileiras rejects
1826
# nearest-even on f32→E8M0FNU (only `zero` and `positive_inf` are valid).
1927
ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero

src/bytecode/encodings.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ module Opcode
9696
const XOrIOp = 108
9797
const YieldOp = 109
9898
const Atan2Op = 110 # since 13.2
99+
const PackOp = 111 # since 13.3
100+
const UnpackOp = 112 # since 13.3
99101
end
100102

101103
# Enums for operation attributes
@@ -1796,6 +1798,36 @@ function encode_BitcastOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
17961798
return new_op!(cb)
17971799
end
17981800

1801+
"""
1802+
encode_PackOp!(cb, result_type, source) -> Value
1803+
1804+
Pack a rank-1 numeric tile into a rank-1 `tile<i8>`. Unlike `bitcast`, this is
1805+
not element-wise: the whole tile is reinterpreted as a byte array, so the result
1806+
length is the input's total byte count. The source must not be an 8-bit type
1807+
(use `bitcast`). Since 13.3. Opcode: 111
1808+
"""
1809+
function encode_PackOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
1810+
encode_varint!(cb.buf, Opcode.PackOp)
1811+
encode_typeid!(cb.buf, result_type)
1812+
encode_operand!(cb.buf, source)
1813+
return new_op!(cb)
1814+
end
1815+
1816+
"""
1817+
encode_UnpackOp!(cb, result_type, source) -> Value
1818+
1819+
Unpack a rank-1 `tile<i8>` into a rank-1 numeric tile (the inverse of
1820+
[`encode_PackOp!`](@ref)). The input byte count must equal the output's total
1821+
byte count. The result must not be an 8-bit type (use `bitcast`). Since 13.3.
1822+
Opcode: 112
1823+
"""
1824+
function encode_UnpackOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
1825+
encode_varint!(cb.buf, Opcode.UnpackOp)
1826+
encode_typeid!(cb.buf, result_type)
1827+
encode_operand!(cb.buf, source)
1828+
return new_op!(cb)
1829+
end
1830+
17991831
"""
18001832
encode_BroadcastOp!(cb, result_type, source) -> Value
18011833

src/compiler/intrinsics/conversions.jl

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,218 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.bitcast), args)
5353
CGVal(result_v, result_type_id, result_jltype, source.shape)
5454
end
5555

56+
@inline lookup_bitwidth(@nospecialize(T::Type)) =
57+
Base.invokelatest(bitwidth, T)::Int
58+
59+
# ── Low-level pack / unpack intrinsics (1:1 with cuda_tile.pack / unpack) ──────
60+
#
61+
# Both Tile IR ops are rank-1 → rank-1 and reinterpret the *whole tile* as a byte
62+
# array (not element-wise like bitcast). They are the primitives that the
63+
# whole-tile `Base.reinterpret` below composes (with `reshape`/`bitcast`) into a
64+
# Julia-semantics reinterpret of any rank. The 8-bit element types are handled by
65+
# `bitcast`, never pack/unpack — matching cutile-python's `pack_to_bytes` /
66+
# `unpack_from_bytes`.
67+
68+
"""
69+
Intrinsics.pack(x::Tile{S,Tuple{N}}) -> Tile{UInt8,Tuple{N*bitwidth(S)÷8}}
70+
71+
Pack a rank-1 numeric tile into a rank-1 `UInt8` tile (the tile's bits viewed as
72+
a byte array); lowers to `cuda_tile.pack`. `S` must not be 8-bit (use `bitcast`).
73+
Requires Tile IR bytecode v13.3+.
74+
"""
75+
@intrinsic pack(x)
76+
function tfunc(𝕃, ::typeof(Intrinsics.pack), @nospecialize(x))
77+
src = CC.widenconst(x)
78+
src <: Tile || return nothing
79+
S = src.parameters[1]
80+
Shape = src.parameters[2]
81+
(S isa Type && Shape isa Type) || return nothing
82+
dims = Shape.parameters
83+
length(dims) == 1 || return nothing
84+
n = dims[1]::Int
85+
bs = lookup_bitwidth(S)
86+
(n * bs) % 8 == 0 || return nothing
87+
return Tile{UInt8, Tuple{(n * bs) ÷ 8}}
88+
end
89+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.pack), args)
90+
cb = ctx.cb
91+
tt = ctx.tt
92+
93+
source = @something emit_value!(ctx, args[1]) throw(IRError("pack: cannot resolve source"))
94+
tt.version >= v"13.3" ||
95+
throw(IRError("cuda_tile.pack requires Tile IR bytecode v13.3+, got v$(tt.version)"))
96+
length(source.shape) == 1 ||
97+
throw(IRError("pack: requires a rank-1 tile, got a $(length(source.shape))-D tile"))
98+
99+
src_type = CC.widenconst(source.jltype)
100+
S = eltype(src_type)
101+
sbits = lookup_bitwidth(S)
102+
sbits == 8 &&
103+
throw(IRError("pack: 8-bit element type $S should be reinterpreted via bitcast, not packed"))
104+
n = source.shape[1]
105+
(n * sbits) % 8 == 0 ||
106+
throw(IRError("pack: a $n-element $S tile ($(n * sbits) bits) is not a whole number of bytes"))
107+
new_n = (n * sbits) ÷ 8
108+
109+
new_shape = RowMajorShape([new_n])
110+
result_type_id = tile_type!(tt, lookup_dtype!(tt, UInt8), new_shape)
111+
result_v = encode_PackOp!(cb, result_type_id, source.v)
112+
CGVal(result_v, result_type_id, Tile{UInt8, Tuple{new_n}}, new_shape)
113+
end
114+
115+
"""
116+
Intrinsics.unpack(x::Tile{UInt8,Tuple{N}}, ::Type{T}) -> Tile{T,Tuple{N*8÷bitwidth(T)}}
117+
118+
Unpack a rank-1 `UInt8` tile into a rank-1 numeric tile of element type `T` (the
119+
inverse of [`pack`](@ref Intrinsics.pack)); lowers to `cuda_tile.unpack`. `T`
120+
must be a compile-time constant and must not be 8-bit (use `bitcast`). Requires
121+
Tile IR bytecode v13.3+.
122+
"""
123+
@intrinsic unpack(x, ::Type{T}) where {T}
124+
function tfunc(𝕃, ::typeof(Intrinsics.unpack), @nospecialize(x), @nospecialize(target_type))
125+
T = instanceof_tfunc(target_type)
126+
T === nothing && return nothing
127+
src = CC.widenconst(x)
128+
src <: Tile || return nothing
129+
Shape = src.parameters[2]
130+
Shape isa Type || return nothing
131+
dims = Shape.parameters
132+
length(dims) == 1 || return nothing
133+
n = dims[1]::Int
134+
bt = lookup_bitwidth(T)
135+
(n * 8) % bt == 0 || return nothing
136+
return Tile{T, Tuple{(n * 8) ÷ bt}}
137+
end
138+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.unpack), args)
139+
cb = ctx.cb
140+
tt = ctx.tt
141+
142+
source = @something emit_value!(ctx, args[1]) throw(IRError("unpack: cannot resolve source"))
143+
target_type = @something get_constant(ctx, args[2]) throw(IRError("unpack: requires compile-time target type"))
144+
tt.version >= v"13.3" ||
145+
throw(IRError("cuda_tile.unpack requires Tile IR bytecode v13.3+, got v$(tt.version)"))
146+
length(source.shape) == 1 ||
147+
throw(IRError("unpack: requires a rank-1 tile, got a $(length(source.shape))-D tile"))
148+
149+
src_type = CC.widenconst(source.jltype)
150+
eltype(src_type) === UInt8 ||
151+
throw(IRError("unpack: requires a UInt8 tile, got $(eltype(src_type))"))
152+
tbits = lookup_bitwidth(target_type)
153+
tbits == 8 &&
154+
throw(IRError("unpack: 8-bit target $target_type should be reinterpreted via bitcast, not unpacked"))
155+
n = source.shape[1]
156+
(n * 8) % tbits == 0 ||
157+
throw(IRError("unpack: $n bytes ($(n * 8) bits) do not evenly divide into $target_type ($tbits-bit) elements"))
158+
new_n = (n * 8) ÷ tbits
159+
160+
new_shape = RowMajorShape([new_n])
161+
result_type_id = tile_type!(tt, lookup_dtype!(tt, target_type), new_shape)
162+
result_v = encode_UnpackOp!(cb, result_type_id, source.v)
163+
CGVal(result_v, result_type_id, Tile{target_type, Tuple{new_n}}, new_shape)
164+
end
165+
166+
# ── Whole-tile reinterpret (Julia semantics, any rank) ────────────────────────
167+
#
168+
# A tile's row-major byte stream equals the column-major stream of its Julia
169+
# shape (we store the reversed shape), so flatten → width-convert → reshape
170+
# reproduces `Base.reinterpret` element-for-element. Equal widths bitcast;
171+
# crossing 8 bits goes through `pack`/`unpack`; a non-byte → non-byte change
172+
# routes through bytes (pack then unpack). All the shape arithmetic is
173+
# compile-time constant, so the intermediate `reshape`s fold away (identity
174+
# reshapes are eliminated by the canonicalizer) — a rank-1 FP4 reinterpret lowers
175+
# to a single pack/unpack.
176+
177+
# Width-convert a rank-1 tile to element type `T` (rank-1 in, rank-1 out).
178+
@inline function reinterpret_width(::Type{T}, flat::Tile{S}) where {T, S}
179+
bs = bitwidth(S)
180+
bt = bitwidth(T)
181+
if bs == bt
182+
return Intrinsics.bitcast(flat, T) # same width
183+
elseif bt == 8
184+
return Intrinsics.bitcast(Intrinsics.pack(flat), T) # S → bytes → T8
185+
elseif bs == 8
186+
return Intrinsics.unpack(Intrinsics.bitcast(flat, UInt8), T) # S8 → bytes → T
187+
else
188+
return Intrinsics.unpack(Intrinsics.pack(flat), T) # S → bytes → T
189+
end
190+
end
191+
192+
# Result shape for `reinterpret(T, x)`: rescale the leading (column-major)
193+
# dimension by the element-width ratio, like `reinterpret(T, ::AbstractArray)`.
194+
@inline function reinterpret_scaled_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
195+
bs = bitwidth(S)
196+
bt = bitwidth(T)
197+
if N == 0
198+
bs == bt || throw(ArgumentError("reinterpret: a 0-D $S tile ($bs bits) cannot be reinterpreted as $T ($bt bits)"))
199+
return ()
200+
end
201+
(sz[1] * bs) % bt == 0 ||
202+
throw(ArgumentError("reinterpret: leading dimension $(sz[1]) of a $S tile ($(sz[1] * bs) bits) is not divisible by $bt-bit $T"))
203+
return (div(sz[1] * bs, bt), Base.tail(sz)...)
204+
end
205+
206+
# Result shape for `reinterpret(reshape, T, x)`: drop the leading dim on widening
207+
# (it must equal the ratio), prepend one on narrowing, like the array version.
208+
@inline function reinterpret_reshape_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
209+
bs = bitwidth(S)
210+
bt = bitwidth(T)
211+
bs == bt && return sz
212+
if bt > bs
213+
bt % bs == 0 ||
214+
throw(ArgumentError("reinterpret(reshape, $T, x): $T ($bt bits) is not a whole multiple of $S ($bs bits)"))
215+
r = div(bt, bs)
216+
(N >= 1 && sz[1] == r) ||
217+
throw(ArgumentError("reinterpret(reshape, $T, x): leading dimension must be $r, got $(N == 0 ? () : sz[1])"))
218+
return Base.tail(sz)
219+
else
220+
bs % bt == 0 ||
221+
throw(ArgumentError("reinterpret(reshape, $T, x): $S ($bs bits) is not a whole multiple of $T ($bt bits)"))
222+
return (div(bs, bt), sz...)
223+
end
224+
end
225+
226+
"""
227+
Base.reinterpret(::Type{T}, x::Tile) -> Tile{T}
228+
229+
Reinterpret the *whole tile* `x` as a tile of element type `T`, like
230+
`reinterpret(T, ::AbstractArray)`: the underlying bits are viewed as a contiguous
231+
(column-major) block and the leading dimension is rescaled by the ratio of
232+
element widths. Lowers to `cuda_tile.bitcast` for equal widths and to
233+
`cuda_tile.pack`/`unpack` (via `reshape` to rank-1) when widths differ.
234+
235+
This is how sub-byte formats move through global memory: a `Tile{UInt8,(N,)}`
236+
reinterprets to a `Tile{Float4_E2M1FN,(2N,)}` and back, so FP4 data can be stored
237+
in a `UInt8` array. The total bit-width is preserved, so it must divide evenly.
238+
239+
Note `reinterpret.(T, x)` (with a dot) is the unrelated *element-wise* broadcast,
240+
which keeps the shape and requires `T` to be the same width as `eltype(x)`.
241+
242+
```julia
243+
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)}
244+
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)}
245+
vals = convert(ct.Tile{Float32}, fp4) # widen for compute
246+
```
247+
"""
248+
@inline function Base.reinterpret(::Type{T}, x::Tile) where {T}
249+
rshape = reinterpret_scaled_shape(T, eltype(x), size(x))
250+
flat = Intrinsics.reshape(x, (prod(size(x)),))
251+
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
252+
end
253+
254+
"""
255+
Base.reinterpret(reshape, ::Type{T}, x::Tile) -> Tile{T}
256+
257+
The `reshape`-form whole-tile reinterpret, mirroring
258+
`reinterpret(reshape, T, ::AbstractArray)`: instead of rescaling the leading
259+
dimension it *removes* it when widening (the leading dim must equal
260+
`bitwidth(T) ÷ bitwidth(eltype(x))`) and *prepends* one when narrowing.
261+
"""
262+
@inline function Base.reinterpret(::typeof(reshape), ::Type{T}, x::Tile) where {T}
263+
rshape = reinterpret_reshape_shape(T, eltype(x), size(x))
264+
flat = Intrinsics.reshape(x, (prod(size(x)),))
265+
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
266+
end
267+
56268
"""
57269
Intrinsics.exti(x::Tile{<:Integer}, ::Type{T}, s::Signedness.T) -> Tile{T} where {T<:Integer}
58270

src/language/types.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,23 @@ similar_type(::Type{Tile{T, Shape}}, ::Type{U}, new_shape::Tuple) where {T, Shap
377377
similar_type(::Type{<:Tile{T}}, ::Type{U}) where {T, U} = Tile{U}
378378
similar_type(::Type, ::Type{T}) where {T} = T # fallback for non-Tile types
379379

380+
"""
381+
bitwidth(::Type{T}) -> Int
382+
383+
Number of bits a single element of `T` occupies in a Tile IR tile. Used by the
384+
whole-tile [`reinterpret`](@ref Base.reinterpret(::Type, ::Tile)) to scale the
385+
tile shape across a change of element width (e.g. `UInt8` ↔ `Float4_E2M1FN`,
386+
8 bits ↔ 4 bits).
387+
388+
The default is `8 * sizeof(T)`, which is correct for the standard integer and
389+
floating-point types and for the byte-wide `Float8_*` formats. Sub-byte formats
390+
whose `sizeof` rounds up to a whole byte (e.g. `Float4_E2M1FN`, 4 bits but
391+
`sizeof == 1`) override this; the `Microfloats` extension forwards to
392+
`Microfloats.bitwidth`, which derives the true width from the format's bit
393+
fields. Matches the `bitwidth` convention used by `Microfloats`/`Narrow`.
394+
"""
395+
bitwidth(::Type{T}) where {T} = 8 * sizeof(T)
396+
380397

381398
"""
382399
TFloat32 <: AbstractFloat

test/codegen/operations.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,59 @@ end
11401140
end
11411141
end
11421142

1143+
@testset "reinterpret (whole-tile)" begin
1144+
# Equal width, different Tile IR dtype (Float16 -> Int16): whole-tile
1145+
# reinterpret is a plain bitcast — no pack/unpack, shape preserved.
1146+
@test @filecheck begin
1147+
@check_label "entry"
1148+
code_tiled(Tuple{ct.TileArray{Float16,1,spec1d}, ct.TileArray{Int16,1,spec1d}}) do a, b
1149+
pid = ct.bid(1)
1150+
tile = ct.load(a, pid, (16,))
1151+
@check "bitcast"
1152+
@check_not "pack"
1153+
ct.store(b, pid, reinterpret(Int16, tile))
1154+
return
1155+
end
1156+
end
1157+
1158+
# Widen UInt8 -> UInt16 (1D): lowers to a single unpack, identity reshapes
1159+
# folded away.
1160+
@test @filecheck begin
1161+
@check_label "entry"
1162+
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}}) do a, b
1163+
pid = ct.bid(1)
1164+
tile = ct.load(a, pid, (16,))
1165+
@check "unpack"
1166+
ct.store(b, pid, reinterpret(UInt16, tile))
1167+
return
1168+
end
1169+
end
1170+
1171+
# Narrow UInt16 -> UInt8 (1D): lowers to a single pack.
1172+
@test @filecheck begin
1173+
@check_label "entry"
1174+
code_tiled(Tuple{ct.TileArray{UInt16,1,spec1d}, ct.TileArray{UInt8,1,spec1d}}) do a, b
1175+
pid = ct.bid(1)
1176+
tile = ct.load(a, pid, (8,))
1177+
@check "pack"
1178+
ct.store(b, pid, reinterpret(UInt8, tile))
1179+
return
1180+
end
1181+
end
1182+
1183+
# pack/unpack require v13.3 — older bytecode rejects with a clear error.
1184+
let kernel = (a, b) -> begin
1185+
pid = ct.bid(1)
1186+
tile = ct.load(a, pid, (16,))
1187+
ct.store(b, pid, reinterpret(UInt16, tile))
1188+
return
1189+
end
1190+
@test_throws "v13.3+" code_tiled(devnull, kernel,
1191+
Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}};
1192+
bytecode_version=v"13.2")
1193+
end
1194+
end
1195+
11431196
# TODO: exti - sign/zero extend integer
11441197
# TODO: ftoi - float to integer
11451198
# TODO: itof - integer to float

0 commit comments

Comments
 (0)