@@ -53,6 +53,218 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.bitcast), args)
5353 CGVal (result_v, result_type_id, result_jltype, source. shape)
5454end
5555
56+ @inline lookup_bitwidth (@nospecialize (T:: Type )) =
57+ Base. invokelatest (bitwidth, T):: Int
58+
59+ # ── Low-level pack / unpack intrinsics (1:1 with cuda_tile.pack / unpack) ──────
60+ #
61+ # Both Tile IR ops are rank-1 → rank-1 and reinterpret the *whole tile* as a byte
62+ # array (not element-wise like bitcast). They are the primitives that the
63+ # whole-tile `Base.reinterpret` below composes (with `reshape`/`bitcast`) into a
64+ # Julia-semantics reinterpret of any rank. The 8-bit element types are handled by
65+ # `bitcast`, never pack/unpack — matching cutile-python's `pack_to_bytes` /
66+ # `unpack_from_bytes`.
67+
68+ """
69+ Intrinsics.pack(x::Tile{S,Tuple{N}}) -> Tile{UInt8,Tuple{N*bitwidth(S)÷8}}
70+
71+ Pack a rank-1 numeric tile into a rank-1 `UInt8` tile (the tile's bits viewed as
72+ a byte array); lowers to `cuda_tile.pack`. `S` must not be 8-bit (use `bitcast`).
73+ Requires Tile IR bytecode v13.3+.
74+ """
75+ @intrinsic pack (x)
76+ function tfunc (𝕃, :: typeof (Intrinsics. pack), @nospecialize (x))
77+ src = CC. widenconst (x)
78+ src <: Tile || return nothing
79+ S = src. parameters[1 ]
80+ Shape = src. parameters[2 ]
81+ (S isa Type && Shape isa Type) || return nothing
82+ dims = Shape. parameters
83+ length (dims) == 1 || return nothing
84+ n = dims[1 ]:: Int
85+ bs = lookup_bitwidth (S)
86+ (n * bs) % 8 == 0 || return nothing
87+ return Tile{UInt8, Tuple{(n * bs) ÷ 8 }}
88+ end
89+ function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. pack), args)
90+ cb = ctx. cb
91+ tt = ctx. tt
92+
93+ source = @something emit_value! (ctx, args[1 ]) throw (IRError (" pack: cannot resolve source" ))
94+ tt. version >= v " 13.3" ||
95+ throw (IRError (" cuda_tile.pack requires Tile IR bytecode v13.3+, got v$(tt. version) " ))
96+ length (source. shape) == 1 ||
97+ throw (IRError (" pack: requires a rank-1 tile, got a $(length (source. shape)) -D tile" ))
98+
99+ src_type = CC. widenconst (source. jltype)
100+ S = eltype (src_type)
101+ sbits = lookup_bitwidth (S)
102+ sbits == 8 &&
103+ throw (IRError (" pack: 8-bit element type $S should be reinterpreted via bitcast, not packed" ))
104+ n = source. shape[1 ]
105+ (n * sbits) % 8 == 0 ||
106+ throw (IRError (" pack: a $n -element $S tile ($(n * sbits) bits) is not a whole number of bytes" ))
107+ new_n = (n * sbits) ÷ 8
108+
109+ new_shape = RowMajorShape ([new_n])
110+ result_type_id = tile_type! (tt, lookup_dtype! (tt, UInt8), new_shape)
111+ result_v = encode_PackOp! (cb, result_type_id, source. v)
112+ CGVal (result_v, result_type_id, Tile{UInt8, Tuple{new_n}}, new_shape)
113+ end
114+
115+ """
116+ Intrinsics.unpack(x::Tile{UInt8,Tuple{N}}, ::Type{T}) -> Tile{T,Tuple{N*8÷bitwidth(T)}}
117+
118+ Unpack a rank-1 `UInt8` tile into a rank-1 numeric tile of element type `T` (the
119+ inverse of [`pack`](@ref Intrinsics.pack)); lowers to `cuda_tile.unpack`. `T`
120+ must be a compile-time constant and must not be 8-bit (use `bitcast`). Requires
121+ Tile IR bytecode v13.3+.
122+ """
123+ @intrinsic unpack (x, :: Type{T} ) where {T}
124+ function tfunc (𝕃, :: typeof (Intrinsics. unpack), @nospecialize (x), @nospecialize (target_type))
125+ T = instanceof_tfunc (target_type)
126+ T === nothing && return nothing
127+ src = CC. widenconst (x)
128+ src <: Tile || return nothing
129+ Shape = src. parameters[2 ]
130+ Shape isa Type || return nothing
131+ dims = Shape. parameters
132+ length (dims) == 1 || return nothing
133+ n = dims[1 ]:: Int
134+ bt = lookup_bitwidth (T)
135+ (n * 8 ) % bt == 0 || return nothing
136+ return Tile{T, Tuple{(n * 8 ) ÷ bt}}
137+ end
138+ function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. unpack), args)
139+ cb = ctx. cb
140+ tt = ctx. tt
141+
142+ source = @something emit_value! (ctx, args[1 ]) throw (IRError (" unpack: cannot resolve source" ))
143+ target_type = @something get_constant (ctx, args[2 ]) throw (IRError (" unpack: requires compile-time target type" ))
144+ tt. version >= v " 13.3" ||
145+ throw (IRError (" cuda_tile.unpack requires Tile IR bytecode v13.3+, got v$(tt. version) " ))
146+ length (source. shape) == 1 ||
147+ throw (IRError (" unpack: requires a rank-1 tile, got a $(length (source. shape)) -D tile" ))
148+
149+ src_type = CC. widenconst (source. jltype)
150+ eltype (src_type) === UInt8 ||
151+ throw (IRError (" unpack: requires a UInt8 tile, got $(eltype (src_type)) " ))
152+ tbits = lookup_bitwidth (target_type)
153+ tbits == 8 &&
154+ throw (IRError (" unpack: 8-bit target $target_type should be reinterpreted via bitcast, not unpacked" ))
155+ n = source. shape[1 ]
156+ (n * 8 ) % tbits == 0 ||
157+ throw (IRError (" unpack: $n bytes ($(n * 8 ) bits) do not evenly divide into $target_type ($tbits -bit) elements" ))
158+ new_n = (n * 8 ) ÷ tbits
159+
160+ new_shape = RowMajorShape ([new_n])
161+ result_type_id = tile_type! (tt, lookup_dtype! (tt, target_type), new_shape)
162+ result_v = encode_UnpackOp! (cb, result_type_id, source. v)
163+ CGVal (result_v, result_type_id, Tile{target_type, Tuple{new_n}}, new_shape)
164+ end
165+
166+ # ── Whole-tile reinterpret (Julia semantics, any rank) ────────────────────────
167+ #
168+ # A tile's row-major byte stream equals the column-major stream of its Julia
169+ # shape (we store the reversed shape), so flatten → width-convert → reshape
170+ # reproduces `Base.reinterpret` element-for-element. Equal widths bitcast;
171+ # crossing 8 bits goes through `pack`/`unpack`; a non-byte → non-byte change
172+ # routes through bytes (pack then unpack). All the shape arithmetic is
173+ # compile-time constant, so the intermediate `reshape`s fold away (identity
174+ # reshapes are eliminated by the canonicalizer) — a rank-1 FP4 reinterpret lowers
175+ # to a single pack/unpack.
176+
177+ # Width-convert a rank-1 tile to element type `T` (rank-1 in, rank-1 out).
178+ @inline function reinterpret_width (:: Type{T} , flat:: Tile{S} ) where {T, S}
179+ bs = bitwidth (S)
180+ bt = bitwidth (T)
181+ if bs == bt
182+ return Intrinsics. bitcast (flat, T) # same width
183+ elseif bt == 8
184+ return Intrinsics. bitcast (Intrinsics. pack (flat), T) # S → bytes → T8
185+ elseif bs == 8
186+ return Intrinsics. unpack (Intrinsics. bitcast (flat, UInt8), T) # S8 → bytes → T
187+ else
188+ return Intrinsics. unpack (Intrinsics. pack (flat), T) # S → bytes → T
189+ end
190+ end
191+
192+ # Result shape for `reinterpret(T, x)`: rescale the leading (column-major)
193+ # dimension by the element-width ratio, like `reinterpret(T, ::AbstractArray)`.
194+ @inline function reinterpret_scaled_shape (:: Type{T} , :: Type{S} , sz:: NTuple{N, Int} ) where {T, S, N}
195+ bs = bitwidth (S)
196+ bt = bitwidth (T)
197+ if N == 0
198+ bs == bt || throw (ArgumentError (" reinterpret: a 0-D $S tile ($bs bits) cannot be reinterpreted as $T ($bt bits)" ))
199+ return ()
200+ end
201+ (sz[1 ] * bs) % bt == 0 ||
202+ throw (ArgumentError (" reinterpret: leading dimension $(sz[1 ]) of a $S tile ($(sz[1 ] * bs) bits) is not divisible by $bt -bit $T " ))
203+ return (div (sz[1 ] * bs, bt), Base. tail (sz)... )
204+ end
205+
206+ # Result shape for `reinterpret(reshape, T, x)`: drop the leading dim on widening
207+ # (it must equal the ratio), prepend one on narrowing, like the array version.
208+ @inline function reinterpret_reshape_shape (:: Type{T} , :: Type{S} , sz:: NTuple{N, Int} ) where {T, S, N}
209+ bs = bitwidth (S)
210+ bt = bitwidth (T)
211+ bs == bt && return sz
212+ if bt > bs
213+ bt % bs == 0 ||
214+ throw (ArgumentError (" reinterpret(reshape, $T , x): $T ($bt bits) is not a whole multiple of $S ($bs bits)" ))
215+ r = div (bt, bs)
216+ (N >= 1 && sz[1 ] == r) ||
217+ throw (ArgumentError (" reinterpret(reshape, $T , x): leading dimension must be $r , got $(N == 0 ? () : sz[1 ]) " ))
218+ return Base. tail (sz)
219+ else
220+ bs % bt == 0 ||
221+ throw (ArgumentError (" reinterpret(reshape, $T , x): $S ($bs bits) is not a whole multiple of $T ($bt bits)" ))
222+ return (div (bs, bt), sz... )
223+ end
224+ end
225+
226+ """
227+ Base.reinterpret(::Type{T}, x::Tile) -> Tile{T}
228+
229+ Reinterpret the *whole tile* `x` as a tile of element type `T`, like
230+ `reinterpret(T, ::AbstractArray)`: the underlying bits are viewed as a contiguous
231+ (column-major) block and the leading dimension is rescaled by the ratio of
232+ element widths. Lowers to `cuda_tile.bitcast` for equal widths and to
233+ `cuda_tile.pack`/`unpack` (via `reshape` to rank-1) when widths differ.
234+
235+ This is how sub-byte formats move through global memory: a `Tile{UInt8,(N,)}`
236+ reinterprets to a `Tile{Float4_E2M1FN,(2N,)}` and back, so FP4 data can be stored
237+ in a `UInt8` array. The total bit-width is preserved, so it must divide evenly.
238+
239+ Note `reinterpret.(T, x)` (with a dot) is the unrelated *element-wise* broadcast,
240+ which keeps the shape and requires `T` to be the same width as `eltype(x)`.
241+
242+ ```julia
243+ bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)}
244+ fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)}
245+ vals = convert(ct.Tile{Float32}, fp4) # widen for compute
246+ ```
247+ """
248+ @inline function Base. reinterpret (:: Type{T} , x:: Tile ) where {T}
249+ rshape = reinterpret_scaled_shape (T, eltype (x), size (x))
250+ flat = Intrinsics. reshape (x, (prod (size (x)),))
251+ return Intrinsics. reshape (reinterpret_width (T, flat), rshape)
252+ end
253+
254+ """
255+ Base.reinterpret(reshape, ::Type{T}, x::Tile) -> Tile{T}
256+
257+ The `reshape`-form whole-tile reinterpret, mirroring
258+ `reinterpret(reshape, T, ::AbstractArray)`: instead of rescaling the leading
259+ dimension it *removes* it when widening (the leading dim must equal
260+ `bitwidth(T) ÷ bitwidth(eltype(x))`) and *prepends* one when narrowing.
261+ """
262+ @inline function Base. reinterpret (:: typeof (reshape), :: Type{T} , x:: Tile ) where {T}
263+ rshape = reinterpret_reshape_shape (T, eltype (x), size (x))
264+ flat = Intrinsics. reshape (x, (prod (size (x)),))
265+ return Intrinsics. reshape (reinterpret_width (T, flat), rshape)
266+ end
267+
56268"""
57269 Intrinsics.exti(x::Tile{<:Integer}, ::Type{T}, s::Signedness.T) -> Tile{T} where {T<:Integer}
58270
0 commit comments