Skip to content

Commit ba0efa9

Browse files
committed
Move Base.reinterpret methods to operations.jl
1 parent 5577add commit ba0efa9

2 files changed

Lines changed: 76 additions & 76 deletions

File tree

src/compiler/intrinsics/conversions.jl

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -152,82 +152,6 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.unpack), args)
152152
CGVal(result_v, result_type_id, Tile{target_type, Tuple{new_n}}, new_shape)
153153
end
154154

155-
# Width-convert a rank-1 tile to element type `T` (rank-1 in, rank-1 out).
156-
@inline function reinterpret_width(::Type{T}, flat::Tile{S}) where {T, S}
157-
bs = bitwidth(S)
158-
bt = bitwidth(T)
159-
if bs == bt
160-
return Intrinsics.bitcast(flat, T) # same width
161-
elseif bt == 8
162-
return Intrinsics.bitcast(Intrinsics.pack(flat), T) # S → bytes → T8
163-
elseif bs == 8
164-
return Intrinsics.unpack(Intrinsics.bitcast(flat, UInt8), T) # S8 → bytes → T
165-
else
166-
return Intrinsics.unpack(Intrinsics.pack(flat), T) # S → bytes → T
167-
end
168-
end
169-
170-
# Result shape for `reinterpret(T, x)`: rescale the leading (column-major)
171-
# dimension by the element-width ratio, like `reinterpret(T, ::AbstractArray)`.
172-
@inline function reinterpret_scaled_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
173-
bs = bitwidth(S)
174-
bt = bitwidth(T)
175-
N == 0 && return () # 0-D: only equal-width is valid; cross-width caught at emit
176-
return (fld(sz[1] * bs, bt), Base.tail(sz)...)
177-
end
178-
179-
# Result shape for `reinterpret(reshape, T, x)`: drop the leading dim on widening
180-
# (it must equal the ratio), prepend one on narrowing, like the array version.
181-
@inline function reinterpret_reshape_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
182-
bs = bitwidth(S)
183-
bt = bitwidth(T)
184-
bs == bt && return sz
185-
N == 0 && return () # cross-width on a 0-D tile is invalid; caught at emit
186-
return bt > bs ? Base.tail(sz) : (div(bs, bt), sz...)
187-
end
188-
189-
"""
190-
Base.reinterpret(::Type{T}, x::Tile) -> Tile{T}
191-
192-
Reinterpret the *whole tile* `x` as a tile of element type `T`, like
193-
`reinterpret(T, ::AbstractArray)`: the underlying bits are viewed as a contiguous
194-
(column-major) block and the leading dimension is rescaled by the ratio of
195-
element widths. Lowers to `cuda_tile.bitcast` for equal widths and to
196-
`cuda_tile.pack`/`unpack` (via `reshape` to rank-1) when widths differ.
197-
198-
This is how sub-byte formats move through global memory: a `Tile{UInt8,(N,)}`
199-
reinterprets to a `Tile{Float4_E2M1FN,(2N,)}` and back, so FP4 data can be stored
200-
in a `UInt8` array. The total bit-width is preserved, so it must divide evenly.
201-
202-
Note `reinterpret.(T, x)` (with a dot) is the unrelated *element-wise* broadcast,
203-
which keeps the shape and requires `T` to be the same width as `eltype(x)`.
204-
205-
```julia
206-
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)}
207-
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)}
208-
vals = convert(ct.Tile{Float32}, fp4) # widen for compute
209-
```
210-
"""
211-
@inline function Base.reinterpret(::Type{T}, x::Tile) where {T}
212-
rshape = reinterpret_scaled_shape(T, eltype(x), size(x))
213-
flat = Intrinsics.reshape(x, (prod(size(x)),))
214-
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
215-
end
216-
217-
"""
218-
Base.reinterpret(reshape, ::Type{T}, x::Tile) -> Tile{T}
219-
220-
The `reshape`-form whole-tile reinterpret, mirroring
221-
`reinterpret(reshape, T, ::AbstractArray)`: instead of rescaling the leading
222-
dimension it *removes* it when widening (the leading dim must equal
223-
`bitwidth(T) ÷ bitwidth(eltype(x))`) and *prepends* one when narrowing.
224-
"""
225-
@inline function Base.reinterpret(::typeof(reshape), ::Type{T}, x::Tile) where {T}
226-
rshape = reinterpret_reshape_shape(T, eltype(x), size(x))
227-
flat = Intrinsics.reshape(x, (prod(size(x)),))
228-
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
229-
end
230-
231155
"""
232156
Intrinsics.exti(x::Tile{<:Integer}, ::Type{T}, s::Signedness.T) -> Tile{T} where {T<:Integer}
233157

src/language/operations.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,82 @@ Equivalent to single-arg `permutedims`.
857857
end
858858
end
859859

860+
# Width-convert a rank-1 tile to element type `T` (rank-1 in, rank-1 out).
861+
@inline function reinterpret_width(::Type{T}, flat::Tile{S}) where {T, S}
862+
bs = bitwidth(S)
863+
bt = bitwidth(T)
864+
if bs == bt
865+
return Intrinsics.bitcast(flat, T) # same width
866+
elseif bt == 8
867+
return Intrinsics.bitcast(Intrinsics.pack(flat), T) # S → bytes → T8
868+
elseif bs == 8
869+
return Intrinsics.unpack(Intrinsics.bitcast(flat, UInt8), T) # S8 → bytes → T
870+
else
871+
return Intrinsics.unpack(Intrinsics.pack(flat), T) # S → bytes → T
872+
end
873+
end
874+
875+
# Result shape for `reinterpret(T, x)`: rescale the leading (column-major)
876+
# dimension by the element-width ratio, like `reinterpret(T, ::AbstractArray)`.
877+
@inline function reinterpret_scaled_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
878+
bs = bitwidth(S)
879+
bt = bitwidth(T)
880+
N == 0 && return () # 0-D: only equal-width is valid; cross-width caught at emit
881+
return (fld(sz[1] * bs, bt), Base.tail(sz)...)
882+
end
883+
884+
# Result shape for `reinterpret(reshape, T, x)`: drop the leading dim on widening
885+
# (it must equal the ratio), prepend one on narrowing, like the array version.
886+
@inline function reinterpret_reshape_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
887+
bs = bitwidth(S)
888+
bt = bitwidth(T)
889+
bs == bt && return sz
890+
N == 0 && return () # cross-width on a 0-D tile is invalid; caught at emit
891+
return bt > bs ? Base.tail(sz) : (div(bs, bt), sz...)
892+
end
893+
894+
"""
895+
Base.reinterpret(::Type{T}, x::Tile) -> Tile{T}
896+
897+
Reinterpret the *whole tile* `x` as a tile of element type `T`, like
898+
`reinterpret(T, ::AbstractArray)`: the underlying bits are viewed as a contiguous
899+
(column-major) block and the leading dimension is rescaled by the ratio of
900+
element widths. Lowers to `cuda_tile.bitcast` for equal widths and to
901+
`cuda_tile.pack`/`unpack` (via `reshape` to rank-1) when widths differ.
902+
903+
This is how sub-byte formats move through global memory: a `Tile{UInt8,(N,)}`
904+
reinterprets to a `Tile{Float4_E2M1FN,(2N,)}` and back, so FP4 data can be stored
905+
in a `UInt8` array. The total bit-width is preserved, so it must divide evenly.
906+
907+
Note `reinterpret.(T, x)` (with a dot) is the unrelated *element-wise* broadcast,
908+
which keeps the shape and requires `T` to be the same width as `eltype(x)`.
909+
910+
```julia
911+
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)}
912+
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)}
913+
vals = convert(ct.Tile{Float32}, fp4) # widen for compute
914+
```
915+
"""
916+
@inline function Base.reinterpret(::Type{T}, x::Tile) where {T}
917+
rshape = reinterpret_scaled_shape(T, eltype(x), size(x))
918+
flat = Intrinsics.reshape(x, (prod(size(x)),))
919+
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
920+
end
921+
922+
"""
923+
Base.reinterpret(reshape, ::Type{T}, x::Tile) -> Tile{T}
924+
925+
The `reshape`-form whole-tile reinterpret, mirroring
926+
`reinterpret(reshape, T, ::AbstractArray)`: instead of rescaling the leading
927+
dimension it *removes* it when widening (the leading dim must equal
928+
`bitwidth(T) ÷ bitwidth(eltype(x))`) and *prepends* one when narrowing.
929+
"""
930+
@inline function Base.reinterpret(::typeof(reshape), ::Type{T}, x::Tile) where {T}
931+
rshape = reinterpret_reshape_shape(T, eltype(x), size(x))
932+
flat = Intrinsics.reshape(x, (prod(size(x)),))
933+
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
934+
end
935+
860936
@inline Base.convert(::Type{Tile{T}}, tile::Tile{T}) where {T} = tile
861937
@inline Base.convert(::Type{Tile{T2}}, tile::Tile{T1, Shape}) where {T1, T2, Shape} =
862938
map(T2, tile)

0 commit comments

Comments
 (0)