Skip to content

Commit 973c8d8

Browse files
committed
Move Base.reinterpret methods to operations.jl
1 parent 7d0f0f5 commit 973c8d8

2 files changed

Lines changed: 76 additions & 76 deletions

File tree

src/compiler/intrinsics/conversions.jl

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -152,82 +152,6 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.unpack), args)
152152
CGVal(result_v, result_type_id, Tile{target_type, Tuple{new_n}}, new_shape)
153153
end
154154

155-
# Width-convert a rank-1 tile to element type `T` (rank-1 in, rank-1 out).
156-
@inline function reinterpret_width(::Type{T}, flat::Tile{S}) where {T, S}
157-
bs = bitwidth(S)
158-
bt = bitwidth(T)
159-
if bs == bt
160-
return Intrinsics.bitcast(flat, T) # same width
161-
elseif bt == 8
162-
return Intrinsics.bitcast(Intrinsics.pack(flat), T) # S → bytes → T8
163-
elseif bs == 8
164-
return Intrinsics.unpack(Intrinsics.bitcast(flat, UInt8), T) # S8 → bytes → T
165-
else
166-
return Intrinsics.unpack(Intrinsics.pack(flat), T) # S → bytes → T
167-
end
168-
end
169-
170-
# Result shape for `reinterpret(T, x)`: rescale the leading (column-major)
171-
# dimension by the element-width ratio, like `reinterpret(T, ::AbstractArray)`.
172-
@inline function reinterpret_scaled_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
173-
bs = bitwidth(S)
174-
bt = bitwidth(T)
175-
N == 0 && return () # 0-D: only equal-width is valid; cross-width caught at emit
176-
return (fld(sz[1] * bs, bt), Base.tail(sz)...)
177-
end
178-
179-
# Result shape for `reinterpret(reshape, T, x)`: drop the leading dim on widening
180-
# (it must equal the ratio), prepend one on narrowing, like the array version.
181-
@inline function reinterpret_reshape_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
182-
bs = bitwidth(S)
183-
bt = bitwidth(T)
184-
bs == bt && return sz
185-
N == 0 && return () # cross-width on a 0-D tile is invalid; caught at emit
186-
return bt > bs ? Base.tail(sz) : (div(bs, bt), sz...)
187-
end
188-
189-
"""
190-
Base.reinterpret(::Type{T}, x::Tile) -> Tile{T}
191-
192-
Reinterpret the *whole tile* `x` as a tile of element type `T`, like
193-
`reinterpret(T, ::AbstractArray)`: the underlying bits are viewed as a contiguous
194-
(column-major) block and the leading dimension is rescaled by the ratio of
195-
element widths. Lowers to `cuda_tile.bitcast` for equal widths and to
196-
`cuda_tile.pack`/`unpack` (via `reshape` to rank-1) when widths differ.
197-
198-
This is how sub-byte formats move through global memory: a `Tile{UInt8,(N,)}`
199-
reinterprets to a `Tile{Float4_E2M1FN,(2N,)}` and back, so FP4 data can be stored
200-
in a `UInt8` array. The total bit-width is preserved, so it must divide evenly.
201-
202-
Note `reinterpret.(T, x)` (with a dot) is the unrelated *element-wise* broadcast,
203-
which keeps the shape and requires `T` to be the same width as `eltype(x)`.
204-
205-
```julia
206-
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)}
207-
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)}
208-
vals = convert(ct.Tile{Float32}, fp4) # widen for compute
209-
```
210-
"""
211-
@inline function Base.reinterpret(::Type{T}, x::Tile) where {T}
212-
rshape = reinterpret_scaled_shape(T, eltype(x), size(x))
213-
flat = Intrinsics.reshape(x, (prod(size(x)),))
214-
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
215-
end
216-
217-
"""
218-
Base.reinterpret(reshape, ::Type{T}, x::Tile) -> Tile{T}
219-
220-
The `reshape`-form whole-tile reinterpret, mirroring
221-
`reinterpret(reshape, T, ::AbstractArray)`: instead of rescaling the leading
222-
dimension it *removes* it when widening (the leading dim must equal
223-
`bitwidth(T) ÷ bitwidth(eltype(x))`) and *prepends* one when narrowing.
224-
"""
225-
@inline function Base.reinterpret(::typeof(reshape), ::Type{T}, x::Tile) where {T}
226-
rshape = reinterpret_reshape_shape(T, eltype(x), size(x))
227-
flat = Intrinsics.reshape(x, (prod(size(x)),))
228-
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
229-
end
230-
231155
"""
232156
Intrinsics.exti(x::Tile{<:Integer}, ::Type{T}, s::Signedness.T) -> Tile{T} where {T<:Integer}
233157

src/language/operations.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,82 @@ Equivalent to single-arg `permutedims`.
932932
end
933933
end
934934

935+
# Width-convert a rank-1 tile to element type `T` (rank-1 in, rank-1 out).
936+
@inline function reinterpret_width(::Type{T}, flat::Tile{S}) where {T, S}
937+
bs = bitwidth(S)
938+
bt = bitwidth(T)
939+
if bs == bt
940+
return Intrinsics.bitcast(flat, T) # same width
941+
elseif bt == 8
942+
return Intrinsics.bitcast(Intrinsics.pack(flat), T) # S → bytes → T8
943+
elseif bs == 8
944+
return Intrinsics.unpack(Intrinsics.bitcast(flat, UInt8), T) # S8 → bytes → T
945+
else
946+
return Intrinsics.unpack(Intrinsics.pack(flat), T) # S → bytes → T
947+
end
948+
end
949+
950+
# Result shape for `reinterpret(T, x)`: rescale the leading (column-major)
951+
# dimension by the element-width ratio, like `reinterpret(T, ::AbstractArray)`.
952+
@inline function reinterpret_scaled_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
953+
bs = bitwidth(S)
954+
bt = bitwidth(T)
955+
N == 0 && return () # 0-D: only equal-width is valid; cross-width caught at emit
956+
return (fld(sz[1] * bs, bt), Base.tail(sz)...)
957+
end
958+
959+
# Result shape for `reinterpret(reshape, T, x)`: drop the leading dim on widening
960+
# (it must equal the ratio), prepend one on narrowing, like the array version.
961+
@inline function reinterpret_reshape_shape(::Type{T}, ::Type{S}, sz::NTuple{N, Int}) where {T, S, N}
962+
bs = bitwidth(S)
963+
bt = bitwidth(T)
964+
bs == bt && return sz
965+
N == 0 && return () # cross-width on a 0-D tile is invalid; caught at emit
966+
return bt > bs ? Base.tail(sz) : (div(bs, bt), sz...)
967+
end
968+
969+
"""
970+
Base.reinterpret(::Type{T}, x::Tile) -> Tile{T}
971+
972+
Reinterpret the *whole tile* `x` as a tile of element type `T`, like
973+
`reinterpret(T, ::AbstractArray)`: the underlying bits are viewed as a contiguous
974+
(column-major) block and the leading dimension is rescaled by the ratio of
975+
element widths. Lowers to `cuda_tile.bitcast` for equal widths and to
976+
`cuda_tile.pack`/`unpack` (via `reshape` to rank-1) when widths differ.
977+
978+
This is how sub-byte formats move through global memory: a `Tile{UInt8,(N,)}`
979+
reinterprets to a `Tile{Float4_E2M1FN,(2N,)}` and back, so FP4 data can be stored
980+
in a `UInt8` array. The total bit-width is preserved, so it must divide evenly.
981+
982+
Note `reinterpret.(T, x)` (with a dot) is the unrelated *element-wise* broadcast,
983+
which keeps the shape and requires `T` to be the same width as `eltype(x)`.
984+
985+
```julia
986+
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)}
987+
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)}
988+
vals = convert(ct.Tile{Float32}, fp4) # widen for compute
989+
```
990+
"""
991+
@inline function Base.reinterpret(::Type{T}, x::Tile) where {T}
992+
rshape = reinterpret_scaled_shape(T, eltype(x), size(x))
993+
flat = Intrinsics.reshape(x, (prod(size(x)),))
994+
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
995+
end
996+
997+
"""
998+
Base.reinterpret(reshape, ::Type{T}, x::Tile) -> Tile{T}
999+
1000+
The `reshape`-form whole-tile reinterpret, mirroring
1001+
`reinterpret(reshape, T, ::AbstractArray)`: instead of rescaling the leading
1002+
dimension it *removes* it when widening (the leading dim must equal
1003+
`bitwidth(T) ÷ bitwidth(eltype(x))`) and *prepends* one when narrowing.
1004+
"""
1005+
@inline function Base.reinterpret(::typeof(reshape), ::Type{T}, x::Tile) where {T}
1006+
rshape = reinterpret_reshape_shape(T, eltype(x), size(x))
1007+
flat = Intrinsics.reshape(x, (prod(size(x)),))
1008+
return Intrinsics.reshape(reinterpret_width(T, flat), rshape)
1009+
end
1010+
9351011
@inline Base.convert(::Type{Tile{T}}, tile::Tile{T}) where {T} = tile
9361012
@inline Base.convert(::Type{Tile{T2}}, tile::Tile{T1, Shape}) where {T1, T2, Shape} =
9371013
map(T2, tile)

0 commit comments

Comments
 (0)