Skip to content

Commit 2b4d57d

Browse files
committed
Use Tuple{...} for Shape type parameter of Tile in comments and docs introduced in #238
1 parent fc5ede9 commit 2b4d57d

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

src/language/operations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -975,16 +975,16 @@ Reinterpret the *whole tile* `x` as a tile of element type `T`, like
975975
element widths. Lowers to `cuda_tile.bitcast` for equal widths and to
976976
`cuda_tile.pack`/`unpack` (via `reshape` to rank-1) when widths differ.
977977
978-
This is how sub-byte formats move through global memory: a `Tile{UInt8,(N,)}`
979-
reinterprets to a `Tile{Float4_E2M1FN,(2N,)}` and back, so FP4 data can be stored
978+
This is how sub-byte formats move through global memory: a `Tile{UInt8,Tuple{N}}`
979+
reinterprets to a `Tile{Float4_E2M1FN,Tuple{2N}}` and back, so FP4 data can be stored
980980
in a `UInt8` array. The total bit-width is preserved, so it must divide evenly.
981981
982982
Note `reinterpret.(T, x)` (with a dot) is the unrelated *element-wise* broadcast,
983983
which keeps the shape and requires `T` to be the same width as `eltype(x)`.
984984
985985
```julia
986-
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)}
987-
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)}
986+
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,Tuple{8}}
987+
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,Tuple{16}}
988988
vals = convert(ct.Tile{Float32}, fp4) # widen for compute
989989
```
990990
"""

test/extensions/Microfloats.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ let kernel = (a, b) -> begin
7272
end
7373

7474
# Whole-tile `reinterpret` between UInt8 and Float4_E2M1FN packs/unpacks two FP4
75-
# per byte: a `Tile{UInt8,(8,)}` unpacks to a `Tile{Float4_E2M1FN,(16,)}`,
75+
# per byte: a `Tile{UInt8,Tuple{8}}` unpacks to a `Tile{Float4_E2M1FN,Tuple{16}}`,
7676
# lowering to `cuda_tile.unpack` (13.3+).
7777
@test @filecheck begin
7878
@check_label "entry"
7979
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{Float32,1,spec1d}};
8080
bytecode_version=v"13.3") do a, b
8181
pid = ct.bid(1)
82-
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)}
82+
bytes = ct.load(a, pid, (8,)) # Tile{UInt8,Tuple{8}}
8383
@check "unpack"
84-
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)}
84+
fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,Tuple{16}}
8585
ct.store(b, pid, convert(ct.Tile{Float32}, fp4))
8686
return
8787
end
@@ -94,9 +94,9 @@ end
9494
bytecode_version=v"13.3") do a, b
9595
pid = ct.bid(1)
9696
vals = ct.load(a, pid, (16,))
97-
fp4 = convert(ct.Tile{Float4_E2M1FN}, vals) # Tile{Float4_E2M1FN,(16,)}
97+
fp4 = convert(ct.Tile{Float4_E2M1FN}, vals) # Tile{Float4_E2M1FN,Tuple{16}}
9898
@check "pack"
99-
ct.store(b, pid, reinterpret(UInt8, fp4)) # Tile{UInt8,(8,)}
99+
ct.store(b, pid, reinterpret(UInt8, fp4)) # Tile{UInt8,Tuple{8}}
100100
return
101101
end
102102
end

0 commit comments

Comments
 (0)