Skip to content

Commit df5a821

Browse files
authored
README update and ct.where removal (#180)
1 parent 33a314f commit df5a821

5 files changed

Lines changed: 124 additions & 63 deletions

File tree

README.md

Lines changed: 102 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
A Julia package for writing GPU kernels using NVIDIA's tile-based programming model.
44

5-
**This package is under active development.** Not all Tile IR features are implemented, and
6-
support for the Julia language is limited and only verified on the examples provided here.
7-
Interfaces and APIs may change without notice.
5+
**This package is in beta.** Most Tile IR features are implemented and the package has been
6+
verified on the benchmarks and tests included in the repository. Interfaces and APIs may
7+
still change without notice.
88

99

1010
## Installation
@@ -19,7 +19,8 @@ julia> Pkg.add("cuTile")
1919
Execution of cuTile kernels requires CUDA.jl to be installed and imported.
2020
cuTile generates kernels based on [Tile IR](https://docs.nvidia.com/cuda/tile-ir/), which requires an NVIDIA Driver that supports CUDA 13 (580 or later).
2121
CUDA.jl automatically downloads the appropriate CUDA toolkit artifacts, so no manual CUDA installation is needed.
22-
Only Ampere, Ada, and Blackwell GPUs are supported at this time, with Hopper support coming in a later release of CUDA 13.
22+
Only Ampere, Ada, and Blackwell GPUs are supported at this time, with Hopper support expected
23+
in a future release of CUDA.
2324

2425
## Quick Start
2526

@@ -124,6 +125,15 @@ cuTile.jl aims to expose as much functionality as possible through Julia-native
124125
prefixed with `ct.` are cuTile intrinsics with no direct Julia equivalent; everything else
125126
uses standard Julia syntax and is overlaid on `Base`.
126127

128+
### Supported Types
129+
130+
**Integers:** `Int8`, `UInt8`, `Int16`, `UInt16`, `Int32`, `UInt32`, `Int64`, `UInt64`
131+
**Floats:** `Float16`, `BFloat16`, `Float32`, `Float64`, `TFloat32`
132+
**Boolean:** `Bool`
133+
134+
`TFloat32` is a 32-bit floating-point type with reduced mantissa precision (10 bits),
135+
optimized for tensor core operations.
136+
127137
### Memory
128138
| Operation | Description |
129139
|-----------|-------------|
@@ -151,6 +161,18 @@ ct.scatter(arr, indices, tile; mask=active_mask)
151161
| `ct.num_blocks(axis)` | Grid size along axis |
152162
| `ct.num_tiles(arr, axis, shape)` | Number of tiles along axis |
153163

164+
### Control Flow
165+
| Construct | Description |
166+
|-----------|-------------|
167+
| `if`/`elseif`/`else` | Conditional branching |
168+
| `for i in start:stop` | Counted loops (compiled to Tile IR ForOp) |
169+
| `for i in start:step:stop` | Stepped loops |
170+
| `while cond ... end` | While loops |
171+
| `ifelse.(cond, x, y)` | Element-wise conditional selection |
172+
173+
Standard Julia control flow works inside kernels and is compiled to structured
174+
Tile IR operations.
175+
154176
### Arithmetic
155177
| Operation | Description |
156178
|-----------|-------------|
@@ -189,6 +211,7 @@ ct.scatter(arr, indices, tile; mask=active_mask)
189211
| `map(f, tiles...)` | Apply function element-wise (same shape) |
190212
| `f.(tiles...)`, `broadcast(f, tiles...)` | Apply function with shape broadcasting |
191213
| `reduce(f, tile; dims, init)` | Reduction with arbitrary function |
214+
| `mapreduce(f, op, tile; dims, init)` | Map then reduce |
192215
| `accumulate(f, tile; dims, init, rev)` | Scan/prefix-sum with arbitrary function |
193216

194217
### Reductions
@@ -215,10 +238,13 @@ ct.scatter(arr, indices, tile; mask=active_mask)
215238
| `exp2.(tile)` | Base-2 exponential |
216239
| `log.(tile)` | Natural logarithm |
217240
| `log2.(tile)` | Base-2 logarithm |
218-
| `sin.(tile)`, `cos.(tile)`, etc. | Trigonometric functions |
241+
| `sin.(tile)`, `cos.(tile)`, `tan.(tile)` | Trigonometric functions |
242+
| `sinh.(tile)`, `cosh.(tile)`, `tanh.(tile)` | Hyperbolic functions |
219243
| `fma.(a, b, c)` | Fused multiply-add |
220244
| `abs.(tile)` | Absolute value |
245+
| `isnan.(tile)` | NaN test |
221246
| `max(a, b)`, `min(a, b)` | Maximum/minimum (scalars) |
247+
| `ceil.(tile)`, `floor.(tile)` | Rounding |
222248
| `ct.@fpmode rounding_mode=ct.Rounding.Approx flush_to_zero=true begin ... end` | Scoped FP rounding mode and flush-to-zero |
223249

224250
### Comparison
@@ -242,6 +268,14 @@ ct.scatter(arr, indices, tile; mask=active_mask)
242268
| `div(a, b)` | Truncating division |
243269
| `mul_hi.(tile_a, tile_b)`, `mul_hi(x, y)` | High bits of integer multiply (use `Base.mul_hi` on Julia 1.13+) |
244270

271+
### Indexing
272+
| Operation | Description |
273+
|-----------|-------------|
274+
| `arr[i, j, ...]` | Load scalar element from `TileArray` |
275+
| `arr[i, j, ...] = val` | Store scalar element to `TileArray` |
276+
| `tile[i, j, ...]` | Extract scalar from `Tile` |
277+
| `setindex(tile, val, i, j, ...)` | Return new `Tile` with element replaced |
278+
245279
### Atomics
246280
| Operation | Description |
247281
|-----------|-------------|
@@ -257,11 +291,52 @@ ct.scatter(arr, indices, tile; mask=active_mask)
257291
All atomics accept `memory_order` (default: `ct.MemoryOrder.AcqRel`) and
258292
`memory_scope` (default: `ct.MemScope.Device`) keyword arguments.
259293

294+
### Performance Tuning
295+
296+
#### Kernel configuration
297+
298+
`ct.@compiler_options` sets optimization hints inside a kernel function body:
299+
300+
```julia
301+
function matmul(A, B, C, ...)
302+
ct.@compiler_options num_ctas=ct.ByTarget(v"10.0" => 2) occupancy=8
303+
...
304+
end
305+
```
306+
307+
| Option | Description | Valid values |
308+
|--------|-------------|--------------|
309+
| `num_ctas` | Number of CTAs in a CGA | Powers of 2 |
310+
| `occupancy` | Target concurrent CTAs per SM | 1–32 |
311+
| `opt_level` | Optimization level | 0–3 |
312+
313+
Values can be plain scalars or `ct.ByTarget(...)` for per-architecture dispatch.
314+
`ByTarget` maps compute capabilities to values, with an optional default:
315+
316+
```julia
317+
ct.@compiler_options num_ctas=ct.ByTarget(v"10.0" => 4, v"12.0" => 2; default=1)
318+
```
319+
320+
Hints can also be passed as keyword arguments to `ct.launch` or `ct.code_tiled`,
321+
which take precedence over `@compiler_options`.
322+
323+
#### Load/store hints
324+
325+
`ct.load` and `ct.store` accept optional keyword arguments that influence memory
326+
traffic scheduling:
327+
328+
| Hint | Description |
329+
|------|-------------|
330+
| `latency` | DRAM traffic weight hint, integer 1 (low) to 10 (high). Default: compiler-inferred. |
331+
| `allow_tma` | Whether to allow Tensor Memory Accelerator lowering. Default: allowed. |
332+
260333
### Debugging
334+
261335
| Operation | Description |
262336
|-----------|-------------|
263337
| `print(args...)` | Print values (Base overlay) |
264338
| `println(args...)` | Print values with newline (Base overlay) |
339+
| `ct.@assert cond [msg]` | Abort kernel if condition is false |
265340

266341
Standard Julia `print`/`println` work inside kernels. String constants and tiles
267342
can be mixed freely; format specifiers are inferred from element types at compile
@@ -270,9 +345,27 @@ time. String interpolation is supported.
270345
```julia
271346
println("Block ", ct.bid(1), ": tile=", tile)
272347
println("result=$result") # string interpolation
348+
ct.@assert idx <= n "index out of bounds"
273349
```
274350

275-
This is a debugging aid and is not optimized for performance.
351+
These are debugging aids and are not optimized for performance.
352+
353+
### Code Inspection
354+
355+
Beyond `ct.code_tiled` and `ct.@code_tiled` shown above, cuTile.jl provides
356+
`@device_code_*` macros that intercept compilation during `ct.launch`:
357+
358+
```julia
359+
ct.@device_code_tiled ct.launch(vadd, grid, a, b, c, ct.Constant(16))
360+
ct.@device_code_typed ct.launch(vadd, grid, a, b, c, ct.Constant(16))
361+
ct.@device_code_structured ct.launch(vadd, grid, a, b, c, ct.Constant(16))
362+
```
363+
364+
| Macro | Output |
365+
|-------|--------|
366+
| `ct.@device_code_tiled` | Final Tile IR (MLIR textual format) |
367+
| `ct.@device_code_typed` | Typed Julia IR after overlay resolution |
368+
| `ct.@device_code_structured` | Structured IR (after control-flow structurization) |
276369

277370

278371
## Differences from cuTile Python
@@ -312,7 +405,8 @@ end
312405
### Optimization hints
313406

314407
Python passes optimization hints as `@ct.kernel` decorator arguments. Julia uses
315-
`ct.@compiler_options` inside the function body (like `@inline`):
408+
`ct.@compiler_options` inside the function body (like `@inline`). See
409+
[Performance Tuning](#performance-tuning) for full details.
316410

317411
```python
318412
# Python
@@ -329,24 +423,6 @@ function matmul(A, B, C, ...)
329423
end
330424
```
331425

332-
Supported options:
333-
334-
| Option | Description | Valid values |
335-
|--------|-------------|--------------|
336-
| `num_ctas` | Number of CTAs in a CGA (cooperative group array) | Powers of 2 |
337-
| `occupancy` | Target occupancy (number of concurrent CTAs per SM) | 1–32 |
338-
| `opt_level` | Optimization level | 0–3 |
339-
340-
Values can be plain scalars or `ct.ByTarget(...)` for per-architecture dispatch.
341-
`ByTarget` maps compute capabilities to values, with an optional default:
342-
343-
```julia
344-
ct.@compiler_options num_ctas=ct.ByTarget(v"10.0" => 4, v"12.0" => 2; default=1)
345-
```
346-
347-
Hints can also be passed as keyword arguments to `ct.launch` or `ct.code_tiled`,
348-
which take precedence over `@compiler_options`.
349-
350426
### Launch Syntax
351427

352428
cuTile.jl implicitly uses the current task-bound stream from CUDA.jl:
@@ -515,7 +591,7 @@ standard Julia silently produce truncated or wrapped results instead:
515591
throwing `InexactError` for non-integer or out-of-range values. Use
516592
`unsafe_trunc` for the explicit non-throwing primitive.
517593

518-
Assertions may be added in the future for testing purposes.
594+
Use `ct.@assert` to add runtime checks in kernels (see Debugging above).
519595

520596

521597
## Host-level operations

src/language/operations.jl

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ end
169169
idx = Intrinsics.iota((flat_len,), Int32)
170170
mask = idx .== Int32(linear)
171171
val_tile = broadcast_to(Tile(T(val)), (flat_len,))
172-
new_flat = where(mask, val_tile, flat)
172+
new_flat = ifelse.(mask, val_tile, flat)
173173
reshape(new_flat, S)
174174
end
175175

@@ -1038,22 +1038,7 @@ end
10381038
Selection
10391039
=============================================================================#
10401040

1041-
public where, extract
1042-
1043-
"""
1044-
where(cond::Tile{Bool}, x, y) -> Tile
1045-
1046-
Element-wise conditional selection: returns x where cond is true, y otherwise.
1047-
Similar to numpy.where() or torch.where(). Supports broadcasting and scalar arguments.
1048-
1049-
# Example
1050-
```julia
1051-
mask = tile_a .> tile_b # Boolean tile
1052-
result = ct.where(mask, tile_a, tile_b) # Element-wise max
1053-
result = ct.where(mask, tile_a, 0.0f0) # Zero out where mask is false
1054-
```
1055-
"""
1056-
where(cond, x, y) = ifelse.(cond, x, y)
1041+
public extract
10571042

10581043
"""
10591044
extract(tile::Tile{T, S}, index::NTuple{N, Int}, shape::NTuple{N, Int}) -> Tile{T, shape}

test/codegen/operations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ spec4d = ct.ArraySpec{4}(16, true)
366366
tile_b = ct.load(b, pid, (16,))
367367
@check "cmpf"
368368
mask = tile_a .< tile_b
369-
result = ct.where(mask, tile_a, tile_b)
369+
result = ifelse.(mask, tile_a, tile_b)
370370
ct.store(c, pid, result)
371371
return
372372
end
@@ -382,7 +382,7 @@ spec4d = ct.ArraySpec{4}(16, true)
382382
tile_b = ct.load(b, pid, (16,))
383383
@check "cmpi"
384384
mask = tile_a .< tile_b
385-
result = ct.where(mask, tile_a, tile_b)
385+
result = ifelse.(mask, tile_a, tile_b)
386386
ct.store(c, pid, result)
387387
return
388388
end
@@ -402,7 +402,7 @@ spec4d = ct.ArraySpec{4}(16, true)
402402
result = a .< b
403403
# Use same-typed operands for where to avoid Union type
404404
b_promoted = convert(ct.Tile{Int64}, b)
405-
selected = ct.where(result, a, b_promoted)
405+
selected = ifelse.(result, a, b_promoted)
406406
ct.store(out, Int32(0), selected)
407407
return
408408
end
@@ -820,7 +820,7 @@ spec4d = ct.ArraySpec{4}(16, true)
820820
tile_b = ct.load(b, pid, (16,))
821821
mask = tile_a .> tile_b
822822
@check "select"
823-
result = ct.where(mask, tile_a, tile_b)
823+
result = ifelse.(mask, tile_a, tile_b)
824824
ct.store(c, pid, result)
825825
return
826826
end

test/device/broadcast.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ for (name, op1, op2) in [
227227
pid = ct.bid(1)
228228
ta = ct.load(a, pid, (16,))
229229
tb = ct.load(b, pid, (16,))
230-
ct.store(out1, pid, ct.where(broadcast($op1, ta, tb), 1.0f0, 0.0f0))
231-
ct.store(out2, pid, ct.where(broadcast($op2, ta, tb), 1.0f0, 0.0f0))
230+
ct.store(out1, pid, ifelse.(broadcast($op1, ta, tb), 1.0f0, 0.0f0))
231+
ct.store(out2, pid, ifelse.(broadcast($op2, ta, tb), 1.0f0, 0.0f0))
232232
return
233233
end
234234
n = 1024
@@ -248,8 +248,8 @@ end
248248
pid = ct.bid(1)
249249
ta = ct.load(a, pid, (16,))
250250
tb = ct.load(b, pid, (16,))
251-
ct.store(out_eq, pid, ct.where(ta .== tb, 1.0f0, 0.0f0))
252-
ct.store(out_ne, pid, ct.where(ta .!= tb, 1.0f0, 0.0f0))
251+
ct.store(out_eq, pid, ifelse.(ta .== tb, 1.0f0, 0.0f0))
252+
ct.store(out_ne, pid, ifelse.(ta .!= tb, 1.0f0, 0.0f0))
253253
return
254254
end
255255

@@ -281,7 +281,7 @@ end
281281
pid = ct.bid(1)
282282
ta = ct.load(a, pid, (16,))
283283
tb = ct.load(b, pid, (16,))
284-
ct.store(out, pid, ct.where(broadcast($op, ta, tb), 1.0f0, 0.0f0))
284+
ct.store(out, pid, ifelse.(broadcast($op, ta, tb), 1.0f0, 0.0f0))
285285
return
286286
end
287287
n = 1024
@@ -302,7 +302,7 @@ end
302302
out::ct.TileArray{Float32,1})
303303
pid = ct.bid(1)
304304
ta = ct.load(a, pid, (16,))
305-
ct.store(out, pid, ct.where(ta .> 0.5f0, 1.0f0, 0.0f0))
305+
ct.store(out, pid, ifelse.(ta .> 0.5f0, 1.0f0, 0.0f0))
306306
return
307307
end
308308

@@ -358,16 +358,16 @@ end
358358

359359
end
360360

361-
@testset "where / ifelse broadcasting" begin
361+
@testset "ifelse broadcasting" begin
362362

363-
@testset "where same-shape" begin
363+
@testset "ifelse same-shape" begin
364364
function where_same_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
365365
c::ct.TileArray{Float32,1})
366366
pid = ct.bid(1)
367367
ta = ct.load(a, pid, (16,))
368368
tb = ct.load(b, pid, (16,))
369369
mask = ta .> tb
370-
result = ct.where(mask, ta, tb)
370+
result = ifelse.(mask, ta, tb)
371371
ct.store(c, pid, result)
372372
return
373373
end
@@ -382,12 +382,12 @@ end
382382
@test Array(c) ifelse.(Array(a) .> Array(b), Array(a), Array(b)) rtol=1e-5
383383
end
384384

385-
@testset "where with scalar y" begin
385+
@testset "ifelse with scalar y" begin
386386
function where_scalar_y_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1})
387387
pid = ct.bid(1)
388388
ta = ct.load(a, pid, (16,))
389389
mask = ta .> 0.5f0
390-
result = ct.where(mask, ta, 0.0f0)
390+
result = ifelse.(mask, ta, 0.0f0)
391391
ct.store(b, pid, result)
392392
return
393393
end
@@ -401,12 +401,12 @@ end
401401
@test Array(b) ifelse.(Array(a) .> 0.5f0, Array(a), 0.0f0) rtol=1e-5
402402
end
403403

404-
@testset "where with scalar x" begin
404+
@testset "ifelse with scalar x" begin
405405
function where_scalar_x_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1})
406406
pid = ct.bid(1)
407407
ta = ct.load(a, pid, (16,))
408408
mask = ta .> 0.5f0
409-
result = ct.where(mask, 1.0f0, ta)
409+
result = ifelse.(mask, 1.0f0, ta)
410410
ct.store(b, pid, result)
411411
return
412412
end
@@ -420,11 +420,11 @@ end
420420
@test Array(b) ifelse.(Array(a) .> 0.5f0, 1.0f0, Array(a)) rtol=1e-5
421421
end
422422

423-
@testset "where with broadcasting" begin
423+
@testset "ifelse with broadcasting" begin
424424
function where_broadcast_kernel(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,2})
425425
mask = ct.load(a, (1, 1), (1, 128)) # (1, 128) mask
426426
tile = ct.load(a, (1, 1), (64, 128)) # (64, 128) tile
427-
result = ct.where(mask .> 0.5f0, tile, 0.0f0)
427+
result = ifelse.(mask .> 0.5f0, tile, 0.0f0)
428428
ct.store(b, (1, 1), result)
429429
return
430430
end

0 commit comments

Comments
 (0)