22
33A Julia package for writing GPU kernels using NVIDIA's tile-based programming model.
44
5- ** This package is under active development .** Not all Tile IR features are implemented, and
6- support for the Julia language is limited and only verified on the examples provided here.
7- Interfaces and APIs may change without notice.
5+ ** This package is in beta .** Most Tile IR features are implemented and the package has been
6+ verified on the benchmarks and tests included in the repository. Interfaces and APIs may
7+ still change without notice.
88
99
1010## Installation
@@ -19,7 +19,8 @@ julia> Pkg.add("cuTile")
1919Execution of cuTile kernels requires CUDA.jl to be installed and imported.
2020cuTile generates kernels based on [ Tile IR] ( https://docs.nvidia.com/cuda/tile-ir/ ) , which requires an NVIDIA Driver that supports CUDA 13 (580 or later).
2121CUDA.jl automatically downloads the appropriate CUDA toolkit artifacts, so no manual CUDA installation is needed.
22- Only Ampere, Ada, and Blackwell GPUs are supported at this time, with Hopper support coming in a later release of CUDA 13.
22+ Only Ampere, Ada, and Blackwell GPUs are supported at this time, with Hopper support expected
23+ in a future release of CUDA.
2324
2425## Quick Start
2526
@@ -124,6 +125,15 @@ cuTile.jl aims to expose as much functionality as possible through Julia-native
124125prefixed with ` ct. ` are cuTile intrinsics with no direct Julia equivalent; everything else
125126uses standard Julia syntax and is overlaid on ` Base ` .
126127
128+ ### Supported Types
129+
130+ ** Integers:** ` Int8 ` , ` UInt8 ` , ` Int16 ` , ` UInt16 ` , ` Int32 ` , ` UInt32 ` , ` Int64 ` , ` UInt64 `
131+ ** Floats:** ` Float16 ` , ` BFloat16 ` , ` Float32 ` , ` Float64 ` , ` TFloat32 `
132+ ** Boolean:** ` Bool `
133+
134+ ` TFloat32 ` is a 32-bit floating-point type with reduced mantissa precision (10 bits),
135+ optimized for tensor core operations.
136+
127137### Memory
128138| Operation | Description |
129139| -----------| -------------|
@@ -151,6 +161,18 @@ ct.scatter(arr, indices, tile; mask=active_mask)
151161| ` ct.num_blocks(axis) ` | Grid size along axis |
152162| ` ct.num_tiles(arr, axis, shape) ` | Number of tiles along axis |
153163
164+ ### Control Flow
165+ | Construct | Description |
166+ | -----------| -------------|
167+ | ` if ` /` elseif ` /` else ` | Conditional branching |
168+ | ` for i in start:stop ` | Counted loops (compiled to Tile IR ForOp) |
169+ | ` for i in start:step:stop ` | Stepped loops |
170+ | ` while cond ... end ` | While loops |
171+ | ` ifelse.(cond, x, y) ` | Element-wise conditional selection |
172+
173+ Standard Julia control flow works inside kernels and is compiled to structured
174+ Tile IR operations.
175+
154176### Arithmetic
155177| Operation | Description |
156178| -----------| -------------|
@@ -189,6 +211,7 @@ ct.scatter(arr, indices, tile; mask=active_mask)
189211| ` map(f, tiles...) ` | Apply function element-wise (same shape) |
190212| ` f.(tiles...) ` , ` broadcast(f, tiles...) ` | Apply function with shape broadcasting |
191213| ` reduce(f, tile; dims, init) ` | Reduction with arbitrary function |
214+ | ` mapreduce(f, op, tile; dims, init) ` | Map then reduce |
192215| ` accumulate(f, tile; dims, init, rev) ` | Scan/prefix-sum with arbitrary function |
193216
194217### Reductions
@@ -215,10 +238,13 @@ ct.scatter(arr, indices, tile; mask=active_mask)
215238| ` exp2.(tile) ` | Base-2 exponential |
216239| ` log.(tile) ` | Natural logarithm |
217240| ` log2.(tile) ` | Base-2 logarithm |
218- | ` sin.(tile) ` , ` cos.(tile) ` , etc. | Trigonometric functions |
241+ | ` sin.(tile) ` , ` cos.(tile) ` , ` tan.(tile) ` | Trigonometric functions |
242+ | ` sinh.(tile) ` , ` cosh.(tile) ` , ` tanh.(tile) ` | Hyperbolic functions |
219243| ` fma.(a, b, c) ` | Fused multiply-add |
220244| ` abs.(tile) ` | Absolute value |
245+ | ` isnan.(tile) ` | NaN test |
221246| ` max(a, b) ` , ` min(a, b) ` | Maximum/minimum (scalars) |
247+ | ` ceil.(tile) ` , ` floor.(tile) ` | Rounding |
222248| ` ct.@fpmode rounding_mode=ct.Rounding.Approx flush_to_zero=true begin ... end ` | Scoped FP rounding mode and flush-to-zero |
223249
224250### Comparison
@@ -242,6 +268,14 @@ ct.scatter(arr, indices, tile; mask=active_mask)
242268| ` div(a, b) ` | Truncating division |
243269| ` mul_hi.(tile_a, tile_b) ` , ` mul_hi(x, y) ` | High bits of integer multiply (use ` Base.mul_hi ` on Julia 1.13+) |
244270
271+ ### Indexing
272+ | Operation | Description |
273+ | -----------| -------------|
274+ | ` arr[i, j, ...] ` | Load scalar element from ` TileArray ` |
275+ | ` arr[i, j, ...] = val ` | Store scalar element to ` TileArray ` |
276+ | ` tile[i, j, ...] ` | Extract scalar from ` Tile ` |
277+ | ` setindex(tile, val, i, j, ...) ` | Return new ` Tile ` with element replaced |
278+
245279### Atomics
246280| Operation | Description |
247281| -----------| -------------|
@@ -257,11 +291,52 @@ ct.scatter(arr, indices, tile; mask=active_mask)
257291All atomics accept ` memory_order ` (default: ` ct.MemoryOrder.AcqRel ` ) and
258292` memory_scope ` (default: ` ct.MemScope.Device ` ) keyword arguments.
259293
294+ ### Performance Tuning
295+
296+ #### Kernel configuration
297+
298+ ` ct.@compiler_options ` sets optimization hints inside a kernel function body:
299+
300+ ``` julia
301+ function matmul (A, B, C, ... )
302+ ct. @compiler_options num_ctas= ct. ByTarget (v " 10.0" => 2 ) occupancy= 8
303+ ...
304+ end
305+ ```
306+
307+ | Option | Description | Valid values |
308+ | --------| -------------| --------------|
309+ | ` num_ctas ` | Number of CTAs in a CGA | Powers of 2 |
310+ | ` occupancy ` | Target concurrent CTAs per SM | 1–32 |
311+ | ` opt_level ` | Optimization level | 0–3 |
312+
313+ Values can be plain scalars or ` ct.ByTarget(...) ` for per-architecture dispatch.
314+ ` ByTarget ` maps compute capabilities to values, with an optional default:
315+
316+ ``` julia
317+ ct. @compiler_options num_ctas= ct. ByTarget (v " 10.0" => 4 , v " 12.0" => 2 ; default= 1 )
318+ ```
319+
320+ Hints can also be passed as keyword arguments to ` ct.launch ` or ` ct.code_tiled ` ,
321+ which take precedence over ` @compiler_options ` .
322+
323+ #### Load/store hints
324+
325+ ` ct.load ` and ` ct.store ` accept optional keyword arguments that influence memory
326+ traffic scheduling:
327+
328+ | Hint | Description |
329+ | ------| -------------|
330+ | ` latency ` | DRAM traffic weight hint, integer 1 (low) to 10 (high). Default: compiler-inferred. |
331+ | ` allow_tma ` | Whether to allow Tensor Memory Accelerator lowering. Default: allowed. |
332+
260333### Debugging
334+
261335| Operation | Description |
262336| -----------| -------------|
263337| ` print(args...) ` | Print values (Base overlay) |
264338| ` println(args...) ` | Print values with newline (Base overlay) |
339+ | ` ct.@assert cond [msg] ` | Abort kernel if condition is false |
265340
266341Standard Julia ` print ` /` println ` work inside kernels. String constants and tiles
267342can be mixed freely; format specifiers are inferred from element types at compile
@@ -270,9 +345,27 @@ time. String interpolation is supported.
270345``` julia
271346println (" Block " , ct. bid (1 ), " : tile=" , tile)
272347println (" result=$result " ) # string interpolation
348+ ct. @assert idx <= n " index out of bounds"
273349```
274350
275- This is a debugging aid and is not optimized for performance.
351+ These are debugging aids and are not optimized for performance.
352+
353+ ### Code Inspection
354+
355+ Beyond ` ct.code_tiled ` and ` ct.@code_tiled ` shown above, cuTile.jl provides
356+ ` @device_code_* ` macros that intercept compilation during ` ct.launch ` :
357+
358+ ``` julia
359+ ct. @device_code_tiled ct. launch (vadd, grid, a, b, c, ct. Constant (16 ))
360+ ct. @device_code_typed ct. launch (vadd, grid, a, b, c, ct. Constant (16 ))
361+ ct. @device_code_structured ct. launch (vadd, grid, a, b, c, ct. Constant (16 ))
362+ ```
363+
364+ | Macro | Output |
365+ | -------| --------|
366+ | ` ct.@device_code_tiled ` | Final Tile IR (MLIR textual format) |
367+ | ` ct.@device_code_typed ` | Typed Julia IR after overlay resolution |
368+ | ` ct.@device_code_structured ` | Structured IR (after control-flow structurization) |
276369
277370
278371## Differences from cuTile Python
312405### Optimization hints
313406
314407Python passes optimization hints as ` @ct.kernel ` decorator arguments. Julia uses
315- ` ct.@compiler_options ` inside the function body (like ` @inline ` ):
408+ ` ct.@compiler_options ` inside the function body (like ` @inline ` ). See
409+ [ Performance Tuning] ( #performance-tuning ) for full details.
316410
317411``` python
318412# Python
@@ -329,24 +423,6 @@ function matmul(A, B, C, ...)
329423end
330424```
331425
332- Supported options:
333-
334- | Option | Description | Valid values |
335- | --------| -------------| --------------|
336- | ` num_ctas ` | Number of CTAs in a CGA (cooperative group array) | Powers of 2 |
337- | ` occupancy ` | Target occupancy (number of concurrent CTAs per SM) | 1–32 |
338- | ` opt_level ` | Optimization level | 0–3 |
339-
340- Values can be plain scalars or ` ct.ByTarget(...) ` for per-architecture dispatch.
341- ` ByTarget ` maps compute capabilities to values, with an optional default:
342-
343- ``` julia
344- ct. @compiler_options num_ctas= ct. ByTarget (v " 10.0" => 4 , v " 12.0" => 2 ; default= 1 )
345- ```
346-
347- Hints can also be passed as keyword arguments to ` ct.launch ` or ` ct.code_tiled ` ,
348- which take precedence over ` @compiler_options ` .
349-
350426### Launch Syntax
351427
352428cuTile.jl implicitly uses the current task-bound stream from CUDA.jl:
@@ -515,7 +591,7 @@ standard Julia silently produce truncated or wrapped results instead:
515591 throwing ` InexactError ` for non-integer or out-of-range values. Use
516592 ` unsafe_trunc ` for the explicit non-throwing primitive.
517593
518- Assertions may be added in the future for testing purposes .
594+ Use ` ct.@assert ` to add runtime checks in kernels (see Debugging above) .
519595
520596
521597## Host-level operations
0 commit comments