Skip to content

Commit 1ca30bb

Browse files
AntonOrestenclaude
andauthored
Add support for Type arguments and fix static parameter codegen (#181)
Enable `Constant(T)` where `T` is a type (e.g., `Constant(Int)`) to produce `Constant{Type{T}, T}` instead of `Constant{DataType, T}`, so method dispatch correctly binds type parameters. Automatically wrap `Type` arguments with `Constant` Handle `:static_parameter` expressions in codegen by looking up concrete values from the method's sptypes, unwrapping VarState and Const wrappers. Also fix `IndexedUseRef` qualification in walk_uses! and guard scalar_elim_pass! against TokenType annotations. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 126e201 commit 1ca30bb

9 files changed

Lines changed: 98 additions & 4 deletions

File tree

ext/CUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ Other values pass through unchanged.
257257
"""
258258
to_tile_arg(x) = x
259259
to_tile_arg(arr::AbstractArray) = TileArray(arr)
260+
to_tile_arg(t::Type) = Constant(t)
260261

261262
# Tiled Broadcast — TiledStyle wins over CuArrayStyle
262263
BroadcastStyle(::cuTile.TiledStyle{N}, ::CuArrayStyle{M}) where {N,M} = cuTile.TiledStyle{max(N,M)}()

src/compiler/codegen/expressions.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ function emit_expr!(ctx::CGCtx, expr::Expr, @nospecialize(result_type))
2020
# Bounds checking is always disabled in Tile IR kernels.
2121
# Emit false so IfOps referencing this SSA can resolve the condition.
2222
return emit_constant!(ctx, false, Bool)
23+
elseif expr.head === :static_parameter
24+
# Static type parameter reference (e.g., V in `f(::T{V}) where {V}`).
25+
# Look up the concrete value from the method's sptypes.
26+
idx = expr.args[1]::Int
27+
sp = ctx.sci.sptypes[idx]
28+
sptyp = sp isa CC.VarState ? sp.typ : sp
29+
val = sptyp isa CC.Const ? sptyp.val : CC.widenconst(sptyp)
30+
return emit_value!(ctx, val)
2331
elseif expr.head === :code_coverage_effect
2432
return nothing
2533
else

src/compiler/passes/canonicalize.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,15 @@ function scalar_elim_block!(block::Block)
131131

132132
current_type = value_type(inst)
133133
current_type === nothing && continue
134+
is_token_type(current_type) && continue
134135
T = CC.widenconst(current_type)
135136
T <: Tile && continue # already tile-typed
136137
T <: Number || continue # only promote scalar number types
137138

138139
for op in ops
139140
op_type = value_type(block, op)
140141
op_type === nothing && continue
142+
is_token_type(op_type) && continue
141143
OT = CC.widenconst(op_type)
142144
OT <: Tile || continue
143145
S = OT.parameters[2]
@@ -167,6 +169,7 @@ function scalar_elim_block!(block::Block)
167169
for inst in instructions(block)
168170
current_type = value_type(inst)
169171
current_type === nothing && continue
172+
is_token_type(current_type) && continue
170173
new_type = promote_scalar_type(CC.widenconst(current_type))
171174
new_type === nothing && continue
172175
update_type!(block, inst, new_type)
@@ -175,6 +178,7 @@ function scalar_elim_block!(block::Block)
175178
# Phase 5: Promote block argument types (loop IVs, carries).
176179
# BlockArgument is immutable, so we create a new one and replace all uses.
177180
for (i, arg) in enumerate(block.args)
181+
is_token_type(arg.type) && continue
178182
T = CC.widenconst(arg.type)
179183
T <: Tile && continue
180184
T <: Number || continue

src/compiler/reflection.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,21 @@ Returns `(stripped, nothing)` when no Constant types are present.
8282
function process_const_argtypes(@nospecialize(f), @nospecialize(argtypes))
8383
params = argtypes isa DataType ? argtypes.parameters :
8484
argtypes isa Tuple ? argtypes : fieldtypes(argtypes)
85-
has_consts = any(T -> T <: Constant, params)
85+
has_consts = any(T -> T <: Constant || CC.isconstType(T), params)
8686
stripped_params = map(params) do T
8787
T <: Constant ? constant_eltype(T) : T
8888
end
8989
stripped = Tuple{stripped_params...}
9090
const_argtypes = if has_consts
9191
cats = Any[CC.Const(f)]
9292
for T in params
93-
push!(cats, T <: Constant ? CC.Const(constant_value(T)) : T)
93+
if T <: Constant
94+
push!(cats, CC.Const(constant_value(T)))
95+
elseif CC.isconstType(T)
96+
push!(cats, CC.Const(T.parameters[1]))
97+
else
98+
push!(cats, T)
99+
end
94100
end
95101
cats
96102
else

src/compiler/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ end
7777
# walk_uses! extensions so that IRStructurizer's uses()/replace_uses! see
7878
# operands inside cuTile-specific IR nodes.
7979
IRStructurizer.walk_uses!(f, node::JoinTokensNode) =
80-
for i in 1:length(node.tokens); f(IndexedUseRef(node.tokens, i)); end
80+
for i in 1:length(node.tokens); f(IRStructurizer.IndexedUseRef(node.tokens, i)); end
8181
IRStructurizer.walk_uses!(f, ::TokenResultNode) = nothing
8282
IRStructurizer.walk_uses!(f, ::MakeTokenNode) = nothing
8383

src/language/types.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,9 @@ argtypes = Tuple{Ptr{Float32}, Constant{Int, 16}}
316316
"""
317317
struct Constant{T, V} end
318318

319-
# Convenience constructor that infers type from value
319+
# Convenience constructors that infer type from value
320320
Constant(val::T) where {T} = Constant{T, val}()
321+
Constant(val::Type{T}) where {T} = Constant{Type{T}, T}()
321322

322323
# Extract constant value - @inline ensures this folds to a constant in IR
323324
@inline Base.getindex(::Constant{T, V}) where {T, V} = V

test/codegen/integration.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,33 @@ end
12621262
end
12631263
end
12641264

1265+
#=============================================================================
1266+
Constant Type Arguments
1267+
=============================================================================#
1268+
1269+
@testset "Constant Type Arguments" begin
1270+
spec = ct.ArraySpec{1}(16, true)
1271+
1272+
function _type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T
1273+
pid = ct.bid(1)
1274+
tile = ct.load(a, pid, (tile_size,)) + zeros(T, (tile_size,))
1275+
ct.store(b, pid, tile)
1276+
return
1277+
end
1278+
1279+
@testset "Type parameter used in kernel body" begin
1280+
@test @filecheck begin
1281+
@check_label "entry"
1282+
@check "load_view_tko"
1283+
@check "addf"
1284+
@check "store_view_tko"
1285+
code_tiled(_type_param_kernel,
1286+
Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec},
1287+
ct.Constant{Int,16}, Type{Float32}})
1288+
end
1289+
end
1290+
end
1291+
12651292
#=============================================================================
12661293
For Loops
12671294
=============================================================================#

test/codegen/reflection.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,30 @@ end
104104
end
105105
end
106106
end
107+
108+
@testset "Type args" begin
109+
const_spec = ct.ArraySpec{1}(128, true, (0,), (32,))
110+
111+
@test ct.Constant(Int) isa ct.Constant{Type{Int}, Int}
112+
113+
@testset "code_tiled with Type parameter" begin
114+
function reflect_type_param(a, b, c, tile_size::Int, ::Type{T}) where T
115+
pid = ct.bid(1)
116+
tile_a = ct.load(a; index=pid, shape=(tile_size,))
117+
tile_b = ct.load(b; index=pid, shape=(tile_size,))
118+
ct.store(c; index=pid, tile=tile_a + tile_b + zeros(T, (tile_size,)))
119+
return
120+
end
121+
122+
ConstTypeTT = Tuple{ct.TileArray{Float32,1,const_spec}, ct.TileArray{Float32,1,const_spec},
123+
ct.TileArray{Float32,1,const_spec}, ct.Constant{Int64, 16},
124+
Type{Float32}}
125+
126+
@test @filecheck begin
127+
@check "load_view_tko"
128+
@check "addf"
129+
@check "store_view_tko"
130+
ct.code_tiled(reflect_type_param, ConstTypeTT)
131+
end
132+
end
133+
end

test/device/core.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,26 @@ end
387387
@test Array(b) Array(a)
388388
end
389389

390+
@testset "Type parameter (auto-wrapped)" begin
391+
function vadd_type_param(a, b, c, tile_size::Int, ::Type{T}) where T
392+
pid = ct.bid(1)
393+
tile_a = ct.load(a; index=pid, shape=(tile_size,))
394+
tile_b = ct.load(b; index=pid, shape=(tile_size,))
395+
ct.store(c; index=pid, tile=T.(tile_a) + T.(tile_b))
396+
return
397+
end
398+
399+
n = 1024
400+
tile_size = 32
401+
a = CUDA.rand(Float16, n)
402+
b = CUDA.rand(Float16, n)
403+
c = CUDA.zeros(Float32, n)
404+
405+
ct.launch(vadd_type_param, cld(n, tile_size), a, b, c, ct.Constant(tile_size), Float32)
406+
407+
@test Array(c) Float32.(Array(a)) + Float32.(Array(b))
408+
end
409+
390410
end
391411

392412
@testset "TileArray auto-conversion" begin

0 commit comments

Comments
 (0)