From 309c0d61f94ead627542ece1f1d2c1493265ad10 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 7 Apr 2026 12:33:51 +0000 Subject: [PATCH 1/5] Add Constant(Type) support and fix static_parameter codegen Enable `Constant(T)` where `T` is a type (e.g., `Constant(Int)`) to produce `Constant{Type{T}, T}` instead of `Constant{DataType, T}`, so method dispatch correctly binds type parameters. Handle `:static_parameter` expressions in codegen by looking up concrete values from the method's sptypes, unwrapping VarState and Const wrappers. Also fix `IndexedUseRef` qualification in walk_uses! and guard scalar_elim_pass! against TokenType annotations. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/expressions.jl | 8 +++++ src/compiler/passes/canonicalize.jl | 4 +++ src/compiler/utils.jl | 2 +- src/language/types.jl | 3 +- test/codegen/integration.jl | 45 +++++++++++++++++++++++++++++ test/codegen/reflection.jl | 33 +++++++++++++++++++++ 6 files changed, 93 insertions(+), 2 deletions(-) diff --git a/src/compiler/codegen/expressions.jl b/src/compiler/codegen/expressions.jl index ab157072..ca9c521e 100644 --- a/src/compiler/codegen/expressions.jl +++ b/src/compiler/codegen/expressions.jl @@ -20,6 +20,14 @@ function emit_expr!(ctx::CGCtx, expr::Expr, @nospecialize(result_type)) # Bounds checking is always disabled in Tile IR kernels. # Emit false so IfOps referencing this SSA can resolve the condition. return emit_constant!(ctx, false, Bool) + elseif expr.head === :static_parameter + # Static type parameter reference (e.g., V in `f(::T{V}) where {V}`). + # Look up the concrete value from the method's sptypes. + idx = expr.args[1]::Int + sp = ctx.sci.sptypes[idx] + sptyp = sp isa CC.VarState ? sp.typ : sp + val = sptyp isa CC.Const ? sptyp.val : CC.widenconst(sptyp) + return emit_value!(ctx, val) elseif expr.head === :code_coverage_effect return nothing else diff --git a/src/compiler/passes/canonicalize.jl b/src/compiler/passes/canonicalize.jl index f6dedb03..dbf6fbda 100644 --- a/src/compiler/passes/canonicalize.jl +++ b/src/compiler/passes/canonicalize.jl @@ -126,6 +126,7 @@ function scalar_elim_block!(block::Block) current_type = value_type(inst) current_type === nothing && continue + is_token_type(current_type) && continue T = CC.widenconst(current_type) T <: Tile && continue # already tile-typed T <: Number || continue # only promote scalar number types @@ -133,6 +134,7 @@ function scalar_elim_block!(block::Block) for op in ops op_type = value_type(block, op) op_type === nothing && continue + is_token_type(op_type) && continue OT = CC.widenconst(op_type) OT <: Tile || continue S = OT.parameters[2] @@ -162,6 +164,7 @@ function scalar_elim_block!(block::Block) for inst in instructions(block) current_type = value_type(inst) current_type === nothing && continue + is_token_type(current_type) && continue new_type = promote_scalar_type(CC.widenconst(current_type)) new_type === nothing && continue update_type!(block, inst, new_type) @@ -170,6 +173,7 @@ function scalar_elim_block!(block::Block) # Phase 5: Promote block argument types (loop IVs, carries). # BlockArgument is immutable, so we create a new one and replace all uses. for (i, arg) in enumerate(block.args) + is_token_type(arg.type) && continue T = CC.widenconst(arg.type) T <: Tile && continue T <: Number || continue diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index e57b8204..c9b6e96a 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -77,7 +77,7 @@ end # walk_uses! extensions so that IRStructurizer's uses()/replace_uses! see # operands inside cuTile-specific IR nodes. IRStructurizer.walk_uses!(f, node::JoinTokensNode) = - for i in 1:length(node.tokens); f(IndexedUseRef(node.tokens, i)); end + for i in 1:length(node.tokens); f(IRStructurizer.IndexedUseRef(node.tokens, i)); end IRStructurizer.walk_uses!(f, ::TokenResultNode) = nothing IRStructurizer.walk_uses!(f, ::MakeTokenNode) = nothing diff --git a/src/language/types.jl b/src/language/types.jl index f246e54d..0523b25b 100644 --- a/src/language/types.jl +++ b/src/language/types.jl @@ -316,8 +316,9 @@ argtypes = Tuple{Ptr{Float32}, Constant{Int, 16}} """ struct Constant{T, V} end -# Convenience constructor that infers type from value +# Convenience constructors that infer type from value Constant(val::T) where {T} = Constant{T, val}() +Constant(val::Type{T}) where {T} = Constant{Type{T}, T}() # Extract constant value - @inline ensures this folds to a constant in IR @inline Base.getindex(::Constant{T, V}) where {T, V} = V diff --git a/test/codegen/integration.jl b/test/codegen/integration.jl index c25baf2f..b1c2dfce 100644 --- a/test/codegen/integration.jl +++ b/test/codegen/integration.jl @@ -1262,6 +1262,51 @@ end end end +#============================================================================= + Constant Type Arguments +=============================================================================# + +@testset "Constant Type Arguments" begin + spec = ct.ArraySpec{1}(16, true) + + function _type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T + pid = ct.bid(1) + tile = ct.load(a, pid, (tile_size,)) + ct.store(b, pid, tile) + return + end + + @testset "Type parameter via static_parameter" begin + @test @filecheck begin + @check_label "entry" + @check "load_view_tko" + @check "store_view_tko" + code_tiled(_type_param_kernel, + Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec}, + ct.Constant{Int,16}, ct.Constant{Type{Nothing},Nothing}}) + end + end + + # Test that Constant(Type) constructor produces correct types + function _use_type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T + pid = ct.bid(1) + tile = ct.load(a, pid, (tile_size,)) + ct.store(b, pid, tile) + return + end + + @testset "Constant(Type) via convenience constructor" begin + @test @filecheck begin + @check_label "entry" + @check "load_view_tko" + @check "store_view_tko" + code_tiled(_use_type_param_kernel, + Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec}, + ct.Constant{Int,16}, ct.Constant{Type{Float32},Float32}}) + end + end +end + #============================================================================= For Loops =============================================================================# diff --git a/test/codegen/reflection.jl b/test/codegen/reflection.jl index 51a73c3a..5e8cd5e5 100644 --- a/test/codegen/reflection.jl +++ b/test/codegen/reflection.jl @@ -104,3 +104,36 @@ end end end end + +@testset "Constant Type args" begin + const_spec = ct.ArraySpec{1}(128, true, (0,), (32,)) + + @testset "Constant(Type) constructor" begin + @test typeof(ct.Constant(Int)) === ct.Constant{Type{Int}, Int} + @test typeof(ct.Constant(Nothing)) === ct.Constant{Type{Nothing}, Nothing} + @test typeof(ct.Constant(Float32)) === ct.Constant{Type{Float32}, Float32} + # Non-type values still work as before + @test typeof(ct.Constant(42)) === ct.Constant{Int, 42} + end + + @testset "code_tiled with Constant Type parameter" begin + function reflect_type_param(a, b, c, tile_size::Int, ::Type{T}) where T + pid = ct.bid(1) + tile_a = ct.load(a; index=pid, shape=(tile_size,)) + tile_b = ct.load(b; index=pid, shape=(tile_size,)) + ct.store(c; index=pid, tile=tile_a + tile_b) + return + end + + ConstTypeTT = Tuple{ct.TileArray{Float32,1,const_spec}, ct.TileArray{Float32,1,const_spec}, + ct.TileArray{Float32,1,const_spec}, ct.Constant{Int64, 16}, + ct.Constant{Type{Nothing}, Nothing}} + + @test @filecheck begin + @check "load_view_tko" + @check "addf" + @check "store_view_tko" + ct.code_tiled(reflect_type_param, ConstTypeTT) + end + end +end From 7e27f2187cc759a4c9bbd186977392e6e95852ed Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 7 Apr 2026 13:05:55 +0000 Subject: [PATCH 2/5] update tests --- test/codegen/reflection.jl | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test/codegen/reflection.jl b/test/codegen/reflection.jl index 5e8cd5e5..247b6616 100644 --- a/test/codegen/reflection.jl +++ b/test/codegen/reflection.jl @@ -108,13 +108,7 @@ end @testset "Constant Type args" begin const_spec = ct.ArraySpec{1}(128, true, (0,), (32,)) - @testset "Constant(Type) constructor" begin - @test typeof(ct.Constant(Int)) === ct.Constant{Type{Int}, Int} - @test typeof(ct.Constant(Nothing)) === ct.Constant{Type{Nothing}, Nothing} - @test typeof(ct.Constant(Float32)) === ct.Constant{Type{Float32}, Float32} - # Non-type values still work as before - @test typeof(ct.Constant(42)) === ct.Constant{Int, 42} - end + @test ct.Constant(Int) isa ct.Constant{Type{Int}, Int} @testset "code_tiled with Constant Type parameter" begin function reflect_type_param(a, b, c, tile_size::Int, ::Type{T}) where T From 9b13bf7fe73c1188ec1cb5c115750e48ea75deec Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 7 Apr 2026 14:19:19 +0000 Subject: [PATCH 3/5] improve tests --- test/codegen/integration.jl | 24 +++--------------------- test/codegen/reflection.jl | 4 ++-- 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/test/codegen/integration.jl b/test/codegen/integration.jl index b1c2dfce..803baceb 100644 --- a/test/codegen/integration.jl +++ b/test/codegen/integration.jl @@ -1271,36 +1271,18 @@ end function _type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T pid = ct.bid(1) - tile = ct.load(a, pid, (tile_size,)) + tile = ct.load(a, pid, (tile_size,)) + zeros(T, (tile_size,)) ct.store(b, pid, tile) return end - @testset "Type parameter via static_parameter" begin + @testset "Type parameter used in kernel body" begin @test @filecheck begin @check_label "entry" @check "load_view_tko" + @check "addf" @check "store_view_tko" code_tiled(_type_param_kernel, - Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec}, - ct.Constant{Int,16}, ct.Constant{Type{Nothing},Nothing}}) - end - end - - # Test that Constant(Type) constructor produces correct types - function _use_type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T - pid = ct.bid(1) - tile = ct.load(a, pid, (tile_size,)) - ct.store(b, pid, tile) - return - end - - @testset "Constant(Type) via convenience constructor" begin - @test @filecheck begin - @check_label "entry" - @check "load_view_tko" - @check "store_view_tko" - code_tiled(_use_type_param_kernel, Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec}, ct.Constant{Int,16}, ct.Constant{Type{Float32},Float32}}) end diff --git a/test/codegen/reflection.jl b/test/codegen/reflection.jl index 247b6616..eeb86dd3 100644 --- a/test/codegen/reflection.jl +++ b/test/codegen/reflection.jl @@ -115,13 +115,13 @@ end pid = ct.bid(1) tile_a = ct.load(a; index=pid, shape=(tile_size,)) tile_b = ct.load(b; index=pid, shape=(tile_size,)) - ct.store(c; index=pid, tile=tile_a + tile_b) + ct.store(c; index=pid, tile=tile_a + tile_b + zeros(T, (tile_size,))) return end ConstTypeTT = Tuple{ct.TileArray{Float32,1,const_spec}, ct.TileArray{Float32,1,const_spec}, ct.TileArray{Float32,1,const_spec}, ct.Constant{Int64, 16}, - ct.Constant{Type{Nothing}, Nothing}} + ct.Constant{Type{Float32}, Float32}} @test @filecheck begin @check "load_view_tko" From f83501979ef0ad012d97dd74559af75e8fa629d0 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 7 Apr 2026 14:53:30 +0000 Subject: [PATCH 4/5] Automatically wrap `Type` arguments with `Constant` --- ext/CUDAExt.jl | 1 + src/compiler/reflection.jl | 10 ++++++++-- test/codegen/integration.jl | 2 +- test/codegen/reflection.jl | 6 +++--- test/device/core.jl | 20 ++++++++++++++++++++ 5 files changed, 33 insertions(+), 6 deletions(-) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 4708dd07..8006d5e2 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -257,6 +257,7 @@ Other values pass through unchanged. """ to_tile_arg(x) = x to_tile_arg(arr::AbstractArray) = TileArray(arr) +to_tile_arg(t::Type) = Constant(t) # Tiled Broadcast — TiledStyle wins over CuArrayStyle BroadcastStyle(::cuTile.TiledStyle{N}, ::CuArrayStyle{M}) where {N,M} = cuTile.TiledStyle{max(N,M)}() diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index fed7afb8..3f991771 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -82,7 +82,7 @@ Returns `(stripped, nothing)` when no Constant types are present. function process_const_argtypes(@nospecialize(f), @nospecialize(argtypes)) params = argtypes isa DataType ? argtypes.parameters : argtypes isa Tuple ? argtypes : fieldtypes(argtypes) - has_consts = any(T -> T <: Constant, params) + has_consts = any(T -> T <: Constant || CC.isconstType(T), params) stripped_params = map(params) do T T <: Constant ? constant_eltype(T) : T end @@ -90,7 +90,13 @@ function process_const_argtypes(@nospecialize(f), @nospecialize(argtypes)) const_argtypes = if has_consts cats = Any[CC.Const(f)] for T in params - push!(cats, T <: Constant ? CC.Const(constant_value(T)) : T) + if T <: Constant + push!(cats, CC.Const(constant_value(T))) + elseif CC.isconstType(T) + push!(cats, CC.Const(T.parameters[1])) + else + push!(cats, T) + end end cats else diff --git a/test/codegen/integration.jl b/test/codegen/integration.jl index 803baceb..f9581b54 100644 --- a/test/codegen/integration.jl +++ b/test/codegen/integration.jl @@ -1284,7 +1284,7 @@ end @check "store_view_tko" code_tiled(_type_param_kernel, Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec}, - ct.Constant{Int,16}, ct.Constant{Type{Float32},Float32}}) + ct.Constant{Int,16}, Type{Float32}}) end end end diff --git a/test/codegen/reflection.jl b/test/codegen/reflection.jl index eeb86dd3..13e626b8 100644 --- a/test/codegen/reflection.jl +++ b/test/codegen/reflection.jl @@ -105,12 +105,12 @@ end end end -@testset "Constant Type args" begin +@testset "Type args" begin const_spec = ct.ArraySpec{1}(128, true, (0,), (32,)) @test ct.Constant(Int) isa ct.Constant{Type{Int}, Int} - @testset "code_tiled with Constant Type parameter" begin + @testset "code_tiled with Type parameter" begin function reflect_type_param(a, b, c, tile_size::Int, ::Type{T}) where T pid = ct.bid(1) tile_a = ct.load(a; index=pid, shape=(tile_size,)) @@ -121,7 +121,7 @@ end ConstTypeTT = Tuple{ct.TileArray{Float32,1,const_spec}, ct.TileArray{Float32,1,const_spec}, ct.TileArray{Float32,1,const_spec}, ct.Constant{Int64, 16}, - ct.Constant{Type{Float32}, Float32}} + Type{Float32}} @test @filecheck begin @check "load_view_tko" diff --git a/test/device/core.jl b/test/device/core.jl index 5674fcf2..991b50fe 100644 --- a/test/device/core.jl +++ b/test/device/core.jl @@ -387,6 +387,26 @@ end @test Array(b) ≈ Array(a) end +@testset "Type parameter (auto-wrapped)" begin + function vadd_type_param(a, b, c, tile_size::Int, ::Type{T}) where T + pid = ct.bid(1) + tile_a = ct.load(a; index=pid, shape=(tile_size,)) + tile_b = ct.load(b; index=pid, shape=(tile_size,)) + ct.store(c; index=pid, tile=T.(tile_a) + T.(tile_b)) + return + end + + n = 1024 + tile_size = 32 + a = CUDA.rand(Float32, n) + b = CUDA.rand(Float32, n) + c = CUDA.zeros(Float32, n) + + ct.launch(vadd_type_param, cld(n, tile_size), a, b, c, ct.Constant(tile_size), Float32) + + @test Array(c) ≈ Array(a) + Array(b) +end + end @testset "TileArray auto-conversion" begin From c00cf40ae722b312de20bb242490e06b3176e5b9 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 7 Apr 2026 14:56:28 +0000 Subject: [PATCH 5/5] update execution test --- test/device/core.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/device/core.jl b/test/device/core.jl index 991b50fe..70868682 100644 --- a/test/device/core.jl +++ b/test/device/core.jl @@ -398,13 +398,13 @@ end n = 1024 tile_size = 32 - a = CUDA.rand(Float32, n) - b = CUDA.rand(Float32, n) + a = CUDA.rand(Float16, n) + b = CUDA.rand(Float16, n) c = CUDA.zeros(Float32, n) ct.launch(vadd_type_param, cld(n, tile_size), a, b, c, ct.Constant(tile_size), Float32) - @test Array(c) ≈ Array(a) + Array(b) + @test Array(c) ≈ Float32.(Array(a)) + Float32.(Array(b)) end end