Skip to content

Commit c323ae7

Browse files
maleadtclaude
andauthored
Support destructuring arbitrary arguments (#128)
The compiler's struct destructuring was TileArray-specific. This generalizes it so any isbits struct is recursively flattened. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2f2d75f commit c323ae7

5 files changed

Lines changed: 312 additions & 133 deletions

File tree

ext/CUDAExt.jl

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module CUDAExt
22

33
using cuTile
4-
using cuTile: TileArray, Constant, ByTarget, CGOpts, CuTileResults, DEFAULT_BYTECODE_VERSION,
5-
emit_code, sanitize_name, constant_eltype, constant_value, is_ghost_type,
6-
resolve_hint, format_sm_arch, validate_hint
4+
using cuTile: TileArray, Constant, CuTileResults,
5+
emit_code, sanitize_name, constant_eltype, flatten,
6+
resolve_hint, format_sm_arch
77

88
using CompilerCaching: CacheView, method_instance, results
99

@@ -245,18 +245,6 @@ Returns e.g. `v"12.0"` for compute capability 12.0.
245245
"""
246246
default_sm_arch() = capability(device())
247247

248-
"""
249-
flatten(x)
250-
251-
Flatten a value into a tuple of its leaf fields for kernel launch.
252-
Scalars return themselves wrapped in a tuple. Structs like TileArray
253-
return their fields in order.
254-
255-
This is used by the launch helper to splat arguments to cudacall.
256-
"""
257-
flatten(x) = is_ghost_type(typeof(x)) ? () : (x,)
258-
flatten(arr::TileArray{T, N}) where {T, N} = (arr.ptr, arr.sizes..., arr.strides...)
259-
260248
"""
261249
to_tile_arg(x)
262250

src/compiler/codegen/kernel.jl

Lines changed: 66 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -38,37 +38,20 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
3838

3939
# Build parameter list, handling ghost types, const args, and struct destructuring
4040
param_types = TypeId[]
41-
param_mapping = Tuple{Int, Union{Nothing, Symbol}}[]
41+
param_mapping = Tuple{Int, Vector{Int}}[]
4242

4343
for (i, argtype) in enumerate(sci.argtypes)
4444
argtype_unwrapped = CC.widenconst(argtype)
4545
if is_ghost_type(argtype_unwrapped)
4646
continue
4747
elseif is_const_arg(i)
4848
continue # const arg: no kernel parameter
49-
elseif should_destructure(argtype_unwrapped)
50-
# Destructure TileArray into flat parameters
51-
params = argtype_unwrapped.parameters
52-
ndims = params[2]::Integer
53-
for fi in 1:fieldcount(argtype_unwrapped)
54-
fname = fieldname(argtype_unwrapped, fi)
55-
ftype = fieldtype(argtype_unwrapped, fi)
56-
if fname === :sizes || fname === :strides
57-
fcount = ndims
58-
elem_type = Int32
59-
else
60-
fcount = flat_field_count(ftype)
61-
elem_type = ftype <: Ptr ? Ptr{params[1]} : (ftype <: Tuple ? eltype(ftype) : ftype)
62-
end
63-
for _ in 1:fcount
64-
push!(param_types, tile_type_for_julia!(ctx, elem_type))
65-
push!(param_mapping, (i, fname))
66-
end
67-
end
68-
ctx.arg_types[i] = argtype_unwrapped
69-
else
49+
elseif isprimitivetype(argtype_unwrapped)
7050
push!(param_types, tile_type_for_julia!(ctx, argtype_unwrapped))
71-
push!(param_mapping, (i, nothing))
51+
push!(param_mapping, (i, Int[]))
52+
else
53+
flatten_struct_params!(ctx, param_types, param_mapping, i, argtype_unwrapped, Int[])
54+
ctx.arg_types[i] = argtype_unwrapped
7255
end
7356
end
7457

@@ -90,7 +73,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
9073
arg_values = make_block_args!(cb, length(param_types))
9174

9275
# Build arg_flat_values map
93-
field_values = Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}}()
76+
field_values = Dict{Tuple{Int, Vector{Int}}, Vector{Value}}()
9477
for (param_idx, val) in enumerate(arg_values)
9578
key = param_mapping[param_idx]
9679
if !haskey(field_values, key)
@@ -100,13 +83,12 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
10083
end
10184

10285
# Store in context and set up slot/argument mappings
103-
# arg_idx is the direct index into argtypes (2, 3, ...) which matches SlotNumber/Argument
10486
for (key, values) in field_values
105-
arg_idx, field = key
87+
arg_idx, path = key
10688
ctx.arg_flat_values[key] = values
10789

108-
if field === nothing
109-
# Regular argument - create concrete CGVal
90+
if isempty(path) && !haskey(ctx.arg_types, arg_idx)
91+
# Regular (non-destructured) argument - create concrete CGVal
11092
if length(values) != 1
11193
throw(IRError("Expected exactly one value for argument $arg_idx, got $(length(values))"))
11294
end
@@ -142,14 +124,14 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
142124

143125
# For destructured args, create lazy CGVals that track the argument index
144126
for (arg_idx, argtype) in ctx.arg_types
145-
tv = arg_ref_value(arg_idx, Union{Symbol, Int}[], argtype)
127+
tv = arg_ref_value(arg_idx, Int[], argtype)
146128
ctx[SlotNumber(arg_idx)] = tv
147129
ctx[Argument(arg_idx)] = tv
148130
end
149131

150132
# Create TensorViews for all TileArray arguments at kernel entry
151-
for (arg_idx, _) in ctx.arg_types
152-
cache_tensor_view!(ctx, arg_idx)
133+
for (arg_idx, argtype) in ctx.arg_types
134+
create_tensor_views!(ctx, arg_idx, argtype, Int[])
153135
end
154136

155137
# Create memory ordering token
@@ -166,6 +148,46 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
166148
finalize_function!(func_buf, cb, writer.debug_info)
167149
end
168150

151+
"""
152+
create_tensor_views!(ctx, arg_idx, T, path)
153+
154+
Walk the type tree and create TensorViews for all nested TileArrays.
155+
"""
156+
function create_tensor_views!(ctx::CGCtx, arg_idx::Int, @nospecialize(T), path::Vector{Int})
157+
if T <: TileArray
158+
cache_tensor_view!(ctx, arg_idx, path, T)
159+
else
160+
for fi in 1:fieldcount(T)
161+
ftype = fieldtype(T, fi)
162+
(is_ghost_type(ftype) || isprimitivetype(ftype)) && continue
163+
field_path = [path..., fi]
164+
create_tensor_views!(ctx, arg_idx, ftype, field_path)
165+
end
166+
end
167+
end
168+
169+
"""
170+
flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, T, path)
171+
172+
Recursively flatten a struct type into kernel parameters.
173+
"""
174+
function flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, @nospecialize(T), path::Vector{Int})
175+
for fi in 1:fieldcount(T)
176+
ftype = fieldtype(T, fi)
177+
field_path = [path..., fi]
178+
if is_ghost_type(ftype)
179+
continue
180+
elseif isprimitivetype(ftype)
181+
type_id = tile_type_for_julia!(ctx, ftype; throw_error=false)
182+
type_id === nothing && continue
183+
push!(param_types, type_id)
184+
push!(param_mapping, (arg_idx, field_path))
185+
else
186+
flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, ftype, field_path)
187+
end
188+
end
189+
end
190+
169191
# getfield for destructured arguments (lazy chain extension)
170192
function emit_getfield!(ctx::CGCtx, args, @nospecialize(result_type))
171193
length(args) >= 2 || return nothing
@@ -192,31 +214,18 @@ function emit_getfield!(ctx::CGCtx, args, @nospecialize(result_type))
192214
if obj_tv !== nothing && is_arg_ref(obj_tv)
193215
arg_idx, chain = obj_tv.arg_ref
194216

195-
if field isa Symbol
196-
# Field access: extend chain with symbol
197-
new_chain = Union{Symbol, Int}[chain..., field]
198-
# Check if this resolves to a scalar field (auto-materialize leaf)
199-
# Don't auto-materialize tuple types - they need indexing first
200-
rt = CC.widenconst(result_type)
201-
if !(rt <: Tuple)
202-
values = get_arg_flat_values(ctx, arg_idx, field)
203-
if values !== nothing && length(values) == 1
204-
# Scalar field - materialize immediately
205-
type_id = tile_type_for_julia!(ctx, rt)
206-
return CGVal(values[1], type_id, rt)
207-
end
208-
end
209-
return arg_ref_value(arg_idx, new_chain, rt)
210-
elseif field isa Integer && !isempty(chain) && chain[end] isa Symbol
211-
# Tuple indexing: chain ends with field name, now indexing into it
212-
# This is a leaf - materialize immediately
213-
field_name = chain[end]
214-
values = get_arg_flat_values(ctx, arg_idx, field_name)
215-
if values !== nothing && 1 <= field <= length(values)
216-
type_id = tile_type_for_julia!(ctx, CC.widenconst(result_type))
217-
return CGVal(values[field], type_id, CC.widenconst(result_type))
218-
end
217+
# Convert field to integer index
218+
idx = if field isa Symbol
219+
obj_type = CC.widenconst(obj_tv.jltype)
220+
Base.fieldindex(obj_type, field)
221+
elseif field isa Integer
222+
Int(field)
223+
else
224+
nothing
219225
end
226+
idx === nothing && return nothing
227+
228+
return resolve_arg_ref(ctx, arg_idx, chain, idx, CC.widenconst(result_type))
220229
end
221230

222231
nothing
@@ -237,24 +246,10 @@ function emit_getindex!(ctx::CGCtx, args, @nospecialize(result_type))
237246
obj_tv = emit_value!(ctx, obj_arg)
238247
obj_tv === nothing && return nothing
239248

240-
# If obj is a lazy arg_ref, try to materialize or extend the chain
249+
# If obj is a lazy arg_ref, extend the chain with the index
241250
if is_arg_ref(obj_tv)
242251
arg_idx, chain = obj_tv.arg_ref
243-
244-
# If chain ends with a symbol (field name), we're indexing into a tuple field
245-
# Try to materialize immediately
246-
if !isempty(chain) && chain[end] isa Symbol
247-
field_name = chain[end]
248-
values = get_arg_flat_values(ctx, arg_idx, field_name)
249-
if values !== nothing && 1 <= index <= length(values)
250-
type_id = tile_type_for_julia!(ctx, CC.widenconst(result_type))
251-
return CGVal(values[index], type_id, CC.widenconst(result_type))
252-
end
253-
end
254-
255-
# Otherwise extend the chain
256-
new_chain = Union{Symbol, Int}[chain..., Int(index)]
257-
return arg_ref_value(arg_idx, new_chain, CC.widenconst(result_type))
252+
return resolve_arg_ref(ctx, arg_idx, chain, Int(index), CC.widenconst(result_type))
258253
end
259254

260255
# Not an arg_ref - not handled here

0 commit comments

Comments
 (0)