Skip to content

Commit e33c254

Browse files
maleadtclaude
andcommitted
Generalize struct destructuring to support arbitrary nested structs
The compiler's struct destructuring was TileArray-specific. This generalizes it so any isbits struct with non-ghost fields (Ptr, Tuple, nested structs) is recursively flattened into kernel parameters. Key changes: - `should_destructure` accepts any qualifying struct, not just TileArray - `arg_flat_values` uses path-based keys (Vector{Union{Symbol,Int}}) instead of single field names - Recursive `flatten_struct_params!` / `_flatten_tuple_param!` handle nested structs and heterogeneous tuples - `cache_tensor_view!` accepts a path for nested TileArrays - `make_tensor_view` resolves lazy arg refs for nested TileArrays - Host-side `flatten()` recursively flattens any destructurable struct This enables passing structs containing TileArrays (e.g. Broadcasted) directly as kernel arguments. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 21ab724 commit e33c254

4 files changed

Lines changed: 246 additions & 80 deletions

File tree

ext/CUDAExt.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module CUDAExt
33
using cuTile
44
using cuTile: TileArray, Constant, ByTarget, CGOpts, CuTileResults, DEFAULT_BYTECODE_VERSION,
55
emit_code, sanitize_name, constant_eltype, constant_value, is_ghost_type,
6-
resolve_hint, format_sm_arch, validate_hint
6+
should_destructure, resolve_hint, format_sm_arch, validate_hint
77

88
using CompilerCaching: CacheView, method_instance, results
99

@@ -254,8 +254,26 @@ return their fields in order.
254254
255255
This is used by the launch helper to splat arguments to cudacall.
256256
"""
257-
flatten(x) = is_ghost_type(typeof(x)) ? () : (x,)
258-
flatten(arr::TileArray{T, N}) where {T, N} = (arr.ptr, arr.sizes..., arr.strides...)
257+
function flatten(x)
258+
T = typeof(x)
259+
is_ghost_type(T) && return ()
260+
cuTile.should_destructure(T) || return (x,)
261+
result = Any[]
262+
for fi in 1:fieldcount(T)
263+
fval = getfield(x, fi)
264+
fT = typeof(fval)
265+
if is_ghost_type(fT) || !cuTile._is_kernel_param_type(fT)
266+
continue
267+
elseif fval isa Tuple
268+
for elem in fval
269+
push!(result, flatten(elem)...)
270+
end
271+
else
272+
push!(result, flatten(fval)...)
273+
end
274+
end
275+
return Tuple(result)
276+
end
259277

260278
"""
261279
to_tile_arg(x)

src/compiler/codegen/kernel.jl

Lines changed: 114 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
3838

3939
# Build parameter list, handling ghost types, const args, and struct destructuring
4040
param_types = TypeId[]
41-
param_mapping = Tuple{Int, Union{Nothing, Symbol}}[]
41+
param_mapping = Tuple{Int, Vector{Union{Symbol, Int}}}[]
4242

4343
for (i, argtype) in enumerate(sci.argtypes)
4444
argtype_unwrapped = CC.widenconst(argtype)
@@ -47,28 +47,11 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
4747
elseif is_const_arg(i)
4848
continue # const arg: no kernel parameter
4949
elseif should_destructure(argtype_unwrapped)
50-
# Destructure TileArray into flat parameters
51-
params = argtype_unwrapped.parameters
52-
ndims = params[2]::Integer
53-
for fi in 1:fieldcount(argtype_unwrapped)
54-
fname = fieldname(argtype_unwrapped, fi)
55-
ftype = fieldtype(argtype_unwrapped, fi)
56-
if fname === :sizes || fname === :strides
57-
fcount = ndims
58-
elem_type = Int32
59-
else
60-
fcount = flat_field_count(ftype)
61-
elem_type = ftype <: Ptr ? Ptr{params[1]} : (ftype <: Tuple ? eltype(ftype) : ftype)
62-
end
63-
for _ in 1:fcount
64-
push!(param_types, tile_type_for_julia!(ctx, elem_type))
65-
push!(param_mapping, (i, fname))
66-
end
67-
end
50+
flatten_struct_params!(ctx, param_types, param_mapping, i, argtype_unwrapped, Union{Symbol, Int}[])
6851
ctx.arg_types[i] = argtype_unwrapped
6952
else
7053
push!(param_types, tile_type_for_julia!(ctx, argtype_unwrapped))
71-
push!(param_mapping, (i, nothing))
54+
push!(param_mapping, (i, Union{Symbol, Int}[]))
7255
end
7356
end
7457

@@ -90,7 +73,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
9073
arg_values = make_block_args!(cb, length(param_types))
9174

9275
# Build arg_flat_values map
93-
field_values = Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}}()
76+
field_values = Dict{Tuple{Int, Vector{Union{Symbol, Int}}}, Vector{Value}}()
9477
for (param_idx, val) in enumerate(arg_values)
9578
key = param_mapping[param_idx]
9679
if !haskey(field_values, key)
@@ -100,13 +83,12 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
10083
end
10184

10285
# Store in context and set up slot/argument mappings
103-
# arg_idx is the direct index into argtypes (2, 3, ...) which matches SlotNumber/Argument
10486
for (key, values) in field_values
105-
arg_idx, field = key
87+
arg_idx, path = key
10688
ctx.arg_flat_values[key] = values
10789

108-
if field === nothing
109-
# Regular argument - create concrete CGVal
90+
if isempty(path) && !haskey(ctx.arg_types, arg_idx)
91+
# Regular (non-destructured) argument - create concrete CGVal
11092
if length(values) != 1
11193
throw(IRError("Expected exactly one value for argument $arg_idx, got $(length(values))"))
11294
end
@@ -148,8 +130,9 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
148130
end
149131

150132
# Create TensorViews for all TileArray arguments at kernel entry
151-
for (arg_idx, _) in ctx.arg_types
152-
cache_tensor_view!(ctx, arg_idx)
133+
# Walk the type tree to find all nested TileArrays
134+
for (arg_idx, argtype) in ctx.arg_types
135+
_create_tensor_views_recursive!(ctx, arg_idx, argtype, Union{Symbol, Int}[])
153136
end
154137

155138
# Create memory ordering token
@@ -166,6 +149,96 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
166149
finalize_function!(func_buf, cb, writer.debug_info)
167150
end
168151

152+
"""
153+
_create_tensor_views_recursive!(ctx, arg_idx, T, path)
154+
155+
Walk the type tree and create TensorViews for all nested TileArrays.
156+
"""
157+
function _create_tensor_views_recursive!(ctx::CGCtx, arg_idx::Int, @nospecialize(T), path::Vector{Union{Symbol, Int}})
158+
if T <: TileArray
159+
cache_tensor_view!(ctx, arg_idx, path, T)
160+
elseif T <: Tuple
161+
for i in 1:length(T.parameters)
162+
ptype = T.parameters[i]
163+
elem_path = Union{Symbol, Int}[path..., i]
164+
_create_tensor_views_recursive!(ctx, arg_idx, ptype, elem_path)
165+
end
166+
else
167+
for fi in 1:fieldcount(T)
168+
ftype = fieldtype(T, fi)
169+
is_ghost_type(ftype) && continue
170+
fname = fieldname(T, fi)
171+
field_path = Union{Symbol, Int}[path..., fname]
172+
if ftype <: TileArray
173+
cache_tensor_view!(ctx, arg_idx, field_path, ftype)
174+
elseif should_destructure(ftype) || ftype <: Tuple
175+
_create_tensor_views_recursive!(ctx, arg_idx, ftype, field_path)
176+
end
177+
end
178+
end
179+
end
180+
181+
"""
182+
flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, T, path)
183+
184+
Recursively flatten a struct type into kernel parameters.
185+
"""
186+
function flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, @nospecialize(T), path::Vector{Union{Symbol, Int}})
187+
for fi in 1:fieldcount(T)
188+
fname = fieldname(T, fi)
189+
ftype = fieldtype(T, fi)
190+
field_path = Union{Symbol, Int}[path..., fname]
191+
if is_ghost_type(ftype)
192+
continue
193+
elseif ftype <: Tuple && ftype !== Tuple{}
194+
_flatten_tuple_param!(ctx, param_types, param_mapping, arg_idx, ftype, field_path)
195+
elseif should_destructure(ftype)
196+
flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, ftype, field_path)
197+
else
198+
# Scalar/Ptr field → 1 flat param (skip if no tile IR type)
199+
type_id = tile_type_for_julia!(ctx, ftype; throw_error=false)
200+
type_id === nothing && continue
201+
push!(param_types, type_id)
202+
push!(param_mapping, (arg_idx, field_path))
203+
end
204+
end
205+
end
206+
207+
"""
208+
Flatten a Tuple field. If all elements are the same primitive type, emit N grouped
209+
params under field_path. Otherwise, index each element and recurse if needed.
210+
"""
211+
function _flatten_tuple_param!(ctx, param_types, param_mapping, arg_idx, @nospecialize(ftype), field_path)
212+
N = length(ftype.parameters)
213+
# Check if this is a homogeneous tuple of simple leaf types (NTuple{N, T})
214+
et = eltype(ftype)
215+
if isconcretetype(et) && !is_ghost_type(et) && !should_destructure(et) && !(et <: Tuple)
216+
# Simple case: all elements are the same leaf type (e.g., NTuple{N, Int32})
217+
for _ in 1:N
218+
push!(param_types, tile_type_for_julia!(ctx, et))
219+
push!(param_mapping, (arg_idx, field_path))
220+
end
221+
else
222+
# Heterogeneous or complex tuple: recurse per element
223+
for i in 1:N
224+
elem_type = ftype.parameters[i]
225+
elem_path = Union{Symbol, Int}[field_path..., i]
226+
if is_ghost_type(elem_type)
227+
continue
228+
elseif elem_type <: Tuple && elem_type !== Tuple{}
229+
_flatten_tuple_param!(ctx, param_types, param_mapping, arg_idx, elem_type, elem_path)
230+
elseif should_destructure(elem_type)
231+
flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, elem_type, elem_path)
232+
else
233+
type_id = tile_type_for_julia!(ctx, elem_type; throw_error=false)
234+
type_id === nothing && continue
235+
push!(param_types, type_id)
236+
push!(param_mapping, (arg_idx, elem_path))
237+
end
238+
end
239+
end
240+
end
241+
169242
# getfield for destructured arguments (lazy chain extension)
170243
function emit_getfield!(ctx::CGCtx, args, @nospecialize(result_type))
171244
length(args) >= 2 || return nothing
@@ -195,27 +268,28 @@ function emit_getfield!(ctx::CGCtx, args, @nospecialize(result_type))
195268
if field isa Symbol
196269
# Field access: extend chain with symbol
197270
new_chain = Union{Symbol, Int}[chain..., field]
198-
# Check if this resolves to a scalar field (auto-materialize leaf)
199-
# Don't auto-materialize tuple types - they need indexing first
200271
rt = CC.widenconst(result_type)
201272
if !(rt <: Tuple)
202-
values = get_arg_flat_values(ctx, arg_idx, field)
273+
# Check if this path resolves to flat values
274+
values = get_arg_flat_values(ctx, arg_idx, new_chain)
203275
if values !== nothing && length(values) == 1
204276
# Scalar field - materialize immediately
205277
type_id = tile_type_for_julia!(ctx, rt)
206278
return CGVal(values[1], type_id, rt)
207279
end
208280
end
209281
return arg_ref_value(arg_idx, new_chain, rt)
210-
elseif field isa Integer && !isempty(chain) && chain[end] isa Symbol
211-
# Tuple indexing: chain ends with field name, now indexing into it
212-
# This is a leaf - materialize immediately
213-
field_name = chain[end]
214-
values = get_arg_flat_values(ctx, arg_idx, field_name)
282+
elseif field isa Integer
283+
# Tuple indexing into a field
284+
# Look up the parent path (which should be a tuple field)
285+
values = get_arg_flat_values(ctx, arg_idx, chain)
215286
if values !== nothing && 1 <= field <= length(values)
216287
type_id = tile_type_for_julia!(ctx, CC.widenconst(result_type))
217288
return CGVal(values[field], type_id, CC.widenconst(result_type))
218289
end
290+
# Not a simple flat tuple — extend chain (element may be destructured)
291+
new_chain = Union{Symbol, Int}[chain..., Int(field)]
292+
return arg_ref_value(arg_idx, new_chain, CC.widenconst(result_type))
219293
end
220294
end
221295

@@ -241,15 +315,11 @@ function emit_getindex!(ctx::CGCtx, args, @nospecialize(result_type))
241315
if is_arg_ref(obj_tv)
242316
arg_idx, chain = obj_tv.arg_ref
243317

244-
# If chain ends with a symbol (field name), we're indexing into a tuple field
245-
# Try to materialize immediately
246-
if !isempty(chain) && chain[end] isa Symbol
247-
field_name = chain[end]
248-
values = get_arg_flat_values(ctx, arg_idx, field_name)
249-
if values !== nothing && 1 <= index <= length(values)
250-
type_id = tile_type_for_julia!(ctx, CC.widenconst(result_type))
251-
return CGVal(values[index], type_id, CC.widenconst(result_type))
252-
end
318+
# Try to materialize: the chain should point to a tuple field
319+
values = get_arg_flat_values(ctx, arg_idx, chain)
320+
if values !== nothing && 1 <= index <= length(values)
321+
type_id = tile_type_for_julia!(ctx, CC.widenconst(result_type))
322+
return CGVal(values[index], type_id, CC.widenconst(result_type))
253323
end
254324

255325
# Otherwise extend the chain

src/compiler/codegen/utils.jl

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,14 @@ mutable struct CGCtx
151151
slots::Dict{Int, CGVal} # Slot number -> CGVal
152152
block_args::Dict{Int, CGVal} # BlockArg id -> CGVal (for control flow)
153153

154-
# Destructured argument handling (for TileArray fields)
155-
arg_flat_values::Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}}
154+
# Destructured argument handling: path-keyed flat values
155+
# Key: (arg_idx, path) where path is e.g. [:ptr] or [:a, :sizes]
156+
arg_flat_values::Dict{Tuple{Int, Vector{Union{Symbol, Int}}}, Vector{Value}}
156157
arg_types::Dict{Int, Type}
157158

158-
# Cached TensorViews for TileArray arguments (arg_idx -> (Value, TypeId))
159-
tensor_views::Dict{Int, Tuple{Value, TypeId}}
159+
# Cached TensorViews for TileArray arguments
160+
# Key: arg_idx::Int for top-level, or (arg_idx, path) for nested
161+
tensor_views::Dict{Any, Tuple{Value, TypeId}}
160162

161163
# Bytecode infrastructure
162164
cb::CodeBuilder
@@ -188,9 +190,9 @@ function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode,
188190
Dict{Int, CGVal}(),
189191
Dict{Int, CGVal}(),
190192
Dict{Int, CGVal}(),
191-
Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}}(),
193+
Dict{Tuple{Int, Vector{Union{Symbol, Int}}}, Vector{Value}}(),
192194
Dict{Int, Type}(),
193-
Dict{Int, Tuple{Value, TypeId}}(),
195+
Dict{Any, Tuple{Value, TypeId}}(),
194196
cb, tt, sci, token, token_type, type_cache, sm_arch, cache,
195197
)
196198
end
@@ -238,12 +240,27 @@ end
238240
=============================================================================#
239241

240242
"""
241-
get_arg_flat_values(ctx, arg_idx, field=nothing) -> Union{Vector{Value}, Nothing}
243+
get_arg_flat_values(ctx, arg_idx, path) -> Union{Vector{Value}, Nothing}
242244
243-
Get the flat Tile IR values for an argument or its field.
245+
Get the flat Tile IR values for a destructured argument at the given path.
244246
"""
245-
function get_arg_flat_values(ctx::CGCtx, arg_idx::Int, field::Union{Nothing, Symbol}=nothing)
246-
get(ctx.arg_flat_values, (arg_idx, field), nothing)
247+
function get_arg_flat_values(ctx::CGCtx, arg_idx::Int, path::Vector{Union{Symbol, Int}})
248+
get(ctx.arg_flat_values, (arg_idx, path), nothing)
249+
end
250+
251+
# Convenience: single field name
252+
function get_arg_flat_values(ctx::CGCtx, arg_idx::Int, field::Symbol)
253+
get_arg_flat_values(ctx, arg_idx, Union{Symbol, Int}[field])
254+
end
255+
256+
# Convenience: no path = top-level
257+
function get_arg_flat_values(ctx::CGCtx, arg_idx::Int)
258+
# Collect all values for this arg_idx across all paths
259+
result = Value[]
260+
for ((idx, _path), vals) in ctx.arg_flat_values
261+
idx == arg_idx && append!(result, vals)
262+
end
263+
isempty(result) ? nothing : result
247264
end
248265

249266
"""
@@ -385,23 +402,61 @@ end
385402
should_destructure(T) -> Bool
386403
387404
Check if a type should be destructured into flat parameters.
405+
Any isbits struct with non-ghost, non-primitive fields qualifies.
388406
"""
389407
function should_destructure(@nospecialize(T))
390408
T = CC.widenconst(T)
391409
isstructtype(T) || return false
392410
is_ghost_type(T) && return false
393411
isprimitivetype(T) && return false
394-
T <: TileArray && return true
412+
T <: Tuple && return false # Tuples are handled as flat params, not recursed
413+
# Must have a concrete layout we can iterate
414+
try fieldcount(T) catch; return false end
415+
for fi in 1:fieldcount(T)
416+
ft = fieldtype(T, fi)
417+
_is_kernel_param_type(ft) && return true
418+
end
419+
return false
420+
end
421+
422+
# Check if a field type contributes kernel parameters (recursively).
423+
# Only isbits types (or types containing them) can be kernel parameters.
424+
function _is_kernel_param_type(@nospecialize(ft))
425+
is_ghost_type(ft) && return false
426+
isprimitivetype(ft) && return true
427+
ft <: Ptr && return true
428+
if ft <: Tuple && ft !== Tuple{}
429+
# Check if any tuple element is a kernel param type
430+
for p in ft.parameters
431+
_is_kernel_param_type(p) && return true
432+
end
433+
return false
434+
end
435+
# For structs, only recurse if isbits (prevents infinite recursion on DataType etc.)
436+
isstructtype(ft) && isbitstype(ft) && should_destructure(ft) && return true
395437
return false
396438
end
397439

398440
"""
399441
flat_field_count(T) -> Int
400442
401-
Count flat parameters a type expands to.
443+
Count flat parameters a type expands to (recursive).
402444
"""
403-
flat_field_count(::Type{<:NTuple{N, T}}) where {N, T} = N
404-
flat_field_count(::Type) = 1
445+
function flat_field_count(@nospecialize(T))
446+
if is_ghost_type(T)
447+
return 0
448+
elseif should_destructure(T)
449+
count = 0
450+
for fi in 1:fieldcount(T)
451+
count += flat_field_count(fieldtype(T, fi))
452+
end
453+
return count
454+
elseif T <: Tuple
455+
return length(T.parameters)
456+
else
457+
return 1
458+
end
459+
end
405460

406461
#-----------------------------------------------------------------------------
407462
# Argument helpers

0 commit comments

Comments
 (0)