@@ -38,7 +38,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
3838
3939 # Build parameter list, handling ghost types, const args, and struct destructuring
4040 param_types = TypeId[]
41- param_mapping = Tuple{Int, Union{Nothing, Symbol }}[]
41+ param_mapping = Tuple{Int, Vector{ Union{Symbol, Int} }}[]
4242
4343 for (i, argtype) in enumerate (sci. argtypes)
4444 argtype_unwrapped = CC. widenconst (argtype)
@@ -47,28 +47,11 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
4747 elseif is_const_arg (i)
4848 continue # const arg: no kernel parameter
4949 elseif should_destructure (argtype_unwrapped)
50- # Destructure TileArray into flat parameters
51- params = argtype_unwrapped. parameters
52- ndims = params[2 ]:: Integer
53- for fi in 1 : fieldcount (argtype_unwrapped)
54- fname = fieldname (argtype_unwrapped, fi)
55- ftype = fieldtype (argtype_unwrapped, fi)
56- if fname === :sizes || fname === :strides
57- fcount = ndims
58- elem_type = Int32
59- else
60- fcount = flat_field_count (ftype)
61- elem_type = ftype <: Ptr ? Ptr{params[1 ]} : (ftype <: Tuple ? eltype (ftype) : ftype)
62- end
63- for _ in 1 : fcount
64- push! (param_types, tile_type_for_julia! (ctx, elem_type))
65- push! (param_mapping, (i, fname))
66- end
67- end
50+ flatten_struct_params! (ctx, param_types, param_mapping, i, argtype_unwrapped, Union{Symbol, Int}[])
6851 ctx. arg_types[i] = argtype_unwrapped
6952 else
7053 push! (param_types, tile_type_for_julia! (ctx, argtype_unwrapped))
71- push! (param_mapping, (i, nothing ))
54+ push! (param_mapping, (i, Union{Symbol, Int}[] ))
7255 end
7356 end
7457
@@ -90,7 +73,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
9073 arg_values = make_block_args! (cb, length (param_types))
9174
9275 # Build arg_flat_values map
93- field_values = Dict {Tuple{Int, Union{Nothing, Symbol }}, Vector{Value}} ()
76+ field_values = Dict{Tuple{Int, Vector{ Union{Symbol, Int} }}, Vector{Value}}()
9477 for (param_idx, val) in enumerate (arg_values)
9578 key = param_mapping[param_idx]
9679 if ! haskey (field_values, key)
@@ -100,13 +83,12 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
10083 end
10184
10285 # Store in context and set up slot/argument mappings
103- # arg_idx is the direct index into argtypes (2, 3, ...) which matches SlotNumber/Argument
10486 for (key, values) in field_values
105- arg_idx, field = key
87+ arg_idx, path = key
10688 ctx. arg_flat_values[key] = values
10789
108- if field === nothing
109- # Regular argument - create concrete CGVal
90+ if isempty (path) && ! haskey (ctx . arg_types, arg_idx)
91+ # Regular (non-destructured) argument - create concrete CGVal
11092 if length (values) != 1
11193 throw (IRError (" Expected exactly one value for argument $arg_idx , got $(length (values)) " ))
11294 end
@@ -148,8 +130,9 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
148130 end
149131
150132 # Create TensorViews for all TileArray arguments at kernel entry
151- for (arg_idx, _) in ctx. arg_types
152- cache_tensor_view! (ctx, arg_idx)
133+ # Walk the type tree to find all nested TileArrays
134+ for (arg_idx, argtype) in ctx. arg_types
135+ _create_tensor_views_recursive! (ctx, arg_idx, argtype, Union{Symbol, Int}[])
153136 end
154137
155138 # Create memory ordering token
@@ -166,6 +149,96 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
166149 finalize_function! (func_buf, cb, writer. debug_info)
167150end
168151
152+ """
153+ _create_tensor_views_recursive!(ctx, arg_idx, T, path)
154+
155+ Walk the type tree and create TensorViews for all nested TileArrays.
156+ """
157+ function _create_tensor_views_recursive! (ctx:: CGCtx , arg_idx:: Int , @nospecialize (T), path:: Vector{Union{Symbol, Int}} )
158+ if T <: TileArray
159+ cache_tensor_view! (ctx, arg_idx, path, T)
160+ elseif T <: Tuple
161+ for i in 1 : length (T. parameters)
162+ ptype = T. parameters[i]
163+ elem_path = Union{Symbol, Int}[path... , i]
164+ _create_tensor_views_recursive! (ctx, arg_idx, ptype, elem_path)
165+ end
166+ else
167+ for fi in 1 : fieldcount (T)
168+ ftype = fieldtype (T, fi)
169+ is_ghost_type (ftype) && continue
170+ fname = fieldname (T, fi)
171+ field_path = Union{Symbol, Int}[path... , fname]
172+ if ftype <: TileArray
173+ cache_tensor_view! (ctx, arg_idx, field_path, ftype)
174+ elseif should_destructure (ftype) || ftype <: Tuple
175+ _create_tensor_views_recursive! (ctx, arg_idx, ftype, field_path)
176+ end
177+ end
178+ end
179+ end
180+
181+ """
182+ flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, T, path)
183+
184+ Recursively flatten a struct type into kernel parameters.
185+ """
186+ function flatten_struct_params! (ctx, param_types, param_mapping, arg_idx, @nospecialize (T), path:: Vector{Union{Symbol, Int}} )
187+ for fi in 1 : fieldcount (T)
188+ fname = fieldname (T, fi)
189+ ftype = fieldtype (T, fi)
190+ field_path = Union{Symbol, Int}[path... , fname]
191+ if is_ghost_type (ftype)
192+ continue
193+ elseif ftype <: Tuple && ftype != = Tuple{}
194+ _flatten_tuple_param! (ctx, param_types, param_mapping, arg_idx, ftype, field_path)
195+ elseif should_destructure (ftype)
196+ flatten_struct_params! (ctx, param_types, param_mapping, arg_idx, ftype, field_path)
197+ else
198+ # Scalar/Ptr field → 1 flat param (skip if no tile IR type)
199+ type_id = tile_type_for_julia! (ctx, ftype; throw_error= false )
200+ type_id === nothing && continue
201+ push! (param_types, type_id)
202+ push! (param_mapping, (arg_idx, field_path))
203+ end
204+ end
205+ end
206+
207+ """
208+ Flatten a Tuple field. If all elements are the same primitive type, emit N grouped
209+ params under field_path. Otherwise, index each element and recurse if needed.
210+ """
211+ function _flatten_tuple_param! (ctx, param_types, param_mapping, arg_idx, @nospecialize (ftype), field_path)
212+ N = length (ftype. parameters)
213+ # Check if this is a homogeneous tuple of simple leaf types (NTuple{N, T})
214+ et = eltype (ftype)
215+ if isconcretetype (et) && ! is_ghost_type (et) && ! should_destructure (et) && ! (et <: Tuple )
216+ # Simple case: all elements are the same leaf type (e.g., NTuple{N, Int32})
217+ for _ in 1 : N
218+ push! (param_types, tile_type_for_julia! (ctx, et))
219+ push! (param_mapping, (arg_idx, field_path))
220+ end
221+ else
222+ # Heterogeneous or complex tuple: recurse per element
223+ for i in 1 : N
224+ elem_type = ftype. parameters[i]
225+ elem_path = Union{Symbol, Int}[field_path... , i]
226+ if is_ghost_type (elem_type)
227+ continue
228+ elseif elem_type <: Tuple && elem_type != = Tuple{}
229+ _flatten_tuple_param! (ctx, param_types, param_mapping, arg_idx, elem_type, elem_path)
230+ elseif should_destructure (elem_type)
231+ flatten_struct_params! (ctx, param_types, param_mapping, arg_idx, elem_type, elem_path)
232+ else
233+ type_id = tile_type_for_julia! (ctx, elem_type; throw_error= false )
234+ type_id === nothing && continue
235+ push! (param_types, type_id)
236+ push! (param_mapping, (arg_idx, elem_path))
237+ end
238+ end
239+ end
240+ end
241+
169242# getfield for destructured arguments (lazy chain extension)
170243function emit_getfield! (ctx:: CGCtx , args, @nospecialize (result_type))
171244 length (args) >= 2 || return nothing
@@ -195,27 +268,28 @@ function emit_getfield!(ctx::CGCtx, args, @nospecialize(result_type))
195268 if field isa Symbol
196269 # Field access: extend chain with symbol
197270 new_chain = Union{Symbol, Int}[chain... , field]
198- # Check if this resolves to a scalar field (auto-materialize leaf)
199- # Don't auto-materialize tuple types - they need indexing first
200271 rt = CC. widenconst (result_type)
201272 if ! (rt <: Tuple )
202- values = get_arg_flat_values (ctx, arg_idx, field)
273+ # Check if this path resolves to flat values
274+ values = get_arg_flat_values (ctx, arg_idx, new_chain)
203275 if values != = nothing && length (values) == 1
204276 # Scalar field - materialize immediately
205277 type_id = tile_type_for_julia! (ctx, rt)
206278 return CGVal (values[1 ], type_id, rt)
207279 end
208280 end
209281 return arg_ref_value (arg_idx, new_chain, rt)
210- elseif field isa Integer && ! isempty (chain) && chain[end ] isa Symbol
211- # Tuple indexing: chain ends with field name, now indexing into it
212- # This is a leaf - materialize immediately
213- field_name = chain[end ]
214- values = get_arg_flat_values (ctx, arg_idx, field_name)
282+ elseif field isa Integer
283+ # Tuple indexing into a field
284+ # Look up the parent path (which should be a tuple field)
285+ values = get_arg_flat_values (ctx, arg_idx, chain)
215286 if values != = nothing && 1 <= field <= length (values)
216287 type_id = tile_type_for_julia! (ctx, CC. widenconst (result_type))
217288 return CGVal (values[field], type_id, CC. widenconst (result_type))
218289 end
290+ # Not a simple flat tuple — extend chain (element may be destructured)
291+ new_chain = Union{Symbol, Int}[chain... , Int (field)]
292+ return arg_ref_value (arg_idx, new_chain, CC. widenconst (result_type))
219293 end
220294 end
221295
@@ -241,15 +315,11 @@ function emit_getindex!(ctx::CGCtx, args, @nospecialize(result_type))
241315 if is_arg_ref (obj_tv)
242316 arg_idx, chain = obj_tv. arg_ref
243317
244- # If chain ends with a symbol (field name), we're indexing into a tuple field
245- # Try to materialize immediately
246- if ! isempty (chain) && chain[end ] isa Symbol
247- field_name = chain[end ]
248- values = get_arg_flat_values (ctx, arg_idx, field_name)
249- if values != = nothing && 1 <= index <= length (values)
250- type_id = tile_type_for_julia! (ctx, CC. widenconst (result_type))
251- return CGVal (values[index], type_id, CC. widenconst (result_type))
252- end
318+ # Try to materialize: the chain should point to a tuple field
319+ values = get_arg_flat_values (ctx, arg_idx, chain)
320+ if values != = nothing && 1 <= index <= length (values)
321+ type_id = tile_type_for_julia! (ctx, CC. widenconst (result_type))
322+ return CGVal (values[index], type_id, CC. widenconst (result_type))
253323 end
254324
255325 # Otherwise extend the chain
0 commit comments