@@ -38,37 +38,20 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
3838
3939 # Build parameter list, handling ghost types, const args, and struct destructuring
4040 param_types = TypeId[]
41- param_mapping = Tuple{Int, Union{Nothing, Symbol }}[]
41+ param_mapping = Tuple{Int, Vector{Int }}[]
4242
4343 for (i, argtype) in enumerate (sci. argtypes)
4444 argtype_unwrapped = CC. widenconst (argtype)
4545 if is_ghost_type (argtype_unwrapped)
4646 continue
4747 elseif is_const_arg (i)
4848 continue # const arg: no kernel parameter
49- elseif should_destructure (argtype_unwrapped)
50- # Destructure TileArray into flat parameters
51- params = argtype_unwrapped. parameters
52- ndims = params[2 ]:: Integer
53- for fi in 1 : fieldcount (argtype_unwrapped)
54- fname = fieldname (argtype_unwrapped, fi)
55- ftype = fieldtype (argtype_unwrapped, fi)
56- if fname === :sizes || fname === :strides
57- fcount = ndims
58- elem_type = Int32
59- else
60- fcount = flat_field_count (ftype)
61- elem_type = ftype <: Ptr ? Ptr{params[1 ]} : (ftype <: Tuple ? eltype (ftype) : ftype)
62- end
63- for _ in 1 : fcount
64- push! (param_types, tile_type_for_julia! (ctx, elem_type))
65- push! (param_mapping, (i, fname))
66- end
67- end
68- ctx. arg_types[i] = argtype_unwrapped
69- else
49+ elseif isprimitivetype (argtype_unwrapped)
7050 push! (param_types, tile_type_for_julia! (ctx, argtype_unwrapped))
71- push! (param_mapping, (i, nothing ))
51+ push! (param_mapping, (i, Int[]))
52+ else
53+ flatten_struct_params! (ctx, param_types, param_mapping, i, argtype_unwrapped, Int[])
54+ ctx. arg_types[i] = argtype_unwrapped
7255 end
7356 end
7457
@@ -90,7 +73,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
9073 arg_values = make_block_args! (cb, length (param_types))
9174
9275 # Build arg_flat_values map
93- field_values = Dict {Tuple{Int, Union{Nothing, Symbol }}, Vector{Value}} ()
76+ field_values = Dict {Tuple{Int, Vector{Int }}, Vector{Value}} ()
9477 for (param_idx, val) in enumerate (arg_values)
9578 key = param_mapping[param_idx]
9679 if ! haskey (field_values, key)
@@ -100,13 +83,12 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
10083 end
10184
10285 # Store in context and set up slot/argument mappings
103- # arg_idx is the direct index into argtypes (2, 3, ...) which matches SlotNumber/Argument
10486 for (key, values) in field_values
105- arg_idx, field = key
87+ arg_idx, path = key
10688 ctx. arg_flat_values[key] = values
10789
108- if field === nothing
109- # Regular argument - create concrete CGVal
90+ if isempty (path) && ! haskey (ctx . arg_types, arg_idx)
91+ # Regular (non-destructured) argument - create concrete CGVal
11092 if length (values) != 1
11193 throw (IRError (" Expected exactly one value for argument $arg_idx , got $(length (values)) " ))
11294 end
@@ -142,14 +124,14 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
142124
143125 # For destructured args, create lazy CGVals that track the argument index
144126 for (arg_idx, argtype) in ctx. arg_types
145- tv = arg_ref_value (arg_idx, Union{Symbol, Int} [], argtype)
127+ tv = arg_ref_value (arg_idx, Int[], argtype)
146128 ctx[SlotNumber (arg_idx)] = tv
147129 ctx[Argument (arg_idx)] = tv
148130 end
149131
150132 # Create TensorViews for all TileArray arguments at kernel entry
151- for (arg_idx, _ ) in ctx. arg_types
152- cache_tensor_view ! (ctx, arg_idx)
133+ for (arg_idx, argtype ) in ctx. arg_types
134+ create_tensor_views ! (ctx, arg_idx, argtype, Int[] )
153135 end
154136
155137 # Create memory ordering token
@@ -166,6 +148,46 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
166148 finalize_function! (func_buf, cb, writer. debug_info)
167149end
168150
151+ """
152+ create_tensor_views!(ctx, arg_idx, T, path)
153+
154+ Walk the type tree and create TensorViews for all nested TileArrays.
155+ """
156+ function create_tensor_views! (ctx:: CGCtx , arg_idx:: Int , @nospecialize (T), path:: Vector{Int} )
157+ if T <: TileArray
158+ cache_tensor_view! (ctx, arg_idx, path, T)
159+ else
160+ for fi in 1 : fieldcount (T)
161+ ftype = fieldtype (T, fi)
162+ (is_ghost_type (ftype) || isprimitivetype (ftype)) && continue
163+ field_path = [path... , fi]
164+ create_tensor_views! (ctx, arg_idx, ftype, field_path)
165+ end
166+ end
167+ end
168+
169+ """
170+ flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, T, path)
171+
172+ Recursively flatten a struct type into kernel parameters.
173+ """
174+ function flatten_struct_params! (ctx, param_types, param_mapping, arg_idx, @nospecialize (T), path:: Vector{Int} )
175+ for fi in 1 : fieldcount (T)
176+ ftype = fieldtype (T, fi)
177+ field_path = [path... , fi]
178+ if is_ghost_type (ftype)
179+ continue
180+ elseif isprimitivetype (ftype)
181+ type_id = tile_type_for_julia! (ctx, ftype; throw_error= false )
182+ type_id === nothing && continue
183+ push! (param_types, type_id)
184+ push! (param_mapping, (arg_idx, field_path))
185+ else
186+ flatten_struct_params! (ctx, param_types, param_mapping, arg_idx, ftype, field_path)
187+ end
188+ end
189+ end
190+
169191# getfield for destructured arguments (lazy chain extension)
170192function emit_getfield! (ctx:: CGCtx , args, @nospecialize (result_type))
171193 length (args) >= 2 || return nothing
@@ -192,31 +214,18 @@ function emit_getfield!(ctx::CGCtx, args, @nospecialize(result_type))
192214 if obj_tv != = nothing && is_arg_ref (obj_tv)
193215 arg_idx, chain = obj_tv. arg_ref
194216
195- if field isa Symbol
196- # Field access: extend chain with symbol
197- new_chain = Union{Symbol, Int}[chain... , field]
198- # Check if this resolves to a scalar field (auto-materialize leaf)
199- # Don't auto-materialize tuple types - they need indexing first
200- rt = CC. widenconst (result_type)
201- if ! (rt <: Tuple )
202- values = get_arg_flat_values (ctx, arg_idx, field)
203- if values != = nothing && length (values) == 1
204- # Scalar field - materialize immediately
205- type_id = tile_type_for_julia! (ctx, rt)
206- return CGVal (values[1 ], type_id, rt)
207- end
208- end
209- return arg_ref_value (arg_idx, new_chain, rt)
210- elseif field isa Integer && ! isempty (chain) && chain[end ] isa Symbol
211- # Tuple indexing: chain ends with field name, now indexing into it
212- # This is a leaf - materialize immediately
213- field_name = chain[end ]
214- values = get_arg_flat_values (ctx, arg_idx, field_name)
215- if values != = nothing && 1 <= field <= length (values)
216- type_id = tile_type_for_julia! (ctx, CC. widenconst (result_type))
217- return CGVal (values[field], type_id, CC. widenconst (result_type))
218- end
217+ # Convert field to integer index
218+ idx = if field isa Symbol
219+ obj_type = CC. widenconst (obj_tv. jltype)
220+ Base. fieldindex (obj_type, field)
221+ elseif field isa Integer
222+ Int (field)
223+ else
224+ nothing
219225 end
226+ idx === nothing && return nothing
227+
228+ return resolve_arg_ref (ctx, arg_idx, chain, idx, CC. widenconst (result_type))
220229 end
221230
222231 nothing
@@ -237,24 +246,10 @@ function emit_getindex!(ctx::CGCtx, args, @nospecialize(result_type))
237246 obj_tv = emit_value! (ctx, obj_arg)
238247 obj_tv === nothing && return nothing
239248
240- # If obj is a lazy arg_ref, try to materialize or extend the chain
249+ # If obj is a lazy arg_ref, extend the chain with the index
241250 if is_arg_ref (obj_tv)
242251 arg_idx, chain = obj_tv. arg_ref
243-
244- # If chain ends with a symbol (field name), we're indexing into a tuple field
245- # Try to materialize immediately
246- if ! isempty (chain) && chain[end ] isa Symbol
247- field_name = chain[end ]
248- values = get_arg_flat_values (ctx, arg_idx, field_name)
249- if values != = nothing && 1 <= index <= length (values)
250- type_id = tile_type_for_julia! (ctx, CC. widenconst (result_type))
251- return CGVal (values[index], type_id, CC. widenconst (result_type))
252- end
253- end
254-
255- # Otherwise extend the chain
256- new_chain = Union{Symbol, Int}[chain... , Int (index)]
257- return arg_ref_value (arg_idx, new_chain, CC. widenconst (result_type))
252+ return resolve_arg_ref (ctx, arg_idx, chain, Int (index), CC. widenconst (result_type))
258253 end
259254
260255 # Not an arg_ref - not handled here
0 commit comments