2727 ASTTransformerFuncContext ,
2828)
2929from quadrants .lang .exception import (
30+ QuadrantsCompilationError ,
3031 QuadrantsSyntaxError ,
3132)
3233from quadrants .lang .matrix import MatrixType
@@ -189,18 +190,12 @@ def _transform_as_kernel(ctx: ASTTransformerFuncContext, node: ast.FunctionDef,
189190
190191 @staticmethod
191192 def _predeclare_struct_ndarrays (ctx : ASTTransformerFuncContext ) -> None :
192- """Walk template args that are structs and pre-declare any ``Ndarray`` attributes as kernel args (via
193- ``decl_ndarray_arg``) so they are registered before ``finalize_params``. The resulting ``AnyArray`` objects are
194- cached on the global context for later lookup by ``build_Attribute``.
193+ """Reject template-annotated struct args that contain ndarrays.
195194
196- Also stores ``(arg_id, template_arg_idx, attr_chain )`` tuples in
197- ``ctx.global_context.struct_ndarray_launch_info`` so the launch path can populate the corresponding slots in the
198- launch context .
195+ Passing ndarrays through ``qd.template( )`` structs bypasses argument pruning (every ndarray field is registered
196+ regardless of whether the kernel uses it) and inflates the cached launch context, causing measurable launch
197+ overhead. Users should annotate such parameters with a concrete ``@dataclass`` type instead .
199198 """
200- from quadrants .lang .util import cook_dtype # pylint: disable=C0415
201-
202- cache = ctx .global_context .ndarray_to_any_array
203- launch_info = ctx .global_context .struct_ndarray_launch_info
204199
205200 def _walk_obj (obj , arg_idx , path ):
206201 if dataclasses .is_dataclass (obj ) and not isinstance (obj , type ):
@@ -209,36 +204,33 @@ def _walk_obj(obj, arg_idx, path):
209204 if isinstance (child , _TensorClass ):
210205 child = child ._unwrap ()
211206 if isinstance (child , _ndarray .Ndarray ):
212- _register_ndarray (child , arg_idx , (* path , field .name ))
207+ param_name = ctx .func .arg_metas [arg_idx ].name
208+ attr_path = "." .join ((* path , field .name ))
209+ raise QuadrantsCompilationError (
210+ f"Kernel parameter '{ param_name } ' is annotated as qd.template(), but "
211+ f"'{ param_name } .{ attr_path } ' is a qd.ndarray. Passing ndarrays through "
212+ f"template structs is not supported because it bypasses argument pruning "
213+ f"and degrades launch performance. Use a concrete struct annotation "
214+ f"(e.g. a @dataclass type hint) instead of qd.template() for struct "
215+ f"parameters that contain ndarrays."
216+ )
213217 elif dataclasses .is_dataclass (child ) and not isinstance (child , type ):
214218 _walk_obj (child , arg_idx , (* path , field .name ))
215219 else :
216220 for attr_name , attr_val in vars (obj ).items ():
217221 if isinstance (attr_val , _TensorClass ):
218222 attr_val = attr_val ._unwrap ()
219223 if isinstance (attr_val , _ndarray .Ndarray ):
220- _register_ndarray (attr_val , arg_idx , (* path , attr_name ))
221-
222- def _register_ndarray (nd , arg_idx , attr_chain ):
223- key = id (nd )
224- if key in cache :
225- return
226- from quadrants ._lib import core as _qd_core # pylint: disable=C0415
227-
228- element_type = cook_dtype (nd .element_type )
229- ndim = len (nd ._physical_shape )
230- needs_grad = nd .grad is not None
231- layout = getattr (nd , "_qd_layout" , None )
232- name = f"__qd_struct_nd_{ key } "
233- arg_id_vec = impl .get_runtime ().compiling_callable .insert_ndarray_param (
234- element_type , ndim , name , needs_grad
235- )
236- arr = any_array .AnyArray (
237- _qd_core .make_external_tensor_expr (element_type , ndim , arg_id_vec , needs_grad , BoundaryMode .UNSAFE ),
238- _qd_layout = layout ,
239- )
240- cache [key ] = arr
241- launch_info .append ((arg_id_vec [0 ], arg_idx , attr_chain ))
224+ param_name = ctx .func .arg_metas [arg_idx ].name
225+ attr_path = "." .join ((* path , attr_name ))
226+ raise QuadrantsCompilationError (
227+ f"Kernel parameter '{ param_name } ' is annotated as qd.template(), but "
228+ f"'{ param_name } .{ attr_path } ' is a qd.ndarray. Passing ndarrays through "
229+ f"template structs is not supported because it bypasses argument pruning "
230+ f"and degrades launch performance. Use a concrete struct annotation "
231+ f"(e.g. a @dataclass type hint) instead of qd.template() for struct "
232+ f"parameters that contain ndarrays."
233+ )
242234
243235 assert ctx .py_args is not None
244236 for i , arg_meta in enumerate (ctx .func .arg_metas ):
0 commit comments