Skip to content

Commit ae92ccc

Browse files
Reject ndarray fields inside template-annotated struct params
Passing structs containing ndarrays through qd.template() bypasses argument pruning — every ndarray field gets registered regardless of whether the kernel uses it — and inflates the cached launch context, causing a measured 42% launch overhead on real workloads. Raise a clear QuadrantsCompilationError guiding users to use a concrete dataclass type annotation instead. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent d0f06b2 commit ae92ccc

1 file changed

Lines changed: 25 additions & 33 deletions

File tree

python/quadrants/lang/ast/ast_transformers/function_def_transformer.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ASTTransformerFuncContext,
2828
)
2929
from quadrants.lang.exception import (
30+
QuadrantsCompilationError,
3031
QuadrantsSyntaxError,
3132
)
3233
from quadrants.lang.matrix import MatrixType
@@ -189,18 +190,12 @@ def _transform_as_kernel(ctx: ASTTransformerFuncContext, node: ast.FunctionDef,
189190

190191
@staticmethod
191192
def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None:
192-
"""Walk template args that are structs and pre-declare any ``Ndarray`` attributes as kernel args (via
193-
``decl_ndarray_arg``) so they are registered before ``finalize_params``. The resulting ``AnyArray`` objects are
194-
cached on the global context for later lookup by ``build_Attribute``.
193+
"""Reject template-annotated struct args that contain ndarrays.
195194
196-
Also stores ``(arg_id, template_arg_idx, attr_chain)`` tuples in
197-
``ctx.global_context.struct_ndarray_launch_info`` so the launch path can populate the corresponding slots in the
198-
launch context.
195+
Passing ndarrays through ``qd.template()`` structs bypasses argument pruning (every ndarray field is registered
196+
regardless of whether the kernel uses it) and inflates the cached launch context, causing measurable launch
197+
overhead. Users should annotate such parameters with a concrete ``@dataclass`` type instead.
199198
"""
200-
from quadrants.lang.util import cook_dtype # pylint: disable=C0415
201-
202-
cache = ctx.global_context.ndarray_to_any_array
203-
launch_info = ctx.global_context.struct_ndarray_launch_info
204199

205200
def _walk_obj(obj, arg_idx, path):
206201
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
@@ -209,36 +204,33 @@ def _walk_obj(obj, arg_idx, path):
209204
if isinstance(child, _TensorClass):
210205
child = child._unwrap()
211206
if isinstance(child, _ndarray.Ndarray):
212-
_register_ndarray(child, arg_idx, (*path, field.name))
207+
param_name = ctx.func.arg_metas[arg_idx].name
208+
attr_path = ".".join((*path, field.name))
209+
raise QuadrantsCompilationError(
210+
f"Kernel parameter '{param_name}' is annotated as qd.template(), but "
211+
f"'{param_name}.{attr_path}' is a qd.ndarray. Passing ndarrays through "
212+
f"template structs is not supported because it bypasses argument pruning "
213+
f"and degrades launch performance. Use a concrete struct annotation "
214+
f"(e.g. a @dataclass type hint) instead of qd.template() for struct "
215+
f"parameters that contain ndarrays."
216+
)
213217
elif dataclasses.is_dataclass(child) and not isinstance(child, type):
214218
_walk_obj(child, arg_idx, (*path, field.name))
215219
else:
216220
for attr_name, attr_val in vars(obj).items():
217221
if isinstance(attr_val, _TensorClass):
218222
attr_val = attr_val._unwrap()
219223
if isinstance(attr_val, _ndarray.Ndarray):
220-
_register_ndarray(attr_val, arg_idx, (*path, attr_name))
221-
222-
def _register_ndarray(nd, arg_idx, attr_chain):
223-
key = id(nd)
224-
if key in cache:
225-
return
226-
from quadrants._lib import core as _qd_core # pylint: disable=C0415
227-
228-
element_type = cook_dtype(nd.element_type)
229-
ndim = len(nd._physical_shape)
230-
needs_grad = nd.grad is not None
231-
layout = getattr(nd, "_qd_layout", None)
232-
name = f"__qd_struct_nd_{key}"
233-
arg_id_vec = impl.get_runtime().compiling_callable.insert_ndarray_param(
234-
element_type, ndim, name, needs_grad
235-
)
236-
arr = any_array.AnyArray(
237-
_qd_core.make_external_tensor_expr(element_type, ndim, arg_id_vec, needs_grad, BoundaryMode.UNSAFE),
238-
_qd_layout=layout,
239-
)
240-
cache[key] = arr
241-
launch_info.append((arg_id_vec[0], arg_idx, attr_chain))
224+
param_name = ctx.func.arg_metas[arg_idx].name
225+
attr_path = ".".join((*path, attr_name))
226+
raise QuadrantsCompilationError(
227+
f"Kernel parameter '{param_name}' is annotated as qd.template(), but "
228+
f"'{param_name}.{attr_path}' is a qd.ndarray. Passing ndarrays through "
229+
f"template structs is not supported because it bypasses argument pruning "
230+
f"and degrades launch performance. Use a concrete struct annotation "
231+
f"(e.g. a @dataclass type hint) instead of qd.template() for struct "
232+
f"parameters that contain ndarrays."
233+
)
242234

243235
assert ctx.py_args is not None
244236
for i, arg_meta in enumerate(ctx.func.arg_metas):

0 commit comments

Comments
 (0)