From b61d4447f00abbbabd4b675c8f6b6f90232a1638 Mon Sep 17 00:00:00 2001 From: Jason Han Date: Wed, 29 Apr 2026 16:23:47 -0400 Subject: [PATCH 1/2] feat: added wrapper function --- src/bloqade/squin/__init__.py | 1 + src/bloqade/squin/utils.py | 119 ++++++++++++++++++++++++++++++++++ test/squin/test_wrap.py | 67 +++++++++++++++++++ 3 files changed, 187 insertions(+) create mode 100644 src/bloqade/squin/utils.py create mode 100644 test/squin/test_wrap.py diff --git a/src/bloqade/squin/__init__.py b/src/bloqade/squin/__init__.py index 4a5556c3..e571406f 100644 --- a/src/bloqade/squin/__init__.py +++ b/src/bloqade/squin/__init__.py @@ -10,6 +10,7 @@ analysis as analysis, ) from .. import qubit as qubit +from .utils import wrap as wrap from ..qubit import ( reset as reset, is_one as is_one, diff --git a/src/bloqade/squin/utils.py b/src/bloqade/squin/utils.py new file mode 100644 index 00000000..e1e0b60f --- /dev/null +++ b/src/bloqade/squin/utils.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import linecache +from typing import Any +from keyword import iskeyword +from itertools import count + +from kirin import ir + +from ..qubit import qalloc, broadcast +from .groups import kernel + +_wrap_counter = count() + + +def wrap( + method: ir.Method, + /, + n_qubits: int | None = None, + **kwargs: Any, +) -> ir.Method: + """Wrap a qubit kernel in an allocating and measuring entry-point. + + Args: + method: Kernel to invoke on the allocated qubits. + n_qubits: Number of qubits to allocate. When omitted, this is inferred + from the method's unsupplied parameters. + **kwargs: Kernel keyword arguments to bind in the wrapped entry-point. + + Returns: + A ``squin.kernel`` method that allocates qubits, invokes ``method``, and + returns measurements for all allocated qubits. + """ + if not isinstance(method, ir.Method): + raise TypeError(f"expected a Kirin Method, got {type(method).__name__}") + + param_names = _method_param_names(method) + _validate_kwargs(param_names, kwargs) + + if n_qubits is None: + n_qubits = len(param_names) - len(kwargs) + + if n_qubits < 0: + raise ValueError("n_qubits must be non-negative") + + if len(param_names) != n_qubits + len(kwargs): + raise ValueError( + f"cannot call {method.sym_name or 'method'} with {n_qubits} qubits " + f"and {len(kwargs)} keyword arguments; expected {len(param_names)} " + "total arguments" + ) + + return _compile_wrapper(method, n_qubits, kwargs) + + +def _method_param_names(method: ir.Method) -> list[str]: + if method.arg_names is None: + return [f"arg{i}" for i in range(method.nargs - 1)] + return list(method.arg_names[1:]) + + +def _validate_kwargs(param_names: list[str], kwargs: dict[str, Any]) -> None: + unexpected = set(kwargs).difference(param_names) + if unexpected: + unexpected_names = ", ".join(sorted(unexpected)) + raise TypeError(f"unexpected keyword argument(s): {unexpected_names}") + + for name in kwargs: + if not name.isidentifier() or iskeyword(name): + raise ValueError(f"invalid keyword argument name: {name!r}") + + +def _compile_wrapper( + wrapped_method: ir.Method, + n_qubits: int, + kwargs: dict[str, Any], +) -> ir.Method: + positional_args = [f"qubits[{idx}]" for idx in range(n_qubits)] + keyword_globals = { + f"__wrapped_kwarg_{idx}__": value for idx, value in enumerate(kwargs.values()) + } + keyword_args = [ + f"{name}={global_name}" + for name, global_name in zip(kwargs, keyword_globals, strict=True) + ] + call_args = ", ".join(positional_args + keyword_args) + + call_line = ( + f" __wrapped_method__({call_args})" + if call_args + else " __wrapped_method__()" + ) + source = ( + "def main():\n" + f" qubits = __qalloc__({n_qubits})\n" + f"{call_line}\n" + " return __broadcast__.measure(qubits)\n" + ) + + filename = ( + f"" + ) + linecache.cache[filename] = ( + len(source), + None, + source.splitlines(keepends=True), + filename, + ) + + globals_ = { + "__wrapped_method__": wrapped_method, + "__qalloc__": qalloc, + "__broadcast__": broadcast, + **keyword_globals, + } + locals_: dict[str, Any] = {} + exec(compile(source, filename, "exec"), globals_, locals_) + return kernel(locals_["main"]) diff --git a/test/squin/test_wrap.py b/test/squin/test_wrap.py new file mode 100644 index 00000000..94f2d295 --- /dev/null +++ b/test/squin/test_wrap.py @@ -0,0 +1,67 @@ +import pytest + +from bloqade import squin +from bloqade.types import Qubit +from bloqade.pyqrack import StackMemorySimulator + + +def test_wrap_infers_qubit_count_and_measures(): + @squin.kernel + def bell(q0: Qubit, q1: Qubit): + squin.h(q0) + squin.cx(q0, q1) + + main = squin.wrap(bell) + + main.print() + sim = StackMemorySimulator(min_qubits=2) + result = sim.run(main) + + assert len(result) == 2 + + +def test_wrap_binds_keyword_arguments(): + @squin.kernel + def rotate(q: Qubit, theta: float): + squin.rx(theta, q) + + main = squin.wrap(rotate, theta=0.125) + + main.print() + sim = StackMemorySimulator(min_qubits=1) + result = sim.run(main) + + assert len(result) == 1 + + +def test_wrap_accepts_explicit_qubit_count_with_keywords(): + @squin.kernel + def ansatz(q0: Qubit, q1: Qubit, theta: float): + squin.rx(theta, q0) + squin.cx(q0, q1) + + main = squin.wrap(ansatz, 2, theta=0.125) + + main.print() + sim = StackMemorySimulator(min_qubits=2) + result = sim.run(main) + + assert len(result) == 2 + + +def test_wrap_rejects_mismatched_argument_count(): + @squin.kernel + def two_qubit(q0: Qubit, q1: Qubit): + squin.cx(q0, q1) + + with pytest.raises(ValueError, match="expected 2 total arguments"): + squin.wrap(two_qubit, 1) + + +def test_wrap_rejects_unexpected_keyword_argument(): + @squin.kernel + def one_qubit(q: Qubit): + squin.h(q) + + with pytest.raises(TypeError, match="unexpected keyword argument"): + squin.wrap(one_qubit, theta=0.125) From f2ab24bd27fb2b483115b7b7e9ef1d88ec642173 Mon Sep 17 00:00:00 2001 From: Jason Han Date: Wed, 29 Apr 2026 16:30:33 -0400 Subject: [PATCH 2/2] feat: added error function when qubit is passed as kwarg --- src/bloqade/squin/utils.py | 17 +++++++++++++++++ test/squin/test_wrap.py | 10 ++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/bloqade/squin/utils.py b/src/bloqade/squin/utils.py index e1e0b60f..fa63ef3d 100644 --- a/src/bloqade/squin/utils.py +++ b/src/bloqade/squin/utils.py @@ -43,6 +43,8 @@ def wrap( if n_qubits < 0: raise ValueError("n_qubits must be non-negative") + _validate_qubit_arguments(param_names, n_qubits, kwargs) + if len(param_names) != n_qubits + len(kwargs): raise ValueError( f"cannot call {method.sym_name or 'method'} with {n_qubits} qubits " @@ -70,6 +72,21 @@ def _validate_kwargs(param_names: list[str], kwargs: dict[str, Any]) -> None: raise ValueError(f"invalid keyword argument name: {name!r}") +def _validate_qubit_arguments( + param_names: list[str], + n_qubits: int, + kwargs: dict[str, Any], +) -> None: + qubit_params = set(param_names[:n_qubits]) + bound_qubits = qubit_params.intersection(kwargs) + if bound_qubits: + names = ", ".join(sorted(bound_qubits)) + raise TypeError( + "qubit arguments are allocated by squin.wrap and cannot be bound " + f"by keyword: {names}" + ) + + def _compile_wrapper( wrapped_method: ir.Method, n_qubits: int, diff --git a/test/squin/test_wrap.py b/test/squin/test_wrap.py index 94f2d295..4f86af40 100644 --- a/test/squin/test_wrap.py +++ b/test/squin/test_wrap.py @@ -65,3 +65,13 @@ def one_qubit(q: Qubit): with pytest.raises(TypeError, match="unexpected keyword argument"): squin.wrap(one_qubit, theta=0.125) + + +def test_wrap_rejects_qubit_bound_by_keyword(): + @squin.kernel + def ansatz(q0: Qubit, q1: Qubit, theta: float): + squin.rx(theta, q0) + squin.cx(q0, q1) + + with pytest.raises(TypeError, match="cannot be bound by keyword: q0"): + squin.wrap(ansatz, 2, q0=object(), theta=0.125)