Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
analysis as analysis,
)
from .. import qubit as qubit
from .utils import wrap as wrap
from ..qubit import (
reset as reset,
is_one as is_one,
Expand Down
136 changes: 136 additions & 0 deletions src/bloqade/squin/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

import linecache
from typing import Any
from keyword import iskeyword
from itertools import count

from kirin import ir

from ..qubit import qalloc, broadcast
from .groups import kernel

_wrap_counter = count()


def wrap(
method: ir.Method,
/,
n_qubits: int | None = None,
**kwargs: Any,
) -> ir.Method:
"""Wrap a qubit kernel in an allocating and measuring entry-point.

Args:
method: Kernel to invoke on the allocated qubits.
n_qubits: Number of qubits to allocate. When omitted, this is inferred
from the method's unsupplied parameters.
**kwargs: Kernel keyword arguments to bind in the wrapped entry-point.

Returns:
A ``squin.kernel`` method that allocates qubits, invokes ``method``, and
returns measurements for all allocated qubits.
"""
if not isinstance(method, ir.Method):
raise TypeError(f"expected a Kirin Method, got {type(method).__name__}")

param_names = _method_param_names(method)
_validate_kwargs(param_names, kwargs)

if n_qubits is None:
n_qubits = len(param_names) - len(kwargs)

if n_qubits < 0:
raise ValueError("n_qubits must be non-negative")

Comment on lines +34 to +45
_validate_qubit_arguments(param_names, n_qubits, kwargs)

if len(param_names) != n_qubits + len(kwargs):
raise ValueError(
f"cannot call {method.sym_name or 'method'} with {n_qubits} qubits "
f"and {len(kwargs)} keyword arguments; expected {len(param_names)} "
"total arguments"
)

return _compile_wrapper(method, n_qubits, kwargs)


def _method_param_names(method: ir.Method) -> list[str]:
if method.arg_names is None:
return [f"arg{i}" for i in range(method.nargs - 1)]
return list(method.arg_names[1:])


def _validate_kwargs(param_names: list[str], kwargs: dict[str, Any]) -> None:
unexpected = set(kwargs).difference(param_names)
if unexpected:
unexpected_names = ", ".join(sorted(unexpected))
raise TypeError(f"unexpected keyword argument(s): {unexpected_names}")

for name in kwargs:
if not name.isidentifier() or iskeyword(name):
raise ValueError(f"invalid keyword argument name: {name!r}")


def _validate_qubit_arguments(
param_names: list[str],
n_qubits: int,
kwargs: dict[str, Any],
) -> None:
qubit_params = set(param_names[:n_qubits])
bound_qubits = qubit_params.intersection(kwargs)
if bound_qubits:
names = ", ".join(sorted(bound_qubits))
raise TypeError(
"qubit arguments are allocated by squin.wrap and cannot be bound "
f"by keyword: {names}"
)


def _compile_wrapper(
wrapped_method: ir.Method,
n_qubits: int,
kwargs: dict[str, Any],
) -> ir.Method:
positional_args = [f"qubits[{idx}]" for idx in range(n_qubits)]
keyword_globals = {
f"__wrapped_kwarg_{idx}__": value for idx, value in enumerate(kwargs.values())
}
keyword_args = [
f"{name}={global_name}"
for name, global_name in zip(kwargs, keyword_globals, strict=True)
]
call_args = ", ".join(positional_args + keyword_args)

call_line = (
f" __wrapped_method__({call_args})"
if call_args
else " __wrapped_method__()"
)
source = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but no: we're not writing a wrapper kernel function by composing python strings. This is genuinely not a good idea.

"def main():\n"
f" qubits = __qalloc__({n_qubits})\n"
f"{call_line}\n"
" return __broadcast__.measure(qubits)\n"
)

filename = (
f"<bloqade.squin.wrap:{wrapped_method.sym_name or 'anonymous'}:"
f"{next(_wrap_counter)}>"
)
linecache.cache[filename] = (
len(source),
None,
source.splitlines(keepends=True),
filename,
)
Comment on lines +117 to +126

globals_ = {
"__wrapped_method__": wrapped_method,
"__qalloc__": qalloc,
"__broadcast__": broadcast,
**keyword_globals,
}
locals_: dict[str, Any] = {}
exec(compile(source, filename, "exec"), globals_, locals_)
return kernel(locals_["main"])
77 changes: 77 additions & 0 deletions test/squin/test_wrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest

from bloqade import squin
from bloqade.types import Qubit
from bloqade.pyqrack import StackMemorySimulator


def test_wrap_infers_qubit_count_and_measures():
@squin.kernel
def bell(q0: Qubit, q1: Qubit):
squin.h(q0)
squin.cx(q0, q1)

main = squin.wrap(bell)

main.print()
sim = StackMemorySimulator(min_qubits=2)
result = sim.run(main)

assert len(result) == 2


def test_wrap_binds_keyword_arguments():
@squin.kernel
def rotate(q: Qubit, theta: float):
squin.rx(theta, q)

main = squin.wrap(rotate, theta=0.125)

main.print()
sim = StackMemorySimulator(min_qubits=1)
result = sim.run(main)

assert len(result) == 1


def test_wrap_accepts_explicit_qubit_count_with_keywords():
@squin.kernel
def ansatz(q0: Qubit, q1: Qubit, theta: float):
squin.rx(theta, q0)
squin.cx(q0, q1)

main = squin.wrap(ansatz, 2, theta=0.125)

main.print()
sim = StackMemorySimulator(min_qubits=2)
result = sim.run(main)

assert len(result) == 2


def test_wrap_rejects_mismatched_argument_count():
@squin.kernel
def two_qubit(q0: Qubit, q1: Qubit):
squin.cx(q0, q1)

with pytest.raises(ValueError, match="expected 2 total arguments"):
squin.wrap(two_qubit, 1)


def test_wrap_rejects_unexpected_keyword_argument():
@squin.kernel
def one_qubit(q: Qubit):
squin.h(q)

with pytest.raises(TypeError, match="unexpected keyword argument"):
squin.wrap(one_qubit, theta=0.125)


def test_wrap_rejects_qubit_bound_by_keyword():
@squin.kernel
def ansatz(q0: Qubit, q1: Qubit, theta: float):
squin.rx(theta, q0)
squin.cx(q0, q1)

with pytest.raises(TypeError, match="cannot be bound by keyword: q0"):
squin.wrap(ansatz, 2, q0=object(), theta=0.125)
Loading