Added Wrapper Function to Infer Kernel Qubit Count#765
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
☂️ Python Coverage
Overall Coverage
New Files
Modified Files
|
david-pl
left a comment
There was a problem hiding this comment.
Generally, I think the feature idea in the issue is not well thought out.
Right now, the boilerplate you have to write to wrap a kernel is
@squin.kernel
def wrapped_kernel():
q = squin.qalloc(10)
kernel(q)
return squin.broadcast.measure(q)This is 4 lines, it doesn't get much simpler than that while keeping things generic. In the above, it's very simple to provide the exact arguments we need for each specific case. If we want to provide a utility function for this it either gets quite complicated or we have to make strong assumptions about the signature of the kernel we want to wrap, limiting the usefulness. The code here seems to be opting for the second approach. This is the better of the two IMO.
To implement this, we'd need to do something like this (pseudo code):
from kirin import ir
from kirin.dialects import func
body_block = ir.Block(
stmts=[
func.Invoke((n_qubits,), callee=squin.qalloc), # allocate qubits
func.Invoke((qubits,), callee=kernel), # call the actual kernel
func.Invoke((qubits,), callee=squin.broadcast.measure), # measure
func.Return(measurement_results)
]
)
body = ir.Region(blocks=[body_block])
code = func.Function(
sym_name="wrapped",
signature=...,
body=ir.Region()
)
wrapped = ir.Method(
dialects=kernel.dialects,
code=code,
)The current implementation is writing the kernel by composing strings, which is not the way to go here.
Also, keep in mind that the suggested implementation here only works for the specific kernel structure takes the list of qubits as the first argument. This is quite a specific assumption to have. Furthermore, the complexity added here is far from trivial. And all that, to save 4 lines of boiler plate. To be honest, I'd vote we simply do not implement this feature at all. We'll just have to live with the 4 lines of extra code.
From the issue, it also seems like the main use case is simulation. We might want to think about how to improve the simulator API to handle arguments more gracefully instead. The fundamental problem is, however, that you can only allocate qubits inside a kernel.
| if call_args | ||
| else " __wrapped_method__()" | ||
| ) | ||
| source = ( |
There was a problem hiding this comment.
Sorry, but no: we're not writing a wrapper kernel function by composing python strings. This is genuinely not a good idea.
|
It should be possible to do this by just looking at the Method signature and then constructing a kernel that invokes the wrapped kernel, it would be the equivilant to the hard coded closure: def wraps(mt: ir.Method, *args):
code = mt.code
if (trait := code.get_trait(HasSignature)) is None:
raise ValueError("expecting a function that has a signature")
signature = trait.get_signature(code)
qubit_arg_map, other_arg_map, qubit_only_signature = analyze_signature(signature)
wrapper_body = ir.Region([body_block := ir.Block()))
# 1. build block with the qubit arguments
# 2. insert logicl to add `py.Constant` for the constant args
# 3. add invoke statement of `mt` with a mixture of the block arguments and the constant ssa values in the body
# 4. construct the function statement and new method. cc: @jasonhan3 @david-pl |
☂️ Code Coverage
Overall Coverage
New Files
Modified Files
|
There was a problem hiding this comment.
Pull request overview
This PR introduces bloqade.squin.wrap, a helper that turns an existing squin.kernel (ir.Method) into a simulation-ready entry point by allocating n_qubits, invoking the kernel (with optional keyword bindings), measuring all allocated qubits, and returning measurement results. This fits into the squin layer as a convenience utility for running kernels without repeatedly writing allocation/measurement boilerplate.
Changes:
- Added
squin.wrap(method, n_qubits=None, **kwargs)with argument/keyword validation and wrapper codegen via a generatedmain()kernel. - Exported
wrapfrombloqade.squin. - Added a new test module covering inference, keyword binding, explicit qubit counts, and a few validation failures.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
src/bloqade/squin/utils.py |
Implements squin.wrap and its validation + wrapper compilation logic. |
src/bloqade/squin/__init__.py |
Exposes wrap at the bloqade.squin package level. |
test/squin/test_wrap.py |
Adds tests for core wrap behavior and several invalid-call cases. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| filename = ( | ||
| f"<bloqade.squin.wrap:{wrapped_method.sym_name or 'anonymous'}:" | ||
| f"{next(_wrap_counter)}>" | ||
| ) | ||
| linecache.cache[filename] = ( | ||
| len(source), | ||
| None, | ||
| source.splitlines(keepends=True), | ||
| filename, | ||
| ) |
| if not isinstance(method, ir.Method): | ||
| raise TypeError(f"expected a Kirin Method, got {type(method).__name__}") | ||
|
|
||
| param_names = _method_param_names(method) | ||
| _validate_kwargs(param_names, kwargs) | ||
|
|
||
| if n_qubits is None: | ||
| n_qubits = len(param_names) - len(kwargs) | ||
|
|
||
| if n_qubits < 0: | ||
| raise ValueError("n_qubits must be non-negative") | ||
|
|
Summary
Closes #331.
** NOTE: this ONLY works for signatures that are of the form
def kernel(*qubits:Qubit, **kwargs)
and the qubits CANNOT be supplied as kwargs. Otherwise, this implementation will not work.**
This PR adds
squin.wrap, a small utility for turning a qubit-operating kernel into a simulation-ready entry point. The generated wrapper allocates qubits, invokes the provided kernel, measures all allocated qubits, and returns the measurement results.Example:
squin.wrapcan infer the number of qubits from the wrapped kernel's unbound parameters, or accept an explicitn_qubitsvalue when additional kernel parameters are bound by keyword:Changes
squin.wrap(method, n_qubits=None, **kwargs).wrapfrombloqade.squin.Notes
Kirin lowering does not currently support starred calls like
kernel(*qubits), sosquin.wrapgenerates a tiny wrapper function with explicit qubit indexing, for examplekernel(qubits[0], qubits[1]), and then compiles that function through the normalsquin.kernelpath.Testing