Skip to content

Commit 0427df3

Browse files
authored
Implement simple no-cloning validation for flattened kernels (#698)
Simplified version of #607 until we have a proper solution. This is enough to support gemini logical kernels for now. Closes #695 .
1 parent e57ebac commit 0427df3

9 files changed

Lines changed: 301 additions & 2 deletions

File tree

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .validation import (
2+
FlatKernelNoCloningValidation as FlatKernelNoCloningValidation,
3+
_FlatKernelNoCloningAnalysis as _FlatKernelNoCloningAnalysis,
4+
)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Any
2+
from dataclasses import field, dataclass
3+
4+
from kirin import ir
5+
from kirin.lattice import EmptyLattice
6+
from kirin.analysis import Forward
7+
from kirin.validation import ValidationPass
8+
from kirin.analysis.forward import ForwardFrame
9+
10+
from bloqade.analysis.address import Address, AddressAnalysis
11+
12+
13+
class _FlatKernelNoCloningAnalysis(Forward[EmptyLattice]):
14+
"""Simple no-cloning validation for kernels that have been aggressively inlined."""
15+
16+
keys = ("validate.nocloning.flatkernel",)
17+
lattice = EmptyLattice
18+
_address_frame: ForwardFrame[Address] | None = None
19+
20+
def eval_fallback(
21+
self, frame: ForwardFrame[EmptyLattice], node: ir.Statement
22+
) -> None:
23+
pass
24+
25+
def run(self, method: ir.Method, *args: EmptyLattice, **kwargs: EmptyLattice):
26+
if self._address_frame is None:
27+
address_analysis = AddressAnalysis(method.dialects)
28+
address_frame, _ = address_analysis.run(method)
29+
self._address_frame = address_frame
30+
return super().run(method, *args, **kwargs)
31+
32+
def method_self(self, method: ir.Method) -> EmptyLattice:
33+
return EmptyLattice.bottom()
34+
35+
def collect_errors(self, stmt: ir.Statement, addresses: list[int]):
36+
seen = set()
37+
duplicates = set()
38+
39+
for addr in addresses:
40+
if addr in seen:
41+
duplicates.add(addr)
42+
else:
43+
seen.add(addr)
44+
45+
self.add_validation_error(
46+
stmt,
47+
ir.ValidationError(
48+
stmt,
49+
f"Gate {stmt.name.upper()} applies to the qubits {duplicates} more than once.",
50+
),
51+
)
52+
53+
54+
@dataclass
55+
class FlatKernelNoCloningValidation(ValidationPass):
56+
_analysis: _FlatKernelNoCloningAnalysis = field(init=False)
57+
58+
def name(self) -> str:
59+
return "No-Cloning Validation"
60+
61+
def get_required_analyses(self) -> list[type]:
62+
return [AddressAnalysis]
63+
64+
def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
65+
analysis = _FlatKernelNoCloningAnalysis(method.dialects)
66+
frame, _ = analysis.run(
67+
method, *(EmptyLattice.bottom() for _ in range(len(method.args) - 1))
68+
)
69+
70+
self._analysis = analysis
71+
errors = analysis.get_validation_errors()
72+
73+
return frame, errors

src/bloqade/gemini/analysis/logical_validation/analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
@dataclass
1515
class _GeminiLogicalValidationAnalysis(Forward[EmptyLattice]):
16-
keys = ["gemini.validate.logical"]
16+
keys = ("gemini.validate.logical",)
1717

1818
lattice = EmptyLattice
1919
addr_frame: ForwardFrame[Address]

src/bloqade/gemini/dialects/logical/groups.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from bloqade.squin import gate, qubit
1212
from bloqade.rewrite.passes import AggressiveUnroll
13+
from bloqade.analysis.validation.simple_nocloning import FlatKernelNoCloningValidation
1314

1415
from ._dialect import dialect
1516

@@ -71,7 +72,11 @@ def run_pass(
7172
)
7273

7374
validator = ValidationSuite(
74-
[GeminiLogicalValidation, GeminiTerminalMeasurementValidation]
75+
[
76+
GeminiLogicalValidation,
77+
GeminiTerminalMeasurementValidation,
78+
FlatKernelNoCloningValidation,
79+
]
7580
)
7681
validation_result = validator.validate(mt)
7782
validation_result.raise_if_invalid()

src/bloqade/squin/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
single_qubit_pauli_channel as single_qubit_pauli_channel,
5454
)
5555
from .analysis.fidelity import impls as impls
56+
from .analysis.validation.simple_nocloning import ( # noqa: F401
57+
impls as simple_nocloning_impls,
58+
)
5659

5760
# NOTE: it's important to keep these imports here since they import squin.kernel
5861
# we skip isort here

src/bloqade/squin/analysis/validation/simple_nocloning/__init__.py

Whitespace-only changes.
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from kirin import interp
2+
from kirin.lattice import EmptyLattice
3+
from kirin.analysis import ForwardFrame
4+
5+
from bloqade.squin import gate, noise
6+
from bloqade.analysis.address.lattice import AddressReg, PartialIList
7+
from bloqade.analysis.validation.simple_nocloning import _FlatKernelNoCloningAnalysis
8+
9+
10+
@gate.dialect.register(key="validate.nocloning.flatkernel")
11+
class GateMethods(interp.MethodTable):
12+
@interp.impl(gate.stmts.X)
13+
@interp.impl(gate.stmts.Y)
14+
@interp.impl(gate.stmts.Z)
15+
@interp.impl(gate.stmts.H)
16+
@interp.impl(gate.stmts.S)
17+
@interp.impl(gate.stmts.T)
18+
@interp.impl(gate.stmts.Rx)
19+
@interp.impl(gate.stmts.Ry)
20+
@interp.impl(gate.stmts.Rz)
21+
def single_qubit_gate(
22+
self,
23+
interp_: _FlatKernelNoCloningAnalysis,
24+
frame: ForwardFrame[EmptyLattice],
25+
stmt: gate.stmts.SingleQubitGate,
26+
):
27+
if interp_._address_frame is None:
28+
return
29+
30+
qubit_addrs = interp_._address_frame.get(stmt.qubits)
31+
32+
if not isinstance(qubit_addrs, AddressReg):
33+
return
34+
35+
unique_addrs = set(qubit_addrs.data)
36+
if len(qubit_addrs.data) == len(unique_addrs):
37+
return
38+
39+
interp_.collect_errors(stmt, list(qubit_addrs.data))
40+
41+
@interp.impl(gate.stmts.CX)
42+
@interp.impl(gate.stmts.CY)
43+
@interp.impl(gate.stmts.CZ)
44+
def controlled_gate(
45+
self,
46+
interp_: _FlatKernelNoCloningAnalysis,
47+
frame: ForwardFrame[EmptyLattice],
48+
stmt: gate.stmts.ControlledGate,
49+
):
50+
51+
if interp_._address_frame is None:
52+
return
53+
54+
control_addrs = interp_._address_frame.get(stmt.controls)
55+
target_addrs = interp_._address_frame.get(stmt.targets)
56+
57+
if not isinstance(control_addrs, AddressReg) or not isinstance(
58+
target_addrs, AddressReg
59+
):
60+
return
61+
62+
all_addrs = list(control_addrs.data) + list(target_addrs.data)
63+
unique_addrs = set(all_addrs)
64+
65+
if len(all_addrs) == len(unique_addrs):
66+
return
67+
68+
interp_.collect_errors(stmt, all_addrs)
69+
70+
71+
@noise.dialect.register(key="validate.nocloning.flatkernel")
72+
class NoiseMethods(interp.MethodTable):
73+
@interp.impl(noise.stmts.Depolarize)
74+
@interp.impl(noise.stmts.SingleQubitPauliChannel)
75+
@interp.impl(noise.stmts.QubitLoss)
76+
def single_qubit_noise_channel(
77+
self,
78+
interp_: _FlatKernelNoCloningAnalysis,
79+
frame: ForwardFrame[EmptyLattice],
80+
stmt: (
81+
noise.stmts.SingleQubitPauliChannel
82+
| noise.stmts.Depolarize
83+
| noise.stmts.QubitLoss
84+
),
85+
):
86+
if interp_._address_frame is None:
87+
return
88+
89+
qubit_addrs = interp_._address_frame.get(stmt.qubits)
90+
91+
if not isinstance(qubit_addrs, AddressReg):
92+
return
93+
94+
if len(qubit_addrs.data) == len(set(qubit_addrs.data)):
95+
return
96+
97+
interp_.collect_errors(stmt, list(qubit_addrs.data))
98+
99+
@interp.impl(noise.stmts.Depolarize2)
100+
@interp.impl(noise.stmts.TwoQubitPauliChannel)
101+
def two_qubit_noise_channel(
102+
self,
103+
interp_: _FlatKernelNoCloningAnalysis,
104+
frame: ForwardFrame[EmptyLattice],
105+
stmt: noise.stmts.Depolarize2 | noise.stmts.TwoQubitPauliChannel,
106+
):
107+
if interp_._address_frame is None:
108+
return
109+
110+
control_addrs = interp_._address_frame.get(stmt.controls)
111+
target_addrs = interp_._address_frame.get(stmt.targets)
112+
113+
if not isinstance(control_addrs, AddressReg) or not isinstance(
114+
target_addrs, AddressReg
115+
):
116+
return
117+
118+
all_addrs = list(control_addrs.data) + list(target_addrs.data)
119+
120+
if len(all_addrs) == len(set(all_addrs)):
121+
return
122+
123+
interp_.collect_errors(stmt, all_addrs)
124+
125+
@interp.impl(noise.stmts.CorrelatedQubitLoss)
126+
def correlated_loss(
127+
self,
128+
interp_: _FlatKernelNoCloningAnalysis,
129+
frame: ForwardFrame[EmptyLattice],
130+
stmt: noise.stmts.CorrelatedQubitLoss,
131+
):
132+
if interp_._address_frame is None:
133+
return
134+
135+
qubit_addrs = interp_._address_frame.get(stmt.qubits)
136+
137+
if not isinstance(qubit_addrs, PartialIList):
138+
return
139+
140+
all_addresses = []
141+
for group_addrs in qubit_addrs.data:
142+
if not isinstance(group_addrs, AddressReg):
143+
continue
144+
all_addresses.extend(group_addrs.data)
145+
146+
if len(all_addresses) == len(set(all_addresses)):
147+
return
148+
149+
interp_.collect_errors(stmt, all_addresses)

test/gemini/test_logical_validation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,13 @@ def main():
205205
squin.broadcast.h(q[1:])
206206

207207
main.print()
208+
209+
210+
def test_nocloning():
211+
with pytest.raises(ValidationErrorGroup):
212+
213+
@gemini.logical.kernel
214+
def main():
215+
q = squin.qalloc(2)
216+
squin.broadcast.ry(0.123, [q[0], q[0]])
217+
squin.cx(q[1], q[1])
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
from kirin.validation import ValidationSuite
3+
from kirin.ir.exception import ValidationErrorGroup
4+
5+
from bloqade import squin
6+
from bloqade.rewrite.passes import AggressiveUnroll
7+
from bloqade.analysis.validation.simple_nocloning import FlatKernelNoCloningValidation
8+
9+
10+
def test_gates():
11+
12+
@squin.kernel
13+
def bad_kernel():
14+
q = squin.qalloc(3)
15+
squin.broadcast.x([q[0], q[1], q[0]])
16+
squin.broadcast.rx(0.123, [q[1], q[1]])
17+
squin.cx(q[2], q[2])
18+
19+
AggressiveUnroll(bad_kernel.dialects).fixpoint(bad_kernel)
20+
21+
_, errors = FlatKernelNoCloningValidation().run(bad_kernel)
22+
assert len(errors) == 3
23+
24+
25+
def test_noise():
26+
@squin.kernel
27+
def bad_kernel():
28+
q = squin.qalloc(3)
29+
squin.broadcast.depolarize(0.1, [q[0], q[0]])
30+
squin.broadcast.depolarize2(0.1, [q[0]], [q[0]])
31+
32+
AggressiveUnroll(bad_kernel.dialects).fixpoint(bad_kernel)
33+
34+
validation_suite = ValidationSuite([FlatKernelNoCloningValidation])
35+
result = validation_suite.validate(bad_kernel)
36+
assert result.error_count() == 2
37+
38+
with pytest.raises(ValidationErrorGroup):
39+
result.raise_if_invalid()
40+
41+
42+
def test_correlated_loss():
43+
@squin.kernel
44+
def bad_kernel():
45+
q = squin.qalloc(3)
46+
squin.broadcast.correlated_qubit_loss(0.1, [[q[0], q[1]], [q[1], q[2]]])
47+
48+
AggressiveUnroll(bad_kernel.dialects).fixpoint(bad_kernel)
49+
50+
validation_suite = ValidationSuite([FlatKernelNoCloningValidation])
51+
result = validation_suite.validate(bad_kernel)
52+
assert result.error_count() == 1
53+
54+
with pytest.raises(ValidationErrorGroup):
55+
result.raise_if_invalid()

0 commit comments

Comments
 (0)