Skip to content

Commit be76e86

Browse files
codrut3pavoljuhas
andauthored
Fix bug in circuit.insert (#7823)
Change `_group_into_moment_compatible` to take into account measurement and control keys. Otherwise two incompatible operations can be put in the same moment. I also changed `_can_add_op_at` to look at measurement and control keys. I added unit tests that show the issue. --------- Co-authored-by: Pavol Juhas <juhas@google.com>
1 parent ccf35e6 commit be76e86

2 files changed

Lines changed: 152 additions & 7 deletions

File tree

cirq-core/cirq/circuits/circuit.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,7 +2153,27 @@ def _can_add_op_at(self, moment_index: int, operation: cirq.Operation) -> bool:
21532153
if not 0 <= moment_index < len(self._moments):
21542154
return True
21552155

2156-
return not self._moments[moment_index].operates_on(operation.qubits)
2156+
if self._moments[moment_index].operates_on(operation.qubits):
2157+
return False
2158+
2159+
op_measurement_keys = protocols.measurement_key_objs(operation)
2160+
op_control_keys = protocols.control_keys(operation)
2161+
2162+
# defer extraction of moment keys until truly needed
2163+
result = True
2164+
if op_control_keys or op_measurement_keys:
2165+
moment_measurement_keys = protocols.measurement_key_objs(self._moments[moment_index])
2166+
result = op_control_keys.isdisjoint(moment_measurement_keys) and (
2167+
(
2168+
op_measurement_keys.isdisjoint(moment_measurement_keys)
2169+
and op_measurement_keys.isdisjoint(
2170+
protocols.control_keys(self._moments[moment_index])
2171+
)
2172+
)
2173+
if op_measurement_keys
2174+
else True
2175+
)
2176+
return result
21572177

21582178
def _latest_available_moment(self, op: cirq.Operation, *, start_moment_index: int = 0) -> int:
21592179
"""Finds the index of the latest (i.e. right most) moment which can accommodate `op`.
@@ -2999,17 +3019,39 @@ def _group_into_moment_compatible(inputs: Sequence[_MOMENT_OR_OP]) -> Iterator[l
29993019
"""
30003020
batch: list[_MOMENT_OR_OP] = []
30013021
batch_qubits: set[cirq.Qid] = set()
3022+
batch_measurement_keys: set[cirq.MeasurementKey] = set()
3023+
batch_control_keys: set[cirq.MeasurementKey] = set()
30023024
for mop in inputs:
3003-
is_moment = isinstance(mop, cirq.Moment)
3004-
if (is_moment and batch) or not batch_qubits.isdisjoint(mop.qubits):
3025+
if isinstance(mop, cirq.Moment):
3026+
if batch:
3027+
yield batch
3028+
batch = []
3029+
batch_qubits.clear()
3030+
batch_measurement_keys.clear()
3031+
batch_control_keys.clear()
3032+
yield [mop]
3033+
continue
3034+
3035+
op_qubits = mop.qubits
3036+
op_measurement_keys = protocols.measurement_key_objs(mop)
3037+
op_control_keys = protocols.control_keys(mop)
3038+
3039+
if (
3040+
not batch_qubits.isdisjoint(op_qubits)
3041+
or not batch_measurement_keys.isdisjoint(op_measurement_keys)
3042+
or not batch_measurement_keys.isdisjoint(op_control_keys)
3043+
or not batch_control_keys.isdisjoint(op_measurement_keys)
3044+
):
30053045
yield batch
30063046
batch = []
30073047
batch_qubits.clear()
3008-
if is_moment:
3009-
yield [mop]
3010-
continue
3048+
batch_measurement_keys.clear()
3049+
batch_control_keys.clear()
3050+
30113051
batch.append(mop)
3012-
batch_qubits.update(mop.qubits)
3052+
batch_qubits.update(op_qubits)
3053+
batch_measurement_keys.update(op_measurement_keys)
3054+
batch_control_keys.update(op_control_keys)
30133055
if batch:
30143056
yield batch
30153057

cirq-core/cirq/circuits/circuit_test.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5054,3 +5054,106 @@ def test_insert_moments_and_ops_latest() -> None:
50545054
cirq.Moment([cirq.H(q[1])]),
50555055
)
50565056
assert c.insert(insert_index, moments_and_ops, cirq.InsertStrategy.LATEST) == index_after
5057+
5058+
5059+
def test_insert_earliest_batch_with_measurement_key_dependency() -> None:
5060+
q0, q1 = cirq.LineQubit.range(2)
5061+
c = cirq.Circuit(cirq.X(q0), cirq.X(q1))
5062+
ops_to_insert = [cirq.measure(q0, key="k"), cirq.X(q1).with_classical_controls("k")]
5063+
c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)
5064+
5065+
assert c == cirq.Circuit(
5066+
cirq.Moment(cirq.measure(q0, key="k")),
5067+
cirq.Moment(cirq.X(q1).with_classical_controls("k")),
5068+
cirq.Moment(cirq.X(q0), cirq.X(q1)),
5069+
)
5070+
5071+
5072+
def test_insert_earliest_op_with_control_key_unions_with_existing_moment() -> None:
5073+
q0, q1 = cirq.LineQubit.range(2)
5074+
c = cirq.Circuit(cirq.X(q0))
5075+
ops_to_insert = [cirq.measure(q0, key="k"), cirq.X(q1).with_classical_controls("k")]
5076+
c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)
5077+
assert c == cirq.Circuit(
5078+
cirq.Moment(cirq.measure(q0, key="k")),
5079+
cirq.Moment(cirq.X(q0), cirq.X(q1).with_classical_controls("k")),
5080+
)
5081+
5082+
5083+
def test_insert_earliest_op_with_control_key() -> None:
5084+
q0, q1 = cirq.LineQubit.range(2)
5085+
c = cirq.Circuit(cirq.measure(q0, key="k"))
5086+
ops_to_insert = [cirq.X(q1).with_classical_controls("k")]
5087+
c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)
5088+
5089+
assert c == cirq.Circuit(
5090+
cirq.Moment(cirq.X(q1).with_classical_controls("k")), cirq.Moment(cirq.measure(q0, key="k"))
5091+
)
5092+
5093+
5094+
def test_insert_earliest_batch_same_measurement_key() -> None:
5095+
q0, q1 = cirq.LineQubit.range(2)
5096+
c = cirq.Circuit(cirq.X(q0), cirq.X(q1))
5097+
ops_to_insert = [cirq.measure(q0, key="k"), cirq.measure(q1, key="k")]
5098+
c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)
5099+
5100+
assert c == cirq.Circuit(
5101+
cirq.Moment(cirq.measure(q0, key="k")),
5102+
cirq.Moment(cirq.measure(q1, key="k")),
5103+
cirq.Moment(cirq.X(q0), cirq.X(q1)),
5104+
)
5105+
5106+
5107+
def test_insert_earliest_batch_with_measurement_key_dependency_reversed() -> None:
5108+
q0, q1 = cirq.LineQubit.range(2)
5109+
c = cirq.Circuit(cirq.X(q0), cirq.X(q1))
5110+
ops_to_insert = [cirq.X(q1).with_classical_controls("k"), cirq.measure(q0, key="k")]
5111+
c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)
5112+
5113+
assert c == cirq.Circuit(
5114+
cirq.Moment(cirq.X(q1).with_classical_controls("k")),
5115+
cirq.Moment(cirq.measure(q0, key="k")),
5116+
cirq.Moment(cirq.X(q0), cirq.X(q1)),
5117+
)
5118+
5119+
5120+
def test_insert_earliest_batch() -> None:
5121+
q0, q1 = cirq.LineQubit.range(2)
5122+
# Create a circuit with every operation in a separate moment
5123+
c = cirq.Circuit(
5124+
cirq.X(q0), cirq.measure(q0, key="k"), cirq.X(q1), strategy=cirq.InsertStrategy.NEW
5125+
)
5126+
ops_to_insert = [cirq.X(q0).with_classical_controls("k"), cirq.measure(q0, key="k")]
5127+
c.insert(2, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)
5128+
5129+
assert c == cirq.Circuit(
5130+
cirq.Moment(cirq.X(q0)),
5131+
cirq.Moment(cirq.measure(q0, key="k")),
5132+
cirq.Moment(cirq.X(q0).with_classical_controls("k"), cirq.X(q1)),
5133+
cirq.Moment(cirq.measure(q0, key="k")),
5134+
)
5135+
5136+
5137+
def test_insert_in_existing_moment_same_measurement_key_different_qubits() -> None:
5138+
q0, q1 = cirq.LineQubit.range(2)
5139+
c = cirq.Circuit(cirq.measure(q0, key="k"))
5140+
5141+
op_to_add = cirq.measure(q1, key="k")
5142+
5143+
c.insert(0, [op_to_add], strategy=cirq.InsertStrategy.EARLIEST)
5144+
5145+
assert c == cirq.Circuit(
5146+
cirq.Moment(cirq.measure(q1, key="k")), cirq.Moment(cirq.measure(q0, key="k"))
5147+
)
5148+
5149+
5150+
def test_insert_in_existing_moment_measurement_control_key_conflict() -> None:
5151+
c = cirq.Circuit(cirq.measure(q0, key="k"))
5152+
5153+
op_to_add = cirq.X(q1).with_classical_controls("k")
5154+
5155+
c.insert(0, [op_to_add], strategy=cirq.InsertStrategy.EARLIEST)
5156+
5157+
assert c == cirq.Circuit(
5158+
cirq.Moment(cirq.X(q1).with_classical_controls("k")), cirq.Moment(cirq.measure(q0, key="k"))
5159+
)

0 commit comments

Comments
 (0)