diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 76be2c2330b..ab9845fc6d3 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -2153,7 +2153,21 @@ def _can_add_op_at(self, moment_index: int, operation: cirq.Operation) -> bool: if not 0 <= moment_index < len(self._moments): return True - return not self._moments[moment_index].operates_on(operation.qubits) + if self._moments[moment_index].operates_on(operation.qubits): + return False + + op_measurement_keys = protocols.measurement_key_objs(operation) + op_control_keys = protocols.control_keys(operation) + + moment_measurement_keys = protocols.measurement_key_objs(self._moments[moment_index]) + moment_control_keys = protocols.control_keys(self._moments[moment_index]) + + # Check that there is no measurement key - control key conflict. + # It is allowed to have already an operation in the moment with + # the same measurement key. + return op_control_keys.isdisjoint( + moment_measurement_keys + ) and moment_control_keys.isdisjoint(op_measurement_keys) def _latest_available_moment(self, op: cirq.Operation, *, start_moment_index: int = 0) -> int: """Finds the index of the latest (i.e. right most) moment which can accommodate `op`. @@ -2999,17 +3013,42 @@ def _group_into_moment_compatible(inputs: Sequence[_MOMENT_OR_OP]) -> Iterator[l """ batch: list[_MOMENT_OR_OP] = [] batch_qubits: set[cirq.Qid] = set() + batch_measurement_keys: set[cirq.MeasurementKey] = set() + batch_control_keys: set[cirq.MeasurementKey] = set() for mop in inputs: - is_moment = isinstance(mop, cirq.Moment) - if (is_moment and batch) or not batch_qubits.isdisjoint(mop.qubits): + if isinstance(mop, cirq.Moment): + if batch: + yield batch + batch = [] + batch_qubits.clear() + batch_measurement_keys.clear() + batch_control_keys.clear() + yield [mop] + continue + + op_qubits = mop.qubits + op_measurement_keys = protocols.measurement_key_objs(mop) + op_control_keys = protocols.control_keys(mop) + + # Check that qubits are different and there is no + # measurement key - control key conflict. + # It is allowed to have already an operation in the batch with + # the same measurement key. + if ( + not batch_qubits.isdisjoint(op_qubits) + or not batch_measurement_keys.isdisjoint(op_control_keys) + or not batch_control_keys.isdisjoint(op_measurement_keys) + ): yield batch batch = [] batch_qubits.clear() - if is_moment: - yield [mop] - continue + batch_measurement_keys.clear() + batch_control_keys.clear() + batch.append(mop) - batch_qubits.update(mop.qubits) + batch_qubits.update(op_qubits) + batch_measurement_keys.update(op_measurement_keys) + batch_control_keys.update(op_control_keys) if batch: yield batch diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 4d86d451d02..128d3b3e668 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -5054,3 +5054,105 @@ def test_insert_moments_and_ops_latest() -> None: cirq.Moment([cirq.H(q[1])]), ) assert c.insert(insert_index, moments_and_ops, cirq.InsertStrategy.LATEST) == index_after + + +def test_insert_earliest_batch_with_measurement_key_dependency() -> None: + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.X(q0), cirq.X(q1)) + ops_to_insert = [cirq.measure(q0, key="k"), cirq.X(q1).with_classical_controls("k")] + c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST) + + assert c == cirq.Circuit( + cirq.Moment(cirq.measure(q0, key="k")), + cirq.Moment(cirq.X(q1).with_classical_controls("k")), + cirq.Moment(cirq.X(q0), cirq.X(q1)), + ) + + +def test_insert_earliest_op_with_control_key_unions_with_existing_moment() -> None: + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.X(q0)) + ops_to_insert = [cirq.measure(q0, key="k"), cirq.X(q1).with_classical_controls("k")] + c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST) + assert c == cirq.Circuit( + cirq.Moment(cirq.measure(q0, key="k")), + cirq.Moment(cirq.X(q0), cirq.X(q1).with_classical_controls("k")), + ) + + +def test_insert_earliest_op_with_control_key() -> None: + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.measure(q0, key="k")) + ops_to_insert = [cirq.X(q1).with_classical_controls("k")] + c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST) + + assert c == cirq.Circuit( + cirq.Moment(cirq.X(q1).with_classical_controls("k")), cirq.Moment(cirq.measure(q0, key="k")) + ) + + +def test_insert_earliest_batch_same_measurement_key() -> None: + # Checks that operations with the same measurement key + # can be put in the same moment. + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.X(q0), cirq.X(q1)) + ops_to_insert = [cirq.measure(q0, key="k"), cirq.measure(q1, key="k")] + c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST) + + assert c == cirq.Circuit( + cirq.Moment(cirq.measure(q0, key="k"), cirq.measure(q1, key="k")), + cirq.Moment(cirq.X(q0), cirq.X(q1)), + ) + + +def test_insert_earliest_batch_with_measurement_key_dependency_reversed() -> None: + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.X(q0), cirq.X(q1)) + ops_to_insert = [cirq.X(q1).with_classical_controls("k"), cirq.measure(q0, key="k")] + c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST) + + assert c == cirq.Circuit( + cirq.Moment(cirq.X(q1).with_classical_controls("k")), + cirq.Moment(cirq.measure(q0, key="k")), + cirq.Moment(cirq.X(q0), cirq.X(q1)), + ) + + +def test_insert_earliest_batch() -> None: + q0, q1 = cirq.LineQubit.range(2) + # Create a circuit with every operation in a separate moment + c = cirq.Circuit( + cirq.X(q0), cirq.measure(q0, key="k"), cirq.X(q1), strategy=cirq.InsertStrategy.NEW + ) + ops_to_insert = [cirq.X(q0).with_classical_controls("k"), cirq.measure(q0, key="k")] + c.insert(2, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST) + + assert c == cirq.Circuit( + cirq.Moment(cirq.X(q0)), + cirq.Moment(cirq.measure(q0, key="k")), + cirq.Moment(cirq.X(q0).with_classical_controls("k"), cirq.X(q1)), + cirq.Moment(cirq.measure(q0, key="k")), + ) + + +def test_insert_in_existing_moment_same_measurement_key_different_qubits() -> None: + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.measure(q0, key="k")) + + op_to_add = cirq.measure(q1, key="k") + + c.insert(0, [op_to_add], strategy=cirq.InsertStrategy.EARLIEST) + + assert c == cirq.Circuit(cirq.Moment(cirq.measure(q0, key="k"), cirq.measure(q1, key="k"))) + + +def test_insert_in_existing_moment_measurement_control_key_conflict() -> None: + c = cirq.Circuit(cirq.measure(q0, key="k")) + + op_to_add = cirq.X(q1).with_classical_controls("k") + + c.insert(0, [op_to_add], strategy=cirq.InsertStrategy.EARLIEST) + + assert c == cirq.Circuit( + cirq.Moment(cirq.X(q1).with_classical_controls("k")), cirq.Moment(cirq.measure(q0, key="k")) + )