Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2153,7 +2153,21 @@ def _can_add_op_at(self, moment_index: int, operation: cirq.Operation) -> bool:
if not 0 <= moment_index < len(self._moments):
return True

return not self._moments[moment_index].operates_on(operation.qubits)
if self._moments[moment_index].operates_on(operation.qubits):
return False

op_measurement_keys = protocols.measurement_key_objs(operation)
op_control_keys = protocols.control_keys(operation)

moment_measurement_keys = protocols.measurement_key_objs(self._moments[moment_index])
moment_control_keys = protocols.control_keys(self._moments[moment_index])

# Check that there is no measurement key - control key conflict.
# It is allowed to have already an operation in the moment with
# the same measurement key.
return op_control_keys.isdisjoint(
moment_measurement_keys
) and moment_control_keys.isdisjoint(op_measurement_keys)

def _latest_available_moment(self, op: cirq.Operation, *, start_moment_index: int = 0) -> int:
"""Finds the index of the latest (i.e. right most) moment which can accommodate `op`.
Expand Down Expand Up @@ -2999,17 +3013,42 @@ def _group_into_moment_compatible(inputs: Sequence[_MOMENT_OR_OP]) -> Iterator[l
"""
batch: list[_MOMENT_OR_OP] = []
batch_qubits: set[cirq.Qid] = set()
batch_measurement_keys: set[cirq.MeasurementKey] = set()
batch_control_keys: set[cirq.MeasurementKey] = set()
for mop in inputs:
is_moment = isinstance(mop, cirq.Moment)
if (is_moment and batch) or not batch_qubits.isdisjoint(mop.qubits):
if isinstance(mop, cirq.Moment):
if batch:
yield batch
batch = []
batch_qubits.clear()
batch_measurement_keys.clear()
batch_control_keys.clear()
yield [mop]
continue

op_qubits = mop.qubits
op_measurement_keys = protocols.measurement_key_objs(mop)
op_control_keys = protocols.control_keys(mop)

# Check that qubits are different and there is no
# measurement key - control key conflict.
# It is allowed to have already an operation in the batch with
# the same measurement key.
if (
not batch_qubits.isdisjoint(op_qubits)
or not batch_measurement_keys.isdisjoint(op_control_keys)
or not batch_control_keys.isdisjoint(op_measurement_keys)
):
yield batch
batch = []
batch_qubits.clear()
if is_moment:
yield [mop]
continue
batch_measurement_keys.clear()
batch_control_keys.clear()

batch.append(mop)
batch_qubits.update(mop.qubits)
batch_qubits.update(op_qubits)
batch_measurement_keys.update(op_measurement_keys)
batch_control_keys.update(op_control_keys)
if batch:
yield batch

Expand Down
102 changes: 102 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5054,3 +5054,105 @@ def test_insert_moments_and_ops_latest() -> None:
cirq.Moment([cirq.H(q[1])]),
)
assert c.insert(insert_index, moments_and_ops, cirq.InsertStrategy.LATEST) == index_after


def test_insert_earliest_batch_with_measurement_key_dependency() -> None:
q0, q1 = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(q0), cirq.X(q1))
ops_to_insert = [cirq.measure(q0, key="k"), cirq.X(q1).with_classical_controls("k")]
c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)

assert c == cirq.Circuit(
cirq.Moment(cirq.measure(q0, key="k")),
cirq.Moment(cirq.X(q1).with_classical_controls("k")),
cirq.Moment(cirq.X(q0), cirq.X(q1)),
)


def test_insert_earliest_op_with_control_key_unions_with_existing_moment() -> None:
q0, q1 = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(q0))
ops_to_insert = [cirq.measure(q0, key="k"), cirq.X(q1).with_classical_controls("k")]
c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)
assert c == cirq.Circuit(
cirq.Moment(cirq.measure(q0, key="k")),
cirq.Moment(cirq.X(q0), cirq.X(q1).with_classical_controls("k")),
)


def test_insert_earliest_op_with_control_key() -> None:
q0, q1 = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.measure(q0, key="k"))
ops_to_insert = [cirq.X(q1).with_classical_controls("k")]
c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)

assert c == cirq.Circuit(
cirq.Moment(cirq.X(q1).with_classical_controls("k")), cirq.Moment(cirq.measure(q0, key="k"))
)


def test_insert_earliest_batch_same_measurement_key() -> None:
# Checks that operations with the same measurement key
# can be put in the same moment.
q0, q1 = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(q0), cirq.X(q1))
ops_to_insert = [cirq.measure(q0, key="k"), cirq.measure(q1, key="k")]
c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)

assert c == cirq.Circuit(
cirq.Moment(cirq.measure(q0, key="k"), cirq.measure(q1, key="k")),
cirq.Moment(cirq.X(q0), cirq.X(q1)),
)


def test_insert_earliest_batch_with_measurement_key_dependency_reversed() -> None:
q0, q1 = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(q0), cirq.X(q1))
ops_to_insert = [cirq.X(q1).with_classical_controls("k"), cirq.measure(q0, key="k")]
c.insert(0, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)

assert c == cirq.Circuit(
cirq.Moment(cirq.X(q1).with_classical_controls("k")),
cirq.Moment(cirq.measure(q0, key="k")),
cirq.Moment(cirq.X(q0), cirq.X(q1)),
)


def test_insert_earliest_batch() -> None:
q0, q1 = cirq.LineQubit.range(2)
# Create a circuit with every operation in a separate moment
c = cirq.Circuit(
cirq.X(q0), cirq.measure(q0, key="k"), cirq.X(q1), strategy=cirq.InsertStrategy.NEW
)
ops_to_insert = [cirq.X(q0).with_classical_controls("k"), cirq.measure(q0, key="k")]
c.insert(2, ops_to_insert, strategy=cirq.InsertStrategy.EARLIEST)

assert c == cirq.Circuit(
cirq.Moment(cirq.X(q0)),
cirq.Moment(cirq.measure(q0, key="k")),
cirq.Moment(cirq.X(q0).with_classical_controls("k"), cirq.X(q1)),
cirq.Moment(cirq.measure(q0, key="k")),
)


def test_insert_in_existing_moment_same_measurement_key_different_qubits() -> None:
q0, q1 = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.measure(q0, key="k"))

op_to_add = cirq.measure(q1, key="k")

c.insert(0, [op_to_add], strategy=cirq.InsertStrategy.EARLIEST)

assert c == cirq.Circuit(cirq.Moment(cirq.measure(q0, key="k"), cirq.measure(q1, key="k")))


def test_insert_in_existing_moment_measurement_control_key_conflict() -> None:
c = cirq.Circuit(cirq.measure(q0, key="k"))

op_to_add = cirq.X(q1).with_classical_controls("k")

c.insert(0, [op_to_add], strategy=cirq.InsertStrategy.EARLIEST)

assert c == cirq.Circuit(
cirq.Moment(cirq.X(q1).with_classical_controls("k")), cirq.Moment(cirq.measure(q0, key="k"))
)
Loading