Skip to content

Commit d00a422

Browse files
committed
Add support for handling CircuitOperations recursively
1 parent e3bddbe commit d00a422

3 files changed

Lines changed: 177 additions & 5 deletions

File tree

cirq-core/cirq/transformers/routing/line_initial_mapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
import networkx as nx
3939

40-
from cirq import protocols, value
40+
from cirq import circuits, protocols, value
4141
from cirq.transformers.routing import initial_mapper
4242

4343
if TYPE_CHECKING:
@@ -107,6 +107,8 @@ def degree_lt_two(q: cirq.Qid):
107107
return any(circuit_graph[component_id[q]][i] == q for i in [-1, 0])
108108

109109
for op in circuit.all_operations():
110+
if isinstance(op.untagged, circuits.CircuitOperation):
111+
continue
110112
if protocols.num_qubits(op) != 2:
111113
continue
112114

cirq-core/cirq/transformers/routing/route_circuit_cqc.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __call__(
121121
lookahead_radius: int = 8,
122122
tag_inserted_swaps: bool = False,
123123
initial_mapper: cirq.AbstractInitialMapper | None = None,
124+
min_qubit_mapping_threshold: float = 0.5,
124125
context: cirq.TransformerContext | None = None,
125126
) -> cirq.AbstractCircuit:
126127
"""Transforms the given circuit to make it executable on the device.
@@ -137,6 +138,10 @@ def __call__(
137138
initial_mapper: an initial mapping strategy (placement) of logical qubits in the
138139
circuit onto physical qubits on the device. If not provided, defaults to an
139140
instance of `cirq.LineInitialMapper`.
141+
min_qubit_mapping_threshold: the minimum fraction (0.0 to 1.0) of qubits that should
142+
have their initial mapping computed from outer (non-CircuitOperation) 2-qubit gates
143+
before proceeding with routing. If there are not enough outer 2-qubit gates,
144+
CircuitOperations will be partially unrolled to reach this threshold.
140145
context: transformer context storing common configurable options for transformers.
141146
142147
Returns:
@@ -152,6 +157,7 @@ def __call__(
152157
lookahead_radius=lookahead_radius,
153158
tag_inserted_swaps=tag_inserted_swaps,
154159
initial_mapper=initial_mapper,
160+
min_qubit_mapping_threshold=min_qubit_mapping_threshold,
155161
context=context,
156162
)
157163
return routed_circuit
@@ -163,15 +169,15 @@ def route_circuit(
163169
lookahead_radius: int = 8,
164170
tag_inserted_swaps: bool = False,
165171
initial_mapper: cirq.AbstractInitialMapper | None = None,
172+
min_qubit_mapping_threshold: float = 0.5,
166173
context: cirq.TransformerContext | None = None,
167174
) -> tuple[cirq.AbstractCircuit, dict[cirq.Qid, cirq.Qid], dict[cirq.Qid, cirq.Qid]]:
168175
"""Transforms the given circuit to make it executable on the device.
169176
170177
This transformer assumes that all multi-qubit operations have been decomposed into 2-qubit
171178
operations and will raise an error if `circuit` a n-qubit operation where n > 2. If
172-
`circuit` contains `cirq.CircuitOperation`s and `context.deep` is True then they are first
173-
unrolled before proceeding. If `context.deep` is False or `context` is None then any
174-
`cirq.CircuitOperation` that acts on more than 2-qubits will also raise an error.
179+
`circuit` contains `cirq.CircuitOperation`s and `min_qubit_mapping_threshold` < 1.0,
180+
they are handled using a recursive routing strategy instead of being fully unrolled.
175181
176182
The algorithm tries to find the best swap at each timestep by ranking a set of candidate
177183
swaps against operations starting from the current timestep (say s) to the timestep at index
@@ -191,6 +197,11 @@ def route_circuit(
191197
operations.
192198
initial_mapper: an initial mapping strategy (placement) of logical qubits in the
193199
circuit onto physical qubits on the device.
200+
min_qubit_mapping_threshold: the minimum fraction (0.0 to 1.0) of qubits that should
201+
have their initial mapping computed from outer (non-CircuitOperation) 2-qubit gates
202+
before proceeding with routing. If there are not enough outer 2-qubit gates,
203+
CircuitOperations will be partially unrolled to reach this threshold. A value of 1.0
204+
disables recursive routing and falls back to unrolling all CircuitOperations.
194205
context: transformer context storing common configurable options for transformers.
195206
196207
Returns:
@@ -206,7 +217,20 @@ def route_circuit(
206217
ValueError: if circuit has operations that act on 3 or more qubits, except measurements.
207218
"""
208219

209-
# 0. Handle CircuitOperations by unrolling them.
220+
# 0. Handle CircuitOperations - use recursive routing if threshold < 1.0
221+
has_circuit_ops = self._has_circuit_operations(circuit)
222+
use_recursive_routing = has_circuit_ops and min_qubit_mapping_threshold < 1.0
223+
224+
if use_recursive_routing:
225+
return self._route_circuit_recursive(
226+
circuit=circuit,
227+
min_qubit_mapping_threshold=min_qubit_mapping_threshold,
228+
lookahead_radius=lookahead_radius,
229+
tag_inserted_swaps=tag_inserted_swaps,
230+
initial_mapper=initial_mapper,
231+
)
232+
233+
# Legacy behavior: unroll CircuitOperations if deep=True
210234
if context is not None and context.deep is True:
211235
circuit = transformer_primitives.unroll_circuit_op(circuit, deep=True)
212236
if any(
@@ -563,6 +587,87 @@ def _cost(
563587
mm.apply_swap(*swap)
564588
return max_length, sum_length
565589

590+
def _has_circuit_operations(self, circuit: cirq.AbstractCircuit) -> bool:
591+
"""Check if the circuit contains any CircuitOperations."""
592+
return any(
593+
isinstance(op.untagged, circuits.CircuitOperation) for op in circuit.all_operations()
594+
)
595+
596+
def _get_ops_outside_circuit_ops(
597+
self, circuit: cirq.AbstractCircuit
598+
) -> tuple[list[list[cirq.Operation]], list[list[cirq.Operation]]]:
599+
"""Get 2-qubit and single-qubit ops that are NOT inside CircuitOperations."""
600+
outer_circuit = circuits.Circuit()
601+
for moment in circuit:
602+
outer_moment = circuits.Moment(
603+
op for op in moment if not isinstance(op.untagged, circuits.CircuitOperation)
604+
)
605+
outer_circuit.append(outer_moment)
606+
return self._get_one_and_two_qubit_ops_as_timesteps(outer_circuit)
607+
608+
def _route_circuit_recursive(
609+
self,
610+
circuit: cirq.AbstractCircuit,
611+
min_qubit_mapping_threshold: float,
612+
lookahead_radius: int,
613+
tag_inserted_swaps: bool,
614+
initial_mapper: cirq.AbstractInitialMapper | None,
615+
) -> tuple[cirq.AbstractCircuit, dict[cirq.Qid, cirq.Qid], dict[cirq.Qid, cirq.Qid]]:
616+
"""Route a circuit containing CircuitOperations using recursive strategy."""
617+
if initial_mapper is None:
618+
initial_mapper = line_initial_mapper.LineInitialMapper(self.device_graph)
619+
620+
num_total_qubits = len(list(circuit.all_qubits()))
621+
outer_two_qubit_ops, outer_single_qubit_ops = self._get_ops_outside_circuit_ops(circuit)
622+
outer_qubits = {q for ops in outer_two_qubit_ops for op in ops for q in op.qubits}
623+
624+
if len(outer_qubits) / num_total_qubits >= min_qubit_mapping_threshold:
625+
outer_for_map = circuits.Circuit(op for ops in outer_two_qubit_ops for op in ops)
626+
initial_mapping = initial_mapper.initial_mapping(outer_for_map)
627+
else:
628+
initial_mapping = initial_mapper.initial_mapping(circuit)
629+
630+
mm = mapping_manager.MappingManager(self.device_graph, initial_mapping)
631+
632+
circuit_ops = [
633+
(i, op, op.untagged)
634+
for i, m in enumerate(circuit)
635+
for op in m
636+
if isinstance(op.untagged, circuits.CircuitOperation)
637+
]
638+
639+
routed_ops, routing_swaps = self._route(
640+
mm,
641+
outer_two_qubit_ops,
642+
outer_single_qubit_ops,
643+
lookahead_radius,
644+
tag_inserted_swaps=tag_inserted_swaps,
645+
)
646+
647+
routed_circuit = circuits.Circuit(circuits.Circuit(m) for m in routed_ops)
648+
649+
for _, _, circuit_op in circuit_ops:
650+
inner = circuit_op.circuit.unfreeze(copy=True)
651+
inner_routed, inner_init, _ = self.route_circuit(
652+
inner,
653+
lookahead_radius=lookahead_radius,
654+
tag_inserted_swaps=tag_inserted_swaps,
655+
initial_mapper=initial_mapper,
656+
min_qubit_mapping_threshold=1.0,
657+
)
658+
routed_circuit.append(circuits.Circuit(inner_routed).transform_qubits(inner_init))
659+
660+
if routing_swaps and nx.is_directed(self.device_graph):
661+
routed_circuit = circuits.Circuit(
662+
self._replace_swaps_with_directional_decomposition(routed_circuit, routing_swaps)
663+
)
664+
665+
final_mapping = {
666+
mm.int_to_logical_qid[k]: mm.int_to_physical_qid[v]
667+
for k, v in enumerate(mm.logical_to_physical)
668+
}
669+
return routed_circuit, initial_mapping, final_mapping
670+
566671
def __eq__(self, other) -> bool:
567672
return nx.utils.graphs_equal(self.device_graph, other.device_graph)
568673

cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,3 +414,68 @@ def test_repr() -> None:
414414
device_graph = device.metadata.nx_graph
415415
router = cirq.RouteCQC(device_graph)
416416
cirq.testing.assert_equivalent_repr(router, setup_code='import cirq\nimport networkx as nx')
417+
418+
419+
@pytest.mark.parametrize(
420+
"test_type, threshold, n_qubits",
421+
[
422+
("single_op", 0.1, 4),
423+
("single_op", 0.25, 4),
424+
("single_op", 0.5, 4),
425+
("single_op", 0.75, 4),
426+
("multiple_ops", 0.1, 3),
427+
("multiple_ops", 0.1, 4),
428+
("multiple_ops", 0.1, 5),
429+
("threshold_behavior", 0.25, 4),
430+
("threshold_behavior", 0.5, 4),
431+
("threshold_behavior", 0.75, 4),
432+
],
433+
)
434+
def test_circuit_operations_recursive_routing(test_type, threshold, n_qubits) -> None:
435+
"""Test recursive routing of circuits containing CircuitOperations."""
436+
device = cirq.testing.construct_grid_device(4, 4)
437+
router = cirq.RouteCQC(device.metadata.nx_graph)
438+
q = cirq.LineQubit.range(n_qubits)
439+
440+
if test_type == "single_op":
441+
inner_circuit = cirq.Circuit(cirq.CNOT(q[0], q[1]), cirq.CNOT(q[1], q[2]))
442+
outer_circuit = cirq.Circuit(
443+
cirq.CircuitOperation(inner_circuit.freeze()), cirq.CNOT(q[0], q[1])
444+
)
445+
elif test_type == "multiple_ops":
446+
inner1 = cirq.Circuit(cirq.CNOT(q[0], q[1]), cirq.CZ(q[1], q[2]))
447+
inner2 = cirq.Circuit(cirq.CNOT(q[-2], q[-1]), cirq.CZ(q[0], q[1]))
448+
outer_circuit = cirq.Circuit(
449+
cirq.CircuitOperation(inner1.freeze()),
450+
cirq.CNOT(q[0], q[n_qubits // 2]),
451+
cirq.CircuitOperation(inner2.freeze()),
452+
)
453+
elif test_type == "threshold_behavior":
454+
inner_circuit = cirq.Circuit(
455+
cirq.CNOT(q[0], q[1]), cirq.CNOT(q[1], q[2]), cirq.CNOT(q[2], q[3])
456+
)
457+
outer_circuit = cirq.Circuit(cirq.H(q[0]), cirq.CircuitOperation(inner_circuit.freeze()))
458+
459+
routed, _, _ = router.route_circuit(outer_circuit, min_qubit_mapping_threshold=threshold)
460+
device.validate_circuit(routed)
461+
assert len(list(routed.all_operations())) > 0
462+
463+
464+
def test_directed_device_recursive_routing() -> None:
465+
# Use a directed ring (strongly connected) so LineInitialMapper works
466+
device = cirq.testing.construct_ring_device(4, directed=True)
467+
device_graph = device.metadata.nx_graph
468+
router = cirq.RouteCQC(device_graph)
469+
470+
q = cirq.LineQubit.range(3)
471+
# Inner circuit with adjacent gates; outer circuit forces a swap
472+
inner_circuit = cirq.Circuit(cirq.CNOT(q[0], q[1]))
473+
outer_circuit = cirq.Circuit(
474+
cirq.CNOT(q[0], q[2]), # non-adjacent: forces a swap
475+
cirq.CircuitOperation(inner_circuit.freeze()),
476+
)
477+
478+
routed, _, _ = router.route_circuit(
479+
outer_circuit, min_qubit_mapping_threshold=0.5, tag_inserted_swaps=True
480+
)
481+
assert len(list(routed.all_operations())) > 0

0 commit comments

Comments
 (0)