diff --git a/src/braket/program_sets/circuit_binding.py b/src/braket/program_sets/circuit_binding.py index 320309f40..8b8008420 100644 --- a/src/braket/program_sets/circuit_binding.py +++ b/src/braket/program_sets/circuit_binding.py @@ -16,7 +16,16 @@ import warnings from collections import defaultdict from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from braket.default_simulator.openqasm.parser.openqasm_ast import ( + Identifier, + IntegerLiteral, + QASMNode, + QuantumMeasurementStatement, + QubitDeclaration, +) +from braket.default_simulator.openqasm.parser.openqasm_parser import parse from braket.ir.openqasm import Program from braket.circuits import Circuit, Gate, Observable @@ -29,10 +38,64 @@ from braket.registers import QubitSet +@dataclass +class _ParsedOpenQASM: + lines: Sequence[str] + declarations_index: int + measure_index: int + qubit_format: str + qubits: QubitSet + + +def _parse_openqasm(source: str) -> _ParsedOpenQASM: + lines = source.splitlines() + program = parse(source) + declarations_index = 1 if program.version is not None else 0 + measure_index = len(lines) + register_name = None + register_size = 0 + for stmt in program.statements: + if isinstance(stmt, QubitDeclaration) and register_name is None: + register_name = stmt.qubit.name + # stmt.size is None for a single unindexed qubit + register_size = stmt.size.value if isinstance(stmt.size, IntegerLiteral) else 1 + if isinstance(stmt, QuantumMeasurementStatement) and measure_index == len(lines): + # span is 1-indexed; convert to a 0-indexed line index. + measure_index = stmt.span.start_line - 1 + + qubit_format, qubits = ( + (f"{register_name}[{{}}]", QubitSet(range(register_size))) + if register_name + else ( + "${}", + QubitSet( + int(node.name[1:]) + for node in _walk(program) + if isinstance(node, Identifier) and node.name.startswith("$") + ), + ) + ) + return _ParsedOpenQASM( + lines=lines, + declarations_index=declarations_index, + measure_index=measure_index, + qubit_format=qubit_format, + qubits=qubits, + ) + + +def _walk(node: QASMNode): + yield node + for value in vars(node).values(): + for child in value if isinstance(value, list) else [value]: + if isinstance(child, QASMNode): + yield from _walk(child) + + class CircuitBinding: def __init__( self, - circuit: Circuit, + circuit: Circuit | str, input_sets: ParameterSetsLike | None = None, observables: Sequence[Observable | PauliString | str] | Sum | None = None, ): @@ -51,7 +114,8 @@ def __init__( Note: Circuits cannot have result types attached. Args: - circuit (Circuit): The parametrized circuit + circuit (Circuit | str): The parametrized circuit, either as a Circuit object or as + an OpenQASM string. input_sets (ParameterSetsLike | None): The inputs to the circuit, if specified. observables (Sequence[Observable | PauliString | str] | Sum | None): The observables or Hamiltonian to measure, if specified. @@ -70,9 +134,10 @@ def __init__( and any(isinstance(obs, Sum) for obs in observables) ): raise TypeError("Cannot have Sum Hamiltonian in list of observables") - if circuit.result_types: + if isinstance(circuit, Circuit) and circuit.result_types: raise ValueError("Circuit cannot have result types") self._circuit = circuit + self._parsed_source = _parse_openqasm(circuit) if isinstance(circuit, str) else None self._input_sets = ParameterSets(input_sets) if input_sets else None self._observables = CircuitBinding._to_observables(observables) @@ -96,9 +161,9 @@ def _to_observables( return obs @property - def circuit(self) -> Circuit: + def circuit(self) -> Circuit | str: """ - Circuit: The parametrized circuit + Circuit | str: The parametrized circuit, either as a Circuit object or an OpenQASM string. """ return self._circuit @@ -135,23 +200,56 @@ def to_ir( """ if not self._observables: return Program( - source=self._circuit.to_ir( - IRType.OPENQASM, gate_definitions=gate_definitions - ).source, + source=self._circuit_source(gate_definitions), inputs=self._input_sets.as_dict() if self._input_sets else None, ) - # with_euler_angles validates that the observable has valid Euler angle gates - circuit_with_euler_angles = self._circuit.with_euler_angles(self._observables) euler_angles = self._get_euler_angles() + if isinstance(self._circuit, Circuit): + source = ( + self._circuit + .with_euler_angles(self._observables) + .to_ir(IRType.OPENQASM, gate_definitions=gate_definitions) + .source + ) + else: + source = _inject_euler_angles( + self._parsed_source, self._euler_rotation_targets(), euler_angles.keys() + ) return Program( - source=circuit_with_euler_angles.to_ir( - IRType.OPENQASM, gate_definitions=gate_definitions - ).source, + source=source, inputs=( self._input_sets * euler_angles if self._input_sets else ParameterSets(euler_angles) ).as_dict(), ) + def _circuit_source( + self, + gate_definitions: Mapping[tuple[Gate, QubitSet], PulseSequence] | None, + ) -> str: + if isinstance(self._circuit, Circuit): + return self._circuit.to_ir(IRType.OPENQASM, gate_definitions=gate_definitions).source + return self._circuit + + def _circuit_qubits(self) -> QubitSet: + if isinstance(self._circuit, Circuit): + return self._circuit.qubits + return self._parsed_source.qubits + + def _euler_rotation_targets(self) -> QubitSet: + observables = self._observables + circuit_qubits = self._circuit_qubits() + if isinstance(observables, Sum): + if observables.targets: + return QubitSet(t for obs in observables.summands for t in obs.targets) + return circuit_qubits + targets = QubitSet() + for obs in observables: + if obs.targets: + targets |= obs.targets + else: + targets |= circuit_qubits + return targets + def _get_euler_angles(self) -> dict[str, float] | None: observables = self._observables return ( @@ -164,7 +262,7 @@ def _get_euler_angles_sum(self, observables: Sum) -> dict[str, float]: euler_angles = defaultdict(list) summands = observables.summands if not observables.targets: - targets = self._circuit.qubits + targets = self._circuit_qubits() for obs in summands: for param, angle in obs.get_euler_angles(targets).items(): euler_angles[param].append(angle) @@ -179,7 +277,7 @@ def _get_euler_angles_sum(self, observables: Sum) -> dict[str, float]: def _get_euler_angles_list(self, observables: Sequence[Observable]) -> dict[str, float]: euler_angles = defaultdict(list) - circuit_qubits = self._circuit.qubits + circuit_qubits = self._circuit_qubits() targets = QubitSet(q for obs in observables for q in (obs.targets or circuit_qubits)) for obs in observables: if not obs.targets: @@ -207,12 +305,17 @@ def bind_observables_to_inputs( well as CompositeEntry.expectation. Kwargs: - inplace (bool): whether or not to return a new circuit binding or use the same one - add_measure (bool): whether or not to apply Measure instructions to the circuit + inplace (bool): Whether to return a new circuit binding or use the same one + add_measure (bool): Whether to apply Measure instructions to the circuit. Only + applies when the underlying circuit is a `Circuit`; for OpenQASM string + circuits, the source is preserved verbatim aside from injected Euler-angle + rotations. Returns: CircuitBinding: A new circuit binding with the observables bound. """ + if isinstance(self._circuit, str): + return self._bind_observables_to_inputs_str(inplace) measure = Circuit() parameters = self._input_sets.as_dict() if self._input_sets else None if observables := self._observables: @@ -236,6 +339,29 @@ def bind_observables_to_inputs( return self return CircuitBinding(self._circuit + measure, input_sets=parameters) + def _bind_observables_to_inputs_str(self, inplace: bool) -> CircuitBinding: + source = self._circuit + parameters = self._input_sets.as_dict() if self._input_sets else None + if observables := self._observables: + if isinstance(observables, Sum): + warnings.warn( + "Binding a Sum discards information on observable weights; please " + "distribute your observable in advance using observable.summands.", + stacklevel=2, + ) + euler_angles = self._get_euler_angles() + source = _inject_euler_angles( + self._parsed_source, self._euler_rotation_targets(), euler_angles.keys() + ) + parameters = self._input_sets * euler_angles if parameters else euler_angles + if inplace: + self._circuit = source + self._parsed_source = _parse_openqasm(source) + self._observables = None + self._input_sets = parameters + return self + return CircuitBinding(source, input_sets=parameters) + def __len__(self): input_sets = self._input_sets observables = self._observables @@ -260,3 +386,26 @@ def __repr__(self): f"input_sets={self._input_sets}, " f"observables={self._observables})" ) + + +def _inject_euler_angles( + parsed: _ParsedOpenQASM, + targets: QubitSet, + parameter_names: Sequence[str], +) -> str: + rotations = [] + for q in targets: + theta, phi, omega = euler_angle_parameter_names(q) + formatted = parsed.qubit_format.format(int(q)) + rotations.extend([ + f"rz({theta}) {formatted};", + f"rx({phi}) {formatted};", + f"rz({omega}) {formatted};", + ]) + return "\n".join( + list(parsed.lines[: parsed.declarations_index]) + + [f"input float {name};" for name in parameter_names] + + list(parsed.lines[parsed.declarations_index : parsed.measure_index]) + + rotations + + list(parsed.lines[parsed.measure_index :]) + ) diff --git a/test/unit_tests/braket/program_sets/test_circuit_binding.py b/test/unit_tests/braket/program_sets/test_circuit_binding.py index 9af083fad..049805a26 100644 --- a/test/unit_tests/braket/program_sets/test_circuit_binding.py +++ b/test/unit_tests/braket/program_sets/test_circuit_binding.py @@ -15,9 +15,15 @@ from braket.circuits import Circuit from braket.circuits.observables import X, Y, Z +from braket.circuits.serialization import IRType from braket.parametric import FreeParameter from braket.program_sets import CircuitBinding from braket.quantum_information import PauliString +from braket.registers import QubitSet + + +def _source(circuit): + return circuit.to_ir(IRType.OPENQASM).source def test_equality(circuit_rx_parametrized): @@ -168,3 +174,179 @@ def test_binding_without_measure(circuit_rx_parametrized): circ = cb2.circuit circ.measure(range(2)) assert circ == cb3.circuit + + +def test_string_circuit_no_observables(circuit_rx_parametrized): + src = _source(circuit_rx_parametrized) + cb = CircuitBinding(src, input_sets={"theta": [1.23, 3.21]}) + program = cb.to_ir() + assert program.source == src + assert program.inputs == {"theta": [1.23, 3.21]} + + +def test_string_circuit_matches_circuit_to_ir(circuit_rx_parametrized): + circuit = Circuit(circuit_rx_parametrized).cnot(0, 1) + src = _source(circuit) + observable = [X(0) @ Z(1)] + cb_circ = CircuitBinding(circuit, {"theta": [1.23]}, observable) + cb_str = CircuitBinding(src, {"theta": [1.23]}, observable) + circ_program = cb_circ.to_ir() + str_program = cb_str.to_ir() + assert circ_program.inputs == str_program.inputs + # Same set of statements; declaration ordering may differ between paths. + assert sorted(circ_program.source.splitlines()) == sorted(str_program.source.splitlines()) + + +def test_string_circuit_targetless_observable(circuit_rx_parametrized): + circuit = Circuit(circuit_rx_parametrized).cnot(0, 1) + src = _source(circuit) + cb_circ = CircuitBinding(circuit, observables=[X() @ Y()]) + cb_str = CircuitBinding(src, observables=[X() @ Y()]) + assert cb_str.to_ir().inputs == cb_circ.to_ir().inputs + + +def test_string_circuit_sum_observable(circuit_rx_parametrized): + circuit = Circuit(circuit_rx_parametrized).cnot(0, 1) + src = _source(circuit) + h = 0.5 * X(0) @ Z(1) + 2 * Y(0) + cb_circ = CircuitBinding(circuit, observables=h) + cb_str = CircuitBinding(src, observables=h) + assert cb_str.to_ir().inputs == cb_circ.to_ir().inputs + + +def test_string_circuit_targetless_sum_observable(circuit_rx_parametrized): + circuit = Circuit(circuit_rx_parametrized).cnot(0, 1) + src = _source(circuit) + h_targetless = X() @ Y() - 3 * Z() @ X() + cb_circ = CircuitBinding(circuit, observables=h_targetless) + cb_str = CircuitBinding(src, observables=h_targetless) + assert cb_str.to_ir().inputs == cb_circ.to_ir().inputs + + +def test_string_circuit_bind_sum_warning(circuit_rx_parametrized): + src = _source(circuit_rx_parametrized) + cb = CircuitBinding(src, observables=0.5 * X(0) @ Z(1) + 2 * Y(0)) + with pytest.warns(UserWarning): + cb.bind_observables_to_inputs() + + +def test_string_circuit_custom_register_name(): + src = ( + "OPENQASM 3.0;\n" + "input float theta;\n" + "bit[2] b;\n" + "qubit[2] foo;\n" + "rx(theta) foo[0];\n" + "cnot foo[0], foo[1];\n" + "b[0] = measure foo[0];\n" + "b[1] = measure foo[1];" + ) + cb = CircuitBinding(src, input_sets={"theta": [0.5]}, observables=[X(0) @ Y(1)]) + serialized = cb.to_ir().source + assert "rz(_OBSERVABLE_THETA_0) foo[0];" in serialized + assert "rz(_OBSERVABLE_THETA_1) foo[1];" in serialized + assert "q[0]" not in serialized + assert "q[1]" not in serialized + + +def test_string_circuit_no_measure_or_version_line(): + # Source with no `OPENQASM 3.0;` line and no measurement — exercises the fallback branches + # of _parse_source for declarations_index and measure_index. + src = "qubit[2] q;\nh q[0];\ncnot q[0], q[1];" + cb = CircuitBinding(src, observables=[X(0) @ Y(1)]) + serialized = cb.to_ir().source + # Declarations should be prepended (no OPENQASM line to anchor on). + assert serialized.startswith("input float ") + # Rotations should land at the end (no measurement to anchor on). + assert serialized.rstrip().endswith("rz(_OBSERVABLE_OMEGA_1) q[1];") + + +def test_parse_openqasm_physical_when_no_register_declared(): + # With no `qubit[N] ;` declaration, the source is treated as addressing physical + # qubits (`$N`). + from braket.program_sets.circuit_binding import _parse_openqasm # noqa: PLC0415 + + parsed = _parse_openqasm("rx(0.5) $0;\ncnot $0, $1;") + assert parsed.qubit_format == "${}" + assert parsed.qubits == QubitSet([0, 1]) + + +def test_parse_openqasm_qubit_count_from_declaration(): + # The qubit set comes from the register declaration size, covering broadcast gates + # like `h q;` that carry no explicit index. + from braket.program_sets.circuit_binding import _parse_openqasm # noqa: PLC0415 + + parsed = _parse_openqasm("OPENQASM 3.0;\nqubit[3] q;\nh q;") + assert parsed.qubit_format == "q[{}]" + assert parsed.qubits == QubitSet([0, 1, 2]) + + +def test_parse_openqasm_single_unindexed_qubit(): + # `qubit q;` (no size) declares a single qubit at index 0. + from braket.program_sets.circuit_binding import _parse_openqasm # noqa: PLC0415 + + parsed = _parse_openqasm("OPENQASM 3.0;\nqubit q;\nh q;") + assert parsed.qubits == QubitSet([0]) + + +def test_string_circuit_broadcast_gate(): + # A register-broadcast gate (`h q;`) should still yield Euler rotations on every qubit. + src = "OPENQASM 3.0;\nqubit[2] q;\nh q;" + cb = CircuitBinding(src, observables=[X(0) @ Y(1)]) + serialized = cb.to_ir().source + assert "rz(_OBSERVABLE_THETA_0) q[0];" in serialized + assert "rz(_OBSERVABLE_THETA_1) q[1];" in serialized + + +def test_circuit_binding_dunders(circuit_rx_parametrized): + # Exercises the input_sets/observables properties and __len__/__repr__. + input_sets = {"theta": [1.23, 3.21, 0.5]} + observable = [X(0), Y(0)] + cb = CircuitBinding(circuit_rx_parametrized, input_sets, observable) + assert cb.input_sets.as_dict() == input_sets + assert list(cb.observables) == observable + assert len(cb) == 6 + assert len(CircuitBinding(circuit_rx_parametrized, input_sets)) == 3 + assert len(CircuitBinding(circuit_rx_parametrized, observables=observable)) == 2 + assert "CircuitBinding(circuit=" in repr(cb) + + +def test_string_circuit_physical_qubits(circuit_rx_parametrized): + verbatim = Circuit().add_verbatim_box(Circuit().rx(1, FreeParameter("theta")).cnot(1, 0)) + src = _source(verbatim) + cb = CircuitBinding(src, input_sets={"theta": [0.5]}, observables=[X(0) @ Y(1)]) + serialized = cb.to_ir().source + # Inserted Euler rotations target physical qubits, not virtual q[i]. + assert "rz(_OBSERVABLE_THETA_0) $0;" in serialized + assert "rz(_OBSERVABLE_THETA_1) $1;" in serialized + assert "q[" not in serialized.replace("bit[", "") + + +def test_string_circuit_bind_observables_to_inputs(circuit_rx_parametrized): + src = _source(circuit_rx_parametrized) + observable = [X(0) @ Z(1), Y(0), Z(0)] + cb1 = CircuitBinding(src, input_sets={"theta": [1.35, 1.58]}, observables=observable) + + cb2 = cb1.bind_observables_to_inputs(inplace=False) + assert cb1 != cb2 + assert cb1.to_ir() == cb2.to_ir() + + cb1.bind_observables_to_inputs(inplace=True) + assert cb1 == cb2 + + +def test_string_circuit_bind_no_observables(circuit_rx_parametrized): + src = _source(circuit_rx_parametrized) + cb1 = CircuitBinding(src, input_sets={"theta": [1.35, 1.58]}) + cb2 = cb1.bind_observables_to_inputs(inplace=False) + assert cb1 == cb2 + + +def test_string_circuit_equality(): + src = "OPENQASM 3.0;\ninput float theta;\nbit[1] b;\nqubit[1] q;\nrx(theta) q[0];" + input_sets = {"theta": [1.23, 3.21]} + observable = [X(0)] + cb = CircuitBinding(src, input_sets, observable) + assert cb == CircuitBinding(src, input_sets, observable) + assert cb != CircuitBinding(src, observables=observable) + assert cb != CircuitBinding(src, input_sets) diff --git a/tox.ini b/tox.ini index c43008c4b..52f78bde2 100644 --- a/tox.ini +++ b/tox.ini @@ -17,7 +17,7 @@ extras = test deps = {[test-deps]deps} commands = - pytest {posargs} --cov=braket --cov-report=term-missing --cov-report=html --cov-report=xml --cov-append + pytest {posargs} --cov --cov-report=term-missing --cov-report=html --cov-report=xml --cov-append [testenv:integ-tests] extras = test