Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 166 additions & 17 deletions src/braket/program_sets/circuit_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@
import warnings
from collections import defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass

from braket.default_simulator.openqasm.parser.openqasm_ast import (
Identifier,
IntegerLiteral,
QASMNode,
QuantumMeasurementStatement,
QubitDeclaration,
)
from braket.default_simulator.openqasm.parser.openqasm_parser import parse
from braket.ir.openqasm import Program

from braket.circuits import Circuit, Gate, Observable
Expand All @@ -29,10 +38,64 @@
from braket.registers import QubitSet


@dataclass
class _ParsedOpenQASM:
lines: Sequence[str]
declarations_index: int
measure_index: int
qubit_format: str
qubits: QubitSet


def _parse_openqasm(source: str) -> _ParsedOpenQASM:
lines = source.splitlines()
program = parse(source)
declarations_index = 1 if program.version is not None else 0
measure_index = len(lines)
register_name = None
register_size = 0
for stmt in program.statements:
if isinstance(stmt, QubitDeclaration) and register_name is None:
register_name = stmt.qubit.name
# stmt.size is None for a single unindexed qubit
register_size = stmt.size.value if isinstance(stmt.size, IntegerLiteral) else 1
if isinstance(stmt, QuantumMeasurementStatement) and measure_index == len(lines):
# span is 1-indexed; convert to a 0-indexed line index.
measure_index = stmt.span.start_line - 1

qubit_format, qubits = (
(f"{register_name}[{{}}]", QubitSet(range(register_size)))
if register_name
else (
"${}",
QubitSet(
int(node.name[1:])
for node in _walk(program)
if isinstance(node, Identifier) and node.name.startswith("$")
),
)
)
return _ParsedOpenQASM(
lines=lines,
declarations_index=declarations_index,
measure_index=measure_index,
qubit_format=qubit_format,
qubits=qubits,
)


def _walk(node: QASMNode):
yield node
for value in vars(node).values():
for child in value if isinstance(value, list) else [value]:
if isinstance(child, QASMNode):
yield from _walk(child)


class CircuitBinding:
def __init__(
self,
circuit: Circuit,
circuit: Circuit | str,
input_sets: ParameterSetsLike | None = None,
observables: Sequence[Observable | PauliString | str] | Sum | None = None,
):
Expand All @@ -51,7 +114,8 @@ def __init__(
Note: Circuits cannot have result types attached.

Args:
circuit (Circuit): The parametrized circuit
circuit (Circuit | str): The parametrized circuit, either as a Circuit object or as
an OpenQASM string.
input_sets (ParameterSetsLike | None): The inputs to the circuit, if specified.
observables (Sequence[Observable | PauliString | str] | Sum | None): The observables
or Hamiltonian to measure, if specified.
Expand All @@ -70,9 +134,10 @@ def __init__(
and any(isinstance(obs, Sum) for obs in observables)
):
raise TypeError("Cannot have Sum Hamiltonian in list of observables")
if circuit.result_types:
if isinstance(circuit, Circuit) and circuit.result_types:
raise ValueError("Circuit cannot have result types")
self._circuit = circuit
self._parsed_source = _parse_openqasm(circuit) if isinstance(circuit, str) else None
self._input_sets = ParameterSets(input_sets) if input_sets else None
self._observables = CircuitBinding._to_observables(observables)

Expand All @@ -96,9 +161,9 @@ def _to_observables(
return obs

@property
def circuit(self) -> Circuit:
def circuit(self) -> Circuit | str:
"""
Circuit: The parametrized circuit
Circuit | str: The parametrized circuit, either as a Circuit object or an OpenQASM string.
"""
return self._circuit

Expand Down Expand Up @@ -135,23 +200,56 @@ def to_ir(
"""
if not self._observables:
return Program(
source=self._circuit.to_ir(
IRType.OPENQASM, gate_definitions=gate_definitions
).source,
source=self._circuit_source(gate_definitions),
inputs=self._input_sets.as_dict() if self._input_sets else None,
)
# with_euler_angles validates that the observable has valid Euler angle gates
circuit_with_euler_angles = self._circuit.with_euler_angles(self._observables)
euler_angles = self._get_euler_angles()
if isinstance(self._circuit, Circuit):
source = (
self._circuit
.with_euler_angles(self._observables)
.to_ir(IRType.OPENQASM, gate_definitions=gate_definitions)
.source
)
else:
source = _inject_euler_angles(
self._parsed_source, self._euler_rotation_targets(), euler_angles.keys()
)
return Program(
source=circuit_with_euler_angles.to_ir(
IRType.OPENQASM, gate_definitions=gate_definitions
).source,
source=source,
inputs=(
self._input_sets * euler_angles if self._input_sets else ParameterSets(euler_angles)
).as_dict(),
)

def _circuit_source(
self,
gate_definitions: Mapping[tuple[Gate, QubitSet], PulseSequence] | None,
) -> str:
if isinstance(self._circuit, Circuit):
return self._circuit.to_ir(IRType.OPENQASM, gate_definitions=gate_definitions).source
return self._circuit

def _circuit_qubits(self) -> QubitSet:
if isinstance(self._circuit, Circuit):
return self._circuit.qubits
return self._parsed_source.qubits

def _euler_rotation_targets(self) -> QubitSet:
observables = self._observables
circuit_qubits = self._circuit_qubits()
if isinstance(observables, Sum):
if observables.targets:
return QubitSet(t for obs in observables.summands for t in obs.targets)
return circuit_qubits
targets = QubitSet()
for obs in observables:
if obs.targets:
targets |= obs.targets
else:
targets |= circuit_qubits
return targets

def _get_euler_angles(self) -> dict[str, float] | None:
observables = self._observables
return (
Expand All @@ -164,7 +262,7 @@ def _get_euler_angles_sum(self, observables: Sum) -> dict[str, float]:
euler_angles = defaultdict(list)
summands = observables.summands
if not observables.targets:
targets = self._circuit.qubits
targets = self._circuit_qubits()
for obs in summands:
for param, angle in obs.get_euler_angles(targets).items():
euler_angles[param].append(angle)
Expand All @@ -179,7 +277,7 @@ def _get_euler_angles_sum(self, observables: Sum) -> dict[str, float]:

def _get_euler_angles_list(self, observables: Sequence[Observable]) -> dict[str, float]:
euler_angles = defaultdict(list)
circuit_qubits = self._circuit.qubits
circuit_qubits = self._circuit_qubits()
targets = QubitSet(q for obs in observables for q in (obs.targets or circuit_qubits))
for obs in observables:
if not obs.targets:
Expand Down Expand Up @@ -207,12 +305,17 @@ def bind_observables_to_inputs(
well as CompositeEntry.expectation.

Kwargs:
inplace (bool): whether or not to return a new circuit binding or use the same one
add_measure (bool): whether or not to apply Measure instructions to the circuit
inplace (bool): Whether to return a new circuit binding or use the same one
add_measure (bool): Whether to apply Measure instructions to the circuit. Only
applies when the underlying circuit is a `Circuit`; for OpenQASM string
circuits, the source is preserved verbatim aside from injected Euler-angle
rotations.

Returns:
CircuitBinding: A new circuit binding with the observables bound.
"""
if isinstance(self._circuit, str):
return self._bind_observables_to_inputs_str(inplace)
measure = Circuit()
parameters = self._input_sets.as_dict() if self._input_sets else None
if observables := self._observables:
Expand All @@ -236,6 +339,29 @@ def bind_observables_to_inputs(
return self
return CircuitBinding(self._circuit + measure, input_sets=parameters)

def _bind_observables_to_inputs_str(self, inplace: bool) -> CircuitBinding:
source = self._circuit
parameters = self._input_sets.as_dict() if self._input_sets else None
if observables := self._observables:
if isinstance(observables, Sum):
warnings.warn(
"Binding a Sum discards information on observable weights; please "
"distribute your observable in advance using observable.summands.",
stacklevel=2,
)
euler_angles = self._get_euler_angles()
source = _inject_euler_angles(
self._parsed_source, self._euler_rotation_targets(), euler_angles.keys()
)
parameters = self._input_sets * euler_angles if parameters else euler_angles
if inplace:
self._circuit = source
self._parsed_source = _parse_openqasm(source)
self._observables = None
self._input_sets = parameters
return self
return CircuitBinding(source, input_sets=parameters)

def __len__(self):
input_sets = self._input_sets
observables = self._observables
Expand All @@ -260,3 +386,26 @@ def __repr__(self):
f"input_sets={self._input_sets}, "
f"observables={self._observables})"
)


def _inject_euler_angles(
parsed: _ParsedOpenQASM,
targets: QubitSet,
parameter_names: Sequence[str],
) -> str:
rotations = []
for q in targets:
theta, phi, omega = euler_angle_parameter_names(q)
formatted = parsed.qubit_format.format(int(q))
rotations.extend([
f"rz({theta}) {formatted};",
f"rx({phi}) {formatted};",
f"rz({omega}) {formatted};",
])
return "\n".join(
list(parsed.lines[: parsed.declarations_index])
+ [f"input float {name};" for name in parameter_names]
+ list(parsed.lines[parsed.declarations_index : parsed.measure_index])
+ rotations
+ list(parsed.lines[parsed.measure_index :])
)
Loading
Loading