Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/braket/circuits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
AngledGate, # noqa: F401
DoubleAngledGate, # noqa: F401
)
from braket.circuits.circuit import Circuit # noqa: F401
from braket.circuits.circuit import Circuit, QubitMatch # noqa: F401
from braket.circuits.circuit_diagram import CircuitDiagram # noqa: F401
from braket.circuits.compiler_directive import CompilerDirective # noqa: F401
from braket.circuits.free_parameter import FreeParameter # noqa: F401
Expand Down
91 changes: 91 additions & 0 deletions src/braket/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import warnings
from collections import Counter
from collections.abc import Callable, Iterable, Sequence
from enum import StrEnum
from numbers import Number
from typing import Any, TypeVar

Expand Down Expand Up @@ -45,6 +46,7 @@
)
from braket.circuits.observable import Observable, euler_angle_parameter_names
from braket.circuits.observables import Sum, TensorProduct
from braket.circuits.operator import Operator
from braket.circuits.parameterizable import Parameterizable
from braket.circuits.result_type import (
ObservableParameterResultType,
Expand Down Expand Up @@ -73,6 +75,16 @@
AddableTypes = TypeVar("AddableTypes", SubroutineReturn, SubroutineCallable)


class QubitMatch(StrEnum):
"""Controls how multiple qubits are matched in count."""

ANY = "ANY"
ALL = "ALL"


OperatorIdentifier = str | type[Operator] | Operator


class Circuit:
"""A representation of a quantum circuit that contains the instructions to be performed on a
quantum device and the requested result types.
Expand Down Expand Up @@ -243,6 +255,85 @@ def parameters(self) -> set[FreeParameter]:
"""
return self._parameters

@staticmethod
def _normalize_operator_name(identifier: OperatorIdentifier) -> str:
if isinstance(identifier, type):
return identifier.__name__.upper()
if isinstance(identifier, str):
return identifier.upper()
return identifier.name.upper()

@staticmethod
def _to_operator_names(
operators: OperatorIdentifier | Iterable[OperatorIdentifier] | None,
) -> list[str]:
if operators is None:
return []
if isinstance(operators, (str, type, Operator)):
return [Circuit._normalize_operator_name(operators)]
return [Circuit._normalize_operator_name(op) for op in operators]

def count(
self,
operators: OperatorIdentifier | Iterable[OperatorIdentifier] | None = None,
qubits: QubitInput | Iterable[QubitInput] | None = None,
qubit_match: QubitMatch = QubitMatch.ANY,
include_types: Iterable[MomentType] = (MomentType.GATE,),
) -> Counter[str]:
"""
Count instructions in the circuit with optional filtering.

When both ``operators`` and ``qubits`` are specified, an instruction must satisfy
both filters to be counted (AND semantics).

Args:
operators: Filter by operator name or type. Defaults to None (no filter).
qubits: Filter by qubit. Matched against the union of target and control qubits.
qubit_match (QubitMatch): How multiple qubits relate. ANY = instruction on
any specified qubit; ALL = instruction on all specified qubits. Default ANY.
include_types (Iterable[MomentType]): Moment types to count. Default: GATE only.
Pass additional MomentType values to include noise, measures, etc.

Returns:
Counter[str]: Operator names mapped to occurrence counts.

Examples:
>>> circ = Circuit().h(0).cnot(0, 1).rx(0, 0.5)
>>> circ.count()
Counter({'H': 1, 'CNot': 1, 'Rx': 1})
>>> circ.count("h")
Counter({'H': 1})
>>> circ.count(["H", "CNot"])
Counter({'H': 1, 'CNot': 1})
>>> circ.count(qubits=0)
Counter({'H': 1, 'CNot': 1, 'Rx': 1})
"""
include_types_set = set(include_types)
operator_names_set = set(self._to_operator_names(operators))
qs = QubitSet(qubits) if qubits is not None else None
filter_qubits = qs or None # empty QubitSet treated as no filter

result: Counter[str] = Counter()

for key, instruction in self.moments.items():
if key.moment_type not in include_types_set:
continue

instr_qubits = instruction.target.union(instruction.control)
instr_name_upper = instruction.operator.name.upper()

qubit_pass = filter_qubits is None or (
any(q in instr_qubits for q in filter_qubits)
if qubit_match == QubitMatch.ANY
else all(q in instr_qubits for q in filter_qubits)
)
operator_pass = not operator_names_set or instr_name_upper in operator_names_set

if qubit_pass and operator_pass:
result[instruction.operator.name] += 1

return result

def with_euler_angles(self, observables: Sequence[Observable] | Sum) -> Circuit:
"""Returns a copy of the circuit with parametrized Euler angles on the observables' qubits

Expand Down
71 changes: 71 additions & 0 deletions test/unit_tests/braket/circuits/test_circuit_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from collections import Counter

import pytest

from braket.circuits import Circuit, gates
from braket.circuits.circuit import QubitMatch
from braket.circuits.moments import MomentType
from braket.circuits.noises import BitFlip


@pytest.fixture
def mixed_circuit():
return Circuit().h(0).cnot(0, 1).rx(0, 0.5).h(1)


def test_no_filters_returns_all_gates(mixed_circuit):
assert mixed_circuit.count() == Counter({"H": 2, "CNot": 1, "Rx": 1})


def test_operator_filter_multiple_mixed_identifiers(mixed_circuit):
assert mixed_circuit.count(operators=["h", gates.CNot]) == Counter({"H": 2, "CNot": 1})


def test_operator_filter_gate_instance(mixed_circuit):
assert mixed_circuit.count(operators=gates.CNot()) == Counter({"CNot": 1})


def test_include_gate_noise_type():
circ = Circuit().h(0)
circ.apply_gate_noise(BitFlip(0.1))
result = circ.count(include_types=[MomentType.GATE, MomentType.GATE_NOISE])
assert result == Counter({"H": 1, "BitFlip": 1})


def test_gate_noise_excluded_by_default():
circ = Circuit().h(0)
circ.apply_gate_noise(BitFlip(0.1))
assert circ.count() == Counter({"H": 1})


def test_qubit_filter_single_qubit(mixed_circuit):
assert mixed_circuit.count(qubits=0) == Counter({"H": 1, "CNot": 1, "Rx": 1})


def test_qubit_filter_multiple_qubits_any():
circ = Circuit().h(0).h(2).cnot(0, 1)
assert circ.count(qubits=[0, 2]) == Counter({"H": 2, "CNot": 1})


def test_qubit_filter_multiple_qubits_all(mixed_circuit):
assert mixed_circuit.count(qubits=[0, 1], qubit_match=QubitMatch.ALL) == Counter({"CNot": 1})


def test_operator_and_qubit_filters_intersect(mixed_circuit):
assert mixed_circuit.count(operators="CNot", qubits=[1]) == Counter({"CNot": 1})


def test_unknown_operator_returns_empty(mixed_circuit):
assert mixed_circuit.count(operators="ZZZ") == Counter()


def test_qubit_not_in_circuit_returns_empty(mixed_circuit):
assert mixed_circuit.count(qubits=99) == Counter()


def test_partial_qubits_not_in_circuit_any(mixed_circuit):
assert mixed_circuit.count(qubits=[0, 99]) == Counter({"H": 1, "CNot": 1, "Rx": 1})


def test_partial_qubits_not_in_circuit_all(mixed_circuit):
assert mixed_circuit.count(qubits=[0, 99], qubit_match=QubitMatch.ALL) == Counter()
Loading