Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 221 additions & 1 deletion src/braket/program_sets/program_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
from collections.abc import Iterator, Mapping, Sequence
from dataclasses import dataclass

from braket.ir.openqasm import ProgramSet as OpenQASMProgramSet

Expand Down Expand Up @@ -97,6 +98,167 @@ def total_shots(self) -> int:
raise ValueError("No per-executable shots defined")
return self._shots_per_executable * self.total_executables

def enumerate_executables(self) -> Iterator[tuple[int, int, int]]:
"""Yield ``(binding_index, parameter_set_index, observable_index)`` tuples in order,
one per executable.

The iteration order is: iterate over ``self.entries``; within each entry,
iterate over parameter set indices; within each parameter set index,
iterate over observable indices. The total number of yields is ``self.total_executables``.

For ``Circuit``s and ``CircuitBinding``s with no input sets, ``parameter_set_index`` is 0.
For entries with no observables, ``observable_index`` is 0. For ``CircuitBinding``s with a
``Sum`` Hamiltonian, ``observable_index`` ranges over the summands.

This ordering is used by ``split`` to build its index map and by
``ProgramSetQuantumTaskResult.merge`` to merge results back into the original shape.

Yields:
tuple[int, int, int]: ``(binding_index, parameter_set_index, observable_index)``.
"""
for binding_idx, prog in enumerate(self._programs):
if isinstance(prog, Circuit):
yield binding_idx, 0, 0
continue
num_obs = len(prog.observables) if prog.observables is not None else 1
for ps_idx in range(len(prog.input_sets) if prog.input_sets is not None else 1):
for obs_idx in range(num_obs):
yield binding_idx, ps_idx, obs_idx

def split(self, max_executables: int) -> tuple[list[ProgramSet], list[list[int]]]:
"""Split this program set into program sets of at most ``max_executables`` executables,
alongside a map that records the position in the original program set of each executable
in each of the generated program sets.

When a single parameter set index of a ``CircuitBinding`` would by itself exceed
``max_executables`` due to its observable list or ``Sum`` Hamiltonian being larger than
the budget, the observable list is split into chunks of at most ``max_executables`` entries
(``Sum`` summands are sliced with coefficients preserved). Observable splitting is only
performed when necessary; otherwise the full observable list or ``Sum`` is kept intact.

Args:
max_executables (int): The maximum number of executables per program
set. Must be positive.

Returns:
tuple[list[ProgramSet], list[list[int]]]: ``(program_sets, index_map)``.
``index_map[k][j]`` is the index of the executable that the j-th executable of
``program_sets[k]`` represents.
If this program set already fits within ``max_executables``, the returned
program-set list is ``[self]`` and the index_map is ``[[0, 1, ...,
total_executables - 1]]``.

Raises:
ValueError: If ``max_executables`` is not positive.

Examples:
>>> ps = ProgramSet([
... CircuitBinding(c1, inputs1, obs1), # 100 param sets, 4 observables
... CircuitBinding(c2, inputs2, obs2), # 50 param sets, 2 observables
... ])
>>> subs, index_map = ps.split(120)
>>> [s.total_executables for s in subs]
[120, 120, 120, 120, 20]
>>> sum(len(m) for m in index_map) == ps.total_executables
True
"""
if max_executables <= 0:
raise ValueError(f"max_executables must be positive, got {max_executables}")

if self.total_executables <= max_executables:
return [self], [list(range(self.total_executables))]

program_sets = []
index_map = []
current = []
current_size = 0
for block in self._executable_blocks(max_executables):
if current and current_size + block.size > max_executables:
sub, sub_map = self._build_program_set(current)
program_sets.append(sub)
index_map.append(sub_map)
current = []
current_size = 0
current.append(block)
current_size += block.size
sub, sub_map = self._build_program_set(current)
program_sets.append(sub)
index_map.append(sub_map)

return program_sets, index_map

def _executable_blocks(self, max_executables: int) -> list[_ExecutableBlock]:
blocks = []
orig_idx = 0
for prog_idx, prog in enumerate(self._programs):
if isinstance(prog, Circuit):
blocks.append(
_ExecutableBlock(
prog_idx=prog_idx,
param_set_index=None,
obs_slice=None,
size=1,
original_indices=[orig_idx],
)
)
orig_idx += 1
continue

num_ps = len(prog.input_sets) if prog.input_sets is not None else 1
obs_windows = _observable_windows(
len(prog.observables) if prog.observables is not None else 1, max_executables
)
split_observables = len(obs_windows) > 1
for ps_idx in range(num_ps) if prog.input_sets is not None else [None]:
for start, stop in obs_windows:
size = stop - start
blocks.append(
_ExecutableBlock(
prog_idx=prog_idx,
param_set_index=ps_idx,
obs_slice=slice(start, stop) if split_observables else None,
size=size,
original_indices=list(range(orig_idx, orig_idx + size)),
)
)
orig_idx += size
return blocks

def _build_program_set(self, blocks: list[_ExecutableBlock]) -> tuple[ProgramSet, list[int]]:
entries = []
sub_map = []
i = 0
while i < len(blocks):
head = blocks[i]
prog = self._programs[head.prog_idx]
if head.param_set_index is None:
entries.append(_apply_obs_slice(prog, head.obs_slice))
sub_map.extend(head.original_indices)
i += 1
continue

j = i
while (
j + 1 < len(blocks)
and blocks[j + 1].prog_idx == head.prog_idx
and blocks[j + 1].obs_slice == blocks[j].obs_slice
and blocks[j + 1].param_set_index == blocks[j].param_set_index + 1
):
j += 1
start = head.param_set_index
stop = blocks[j].param_set_index + 1
entries.append(
CircuitBinding(
prog.circuit,
input_sets=prog.input_sets.as_list()[start:stop],
observables=_slice_observables(prog.observables, head.obs_slice),
)
)
for k in range(i, j + 1):
sub_map.extend(blocks[k].original_indices)
i = j + 1
return ProgramSet(entries, self._shots_per_executable), sub_map

@staticmethod
def zip(
circuits: Sequence[Circuit] | CircuitBinding,
Expand Down Expand Up @@ -206,6 +368,64 @@ def __repr__(self):
)


@dataclass
class _ExecutableBlock:
"""Multi-index range for an equivalence class of executables sharing the same combination of
``(circuit, observable list/Sum Hamiltonian, single parameter assignment)``.

Attributes:
prog_idx: Index of the originating program in ``ProgramSet.entries``.
param_set_index: Index into the originating ``CircuitBinding``'s ``input_sets``, or ``None``
for ``Circuit`` entries and ``CircuitBinding``s with no input sets.
obs_slice: Slice into the originating observable list or ``Sum`` summands when observables
were split to fit the budget; ``None`` means the full original observable list
(or no observables).
size: Number of executables this block represents (== ``len(original_indices)``).
original_indices: The indices of this block's executables
in the order of the original program set.
"""

prog_idx: int
param_set_index: int | None
obs_slice: slice | None
size: int
original_indices: list[int]


def _observable_windows(num_observables: int, max_executables: int) -> list[tuple[int, int]]:
if num_observables <= max_executables:
return [(0, num_observables)]
windows = []
start = 0
while start < num_observables:
stop = min(start + max_executables, num_observables)
windows.append((start, stop))
start = stop
return windows


def _slice_observables(
observables: Sum | Sequence[Observable] | None, obs_slice: slice | None
) -> Sum | Sequence[Observable] | None:
if obs_slice is None or observables is None:
return observables
if isinstance(observables, Sum):
return Sum(list(observables.summands)[obs_slice])
return list(observables)[obs_slice]


def _apply_obs_slice(
prog: CircuitBinding | Circuit, obs_slice: slice | None
) -> CircuitBinding | Circuit:
if obs_slice is None or isinstance(prog, Circuit) or prog.observables is None:
return prog
return CircuitBinding(
prog.circuit,
input_sets=prog.input_sets,
observables=_slice_observables(prog.observables, obs_slice),
)


def _zip_circuit_bindings(
circuit_binding: CircuitBinding,
input_sets: Sequence[Mapping[str, float]] | None,
Expand Down
Loading
Loading