Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 298 additions & 20 deletions src/braket/tasks/program_set_quantum_task_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from collections import Counter
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, replace

import boto3
import numpy as np
Expand All @@ -31,10 +31,15 @@
ProgramSetTaskMetadata,
ProgramSetTaskResult,
)
from braket.task_result.program_set_executable_result_v1 import (
ProgramSetExecutableResultMetadata,
)
from braket.task_result.program_set_task_metadata_v1 import ProgramMetadata

from braket.circuits import Observable
from braket.circuits import Circuit, Observable
from braket.circuits.observable import EULER_OBSERVABLE_PREFIX
from braket.circuits.observables import Sum
from braket.circuits.serialization import IRType
from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet
from braket.tasks.measurement_utils import (
expectation_from_measurements,
Expand Down Expand Up @@ -141,7 +146,6 @@ def expectation(self) -> float | None:
return self._expectation


@dataclass
class CompositeEntry:
"""Results of a program in a program set

Expand All @@ -152,15 +156,70 @@ class CompositeEntry:
observables (Sum | list[Observable] | None): The Sum Hamiltonian or observables
that were measured, if any.
shots_per_executable (int): The number of shots each underlying executable was run with
additional_metadata (AdditionalMetadata): Additional metadata about this program
additional_metadata (AdditionalMetadata | None): Additional metadata about this program.
``None`` for entries produced by ``ProgramSetQuantumTaskResult.merge``,
since per-program metadata cannot be aggregated meaningfully across underlying tasks.
"""

entries: list[MeasuredEntry]
program: Program
inputs: ParameterSets
observables: Sum | list[Observable] | None
shots_per_executable: int
additional_metadata: AdditionalMetadata
def __init__(
self,
entries: list[MeasuredEntry],
program: Program,
inputs: ParameterSets,
observables: Sum | list[Observable] | None,
shots_per_executable: int,
additional_metadata: AdditionalMetadata | None,
):
self._entries = entries
self._program = program
self._inputs = inputs
self._observables = observables
self._shots_per_executable = shots_per_executable
self._additional_metadata = additional_metadata
self._was_merged = False
self._expectations = self._compute_expectations() if isinstance(observables, Sum) else None

@property
def entries(self) -> list[MeasuredEntry]:
"""list[MeasuredEntry]: The results of each executable in this program."""
return self._entries

@property
def program(self) -> Program:
"""Program: The program that was run."""
return self._program

@property
def inputs(self) -> ParameterSets:
"""ParameterSets: The input values this program was run with."""
return self._inputs

@property
def observables(self) -> Sum | list[Observable] | None:
"""Sum | list[Observable] | None: The Sum Hamiltonian or observables measured,
if any."""
return self._observables

@property
def shots_per_executable(self) -> int:
"""int: The number of shots each underlying executable was run with."""
return self._shots_per_executable

@property
def additional_metadata(self) -> AdditionalMetadata | None:
"""AdditionalMetadata | None: Additional metadata about this program.

For entries produced by ``ProgramSetQuantumTaskResult.merge``, this will be ``None``;
Use the original per-task results for true per-program metadata.
"""
if self._was_merged:
warnings.warn(
"additional_metadata for a CompositeEntry on a merged "
"ProgramSetQuantumTaskResult is None; "
"use the original per-task results for true per-program metadata.",
stacklevel=2,
)
return self._additional_metadata

@staticmethod
def _from_object(
Expand Down Expand Up @@ -192,11 +251,6 @@ def _from_object(
additional_metadata=program_result.additionalMetadata,
)

def __post_init__(self):
self._expectations = (
self._compute_expectations() if isinstance(self.observables, Sum) else None
)

def __len__(self):
return len(self.entries)

Expand Down Expand Up @@ -310,7 +364,6 @@ def _dispatch_executable_result(
)


@dataclass
class ProgramSetQuantumTaskResult:
"""The result of a program set task.

Expand All @@ -323,10 +376,45 @@ class ProgramSetQuantumTaskResult:
can be automatically computed.
"""

entries: list[CompositeEntry]
task_metadata: ProgramSetTaskMetadata
num_executables: int
program_set: ProgramSet | None
def __init__(
self,
entries: list[CompositeEntry],
task_metadata: ProgramSetTaskMetadata,
num_executables: int,
program_set: ProgramSet | None,
):
self._entries = entries
self._task_metadata = task_metadata
self._num_executables = num_executables
self._program_set = program_set
self._was_merged = False

@property
def entries(self) -> list[CompositeEntry]:
"""list[CompositeEntry]: The results of each program in this program set."""
return self._entries

@property
def task_metadata(self) -> ProgramSetTaskMetadata:
"""ProgramSetTaskMetadata: The metadata of the task."""
if self._was_merged:
warnings.warn(
"task_metadata for a merged ProgramSetQuantumTaskResult is synthesized "
"from multiple underlying tasks; it does not reflect any one underlying task. "
"Use the original per-task results for true task metadata.",
stacklevel=2,
)
return self._task_metadata

@property
def num_executables(self) -> int:
"""int: The total number of executables in this program set task."""
return self._num_executables

@property
def program_set(self) -> ProgramSet | None:
"""ProgramSet | None: The program set that was run, if provided to the constructor."""
return self._program_set

@staticmethod
def from_object(
Expand Down Expand Up @@ -370,6 +458,117 @@ def from_object(
program_set=program_set,
)

@staticmethod
def merge(
results: Sequence[ProgramSetQuantumTaskResult],
program_set: ProgramSet,
index_map: list[list[int]],
) -> ProgramSetQuantumTaskResult:
"""Reconstruct a ``ProgramSetQuantumTaskResult`` from the task results produced by running
each program set of ``program_set.split(...)``.

``index_map`` is the per-executable map returned alongside the program sets by
``ProgramSet.split``: ``index_map[k][j]`` gives the index, in the order of ``program_set``,
of the executable that the jth executable of the kth task represents. The kth task's
executables are read in order for its program set, namely across ``results[k].entries``,
and within each ``CompositeEntry`` across its ``entries``.

The returned ``ProgramSetQuantumTaskResult`` has the same shape as if ``program_set`` had
been run unsplit, namely one ``CompositeEntry`` per entry of ``program_set.entries``,
and ``MeasuredEntry`` objects in the order of the program.

Expectation values and ``Sum`` Hamiltonian expectations are computed
for the original ``ProgramSet``.

Args:
results (Sequence[ProgramSetQuantumTaskResult]): The result of each task, in the same
order as ``program_set.split``'s return.
program_set (ProgramSet): The original unsplit program set.
index_map (list[list[int]]): The per-executable map from ``ProgramSet.split``.

Returns:
ProgramSetQuantumTaskResult: A result matching the shape of ``program_set``.

Raises:
ValueError: If ``len(results) != len(index_map)``, if the total size of ``index_map``
doesn't match ``program_set.total_executables``, or if any task produces a
different number of executables than its map expects.
"""
if len(results) != len(index_map):
raise ValueError(
f"Got {len(results)} task results but {len(index_map)} entries in index_map"
)
total_executables = program_set.total_executables
total_mapped = sum(len(m) for m in index_map)
if total_mapped != total_executables:
raise ValueError(
f"Index map covers {total_mapped} executables but the original program set "
f"has {total_executables}"
)

programs = [_binding_to_program(binding) for binding in program_set.entries]
executable_indices = list(program_set.enumerate_executables())
shots_per_executable = program_set.shots_per_executable

buffer = [None] * total_executables
for k, result in enumerate(results):
_buffer_result(
k=k,
result=result,
parent_indices=index_map[k],
program_set=program_set,
programs=programs,
executable_indices=executable_indices,
buffer=buffer,
)

entries = []
start = 0
for binding_idx, binding in enumerate(program_set.entries):
count = _count_executables(binding)
program = programs[binding_idx]
observables = binding.observables if isinstance(binding, CircuitBinding) else None
entry = CompositeEntry(
entries=buffer[start : start + count],
program=program,
inputs=CompositeEntry._get_inputs(program, observables),
observables=observables,
shots_per_executable=shots_per_executable,
additional_metadata=None,
)
entry._was_merged = True
entries.append(entry)
start += count

metas = [r._task_metadata for r in results]
merged = ProgramSetQuantumTaskResult(
entries=entries,
task_metadata=ProgramSetTaskMetadata(
id=";".join(meta.id for meta in metas), # Better way to do this?
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a way to denote that this is just synthetic metadata, or an aggregate, and these are just convenience attributes. Unfortunately the best way for that would probably be a new schema, which seems excessive.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or if task_metadata could have a union type with Sequence[ProgramSetTaskMetadata]

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a flag that triggers a warning the first time the metadata is accessed.

deviceId=metas[0].deviceId,
requestedShots=sum(m.requestedShots for m in metas),
successfulShots=sum(m.successfulShots for m in metas),
programMetadata=[
ProgramMetadata(
executables=[
ProgramSetExecutableResultMetadata()
for _ in range(_count_executables(b))
]
)
for b in program_set.entries
],
deviceParameters=None, # TODO: find a way to fill this in
createdAt=min(m.createdAt for m in metas if m.createdAt),
endedAt=max(m.endedAt for m in metas if m.endedAt),
status="COMPLETED" if any(m.status == "COMPLETED" for m in metas) else "FAILED",
totalFailedExecutables=sum(m.totalFailedExecutables for m in metas),
),
num_executables=total_executables,
program_set=program_set,
)
merged._was_merged = True
return merged

def __len__(self):
return len(self.entries)

Expand Down Expand Up @@ -481,6 +680,85 @@ def _compute_num_executables(metadata: ProgramSetTaskMetadata) -> int:
return counter


def _binding_to_program(binding: CircuitBinding | Circuit) -> Program:
if isinstance(binding, Circuit):
return Program(source=binding.to_ir(IRType.OPENQASM).source, inputs=None)
return binding.to_ir()


def _count_executables(binding: CircuitBinding | Circuit) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprised CircuitBinding doesn't actually have a simple .executables property here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly because there isn't really a class that encapsulates a single executable, so such a method would just return a list of tuples

if isinstance(binding, Circuit):
return 1
num_ps = len(binding.input_sets) if binding.input_sets is not None else 1
num_obs = len(binding.observables) if binding.observables is not None else 1
return num_ps * num_obs


def _buffer_result(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a description to this function? The other ones are straightforward but this has more complexity and _buffer_result is a bit vague.

k: int,
result: ProgramSetQuantumTaskResult,
parent_indices: list[int],
program_set: ProgramSet,
programs: list[Program],
executable_indices: list[tuple[int, int, int]],
buffer: list[MeasuredEntry | ProgramSetExecutableFailure | None],
) -> None:
j = 0
for composite in result.entries:
for entry in composite.entries:
if j >= len(parent_indices):
raise ValueError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this realistically happen? I guess if you have two programsets, where each one has the same total number of different numbers in the constituents?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if you pass in the wrong merge list; this is just defensive, like the if j != len(parent_indices) below.

f"t=Task {result.task_metadata.id} at index {k} "
"produced more executables than index map expects"
)
orig_idx = parent_indices[j]
binding_idx, ps_idx, obs_idx = executable_indices[orig_idx]
buffer[orig_idx] = _convert_measured_entry(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So map_k is some contiguous list between [0] and the composite entry number of executables, right? j is essentially a subindex, or the rank-2 index. So why is the buffer at along this index when the original buffer is len(total_executables), i.e. unfolded?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, map_k's sublist actually starts where the previous sublist ended; it's the index in the parent program set. I'll rename accordingly.

entry,
program_set.entries[binding_idx],
programs[binding_idx],
ps_idx,
obs_idx,
)
j += 1
if j != len(parent_indices):
raise ValueError(
f"Task {result.task_metadata.id} at index {k} produced {j} executables "
f"but index map expected {len(parent_indices)}"
)


def _convert_measured_entry(
entry: MeasuredEntry | ProgramSetExecutableFailure,
original_binding: CircuitBinding | Circuit,
original_program: Program,
parameter_set_index: int,
observable_index: int,
) -> MeasuredEntry | ProgramSetExecutableFailure:
if isinstance(entry, ProgramSetExecutableFailure):
return entry
if isinstance(original_binding, Circuit):
return replace(entry, program=original_program.source, inputs=None, observable=None)
observables = original_binding.observables
if observables is None:
observable: Observable | None = None
num_obs = 1
elif isinstance(observables, Sum):
observable = observables.summands[observable_index]
num_obs = len(observables.summands)
else:
observable = observables[observable_index]
num_obs = len(observables)
orig_inputs_index = parameter_set_index * num_obs + observable_index
program_inputs = original_program.inputs or {}
return replace(
entry,
program=original_program.source,
inputs={key: value[orig_inputs_index] for key, value in program_inputs.items()} or None,
observable=observable,
)


def _retrieve_s3_object_body(s3_bucket: str, s3_object_key: str, s3_client: BaseClient) -> str:
"""Retrieve the S3 object body.

Expand Down
Loading
Loading