Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 196 additions & 2 deletions src/braket/tasks/program_set_quantum_task_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from collections import Counter
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, replace

import boto3
import numpy as np
Expand All @@ -31,10 +31,15 @@
ProgramSetTaskMetadata,
ProgramSetTaskResult,
)
from braket.task_result.program_set_executable_result_v1 import (
ProgramSetExecutableResultMetadata,
)
from braket.task_result.program_set_task_metadata_v1 import ProgramMetadata

from braket.circuits import Observable
from braket.circuits import Circuit, Observable
from braket.circuits.observable import EULER_OBSERVABLE_PREFIX
from braket.circuits.observables import Sum
from braket.circuits.serialization import IRType
from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet
from braket.tasks.measurement_utils import (
expectation_from_measurements,
Expand Down Expand Up @@ -370,6 +375,116 @@ def from_object(
program_set=program_set,
)

@staticmethod
def merge(
results: Sequence[ProgramSetQuantumTaskResult],
program_set: ProgramSet,
index_map: list[list[int]],
) -> ProgramSetQuantumTaskResult:
"""Reconstruct a ``ProgramSetQuantumTaskResult`` from the task results produced by running
each program set of ``program_set.split(...)``.

``index_map`` is the per-executable map returned alongside the program sets by
``ProgramSet.split``: ``index_map[k][j]`` gives the index, in the order of ``program_set``,
of the executable that the jth executable of the kth task represents. The kth task's
executables are read in order for its program set, namely across ``results[k].entries``,
and within each ``CompositeEntry`` across its ``entries``.

The returned ``ProgramSetQuantumTaskResult`` has the same shape as if ``program_set`` had
been run unsplit, namely one ``CompositeEntry`` per entry of ``program_set.entries``,
and ``MeasuredEntry`` objects in the order of the program.

Expectation values and ``Sum`` Hamiltonian expectations are computed
for the original ``ProgramSet``.

Args:
results (Sequence[ProgramSetQuantumTaskResult]): The result of each task, in the same
order as ``program_set.split``'s return.
program_set (ProgramSet): The original unsplit program set.
index_map (list[list[int]]): The per-executable map from ``ProgramSet.split``.

Returns:
ProgramSetQuantumTaskResult: A result matching the shape of ``program_set``.

Raises:
ValueError: If ``len(results) != len(index_map)``, if the total size of ``index_map``
doesn't match ``program_set.total_executables``, or if any task produces a
different number of executables than its map expects.
"""
if len(results) != len(index_map):
raise ValueError(
f"Got {len(results)} task results but {len(index_map)} entries in index_map"
)
total_executables = program_set.total_executables
total_mapped = sum(len(m) for m in index_map)
if total_mapped != total_executables:
raise ValueError(
f"Index map covers {total_mapped} executables but the original program set "
f"has {total_executables}"
)

programs = [_binding_to_program(binding) for binding in program_set.entries]
executable_indices = list(program_set.enumerate_executables())
binding_executable_counts = [_count_executables(b) for b in program_set.entries]
shots_per_executable = results[0].entries[0].shots_per_executable
Comment thread
speller26 marked this conversation as resolved.
Outdated

buffer = [None] * total_executables
for k, result in enumerate(results):
_buffer_result(
k=k,
result=result,
map_k=index_map[k],
program_set=program_set,
programs=programs,
executable_indices=executable_indices,
buffer=buffer,
)

entries = []
start = 0
for binding_idx, binding in enumerate(program_set.entries):
count = binding_executable_counts[binding_idx]
Comment thread
speller26 marked this conversation as resolved.
Outdated
program = programs[binding_idx]
observables = binding.observables if isinstance(binding, CircuitBinding) else None
entries.append(
CompositeEntry(
entries=buffer[start : start + count],
program=program,
inputs=CompositeEntry._get_inputs(program, observables),
observables=observables,
shots_per_executable=shots_per_executable,
additional_metadata=None,
)
)
start += count

metas = [r.task_metadata for r in results]
return ProgramSetQuantumTaskResult(
entries=entries,
task_metadata=ProgramSetTaskMetadata(
id=";".join(meta.id for meta in metas), # Better way to do this?
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a way to denote that this is just synthetic metadata, or an aggregate, and these are just convenience attributes. Unfortunately the best way for that would probably be a new schema, which seems excessive.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or if task_metadata could have a union type with Sequence[ProgramSetTaskMetadata]

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a flag that triggers a warning the first time the metadata is accessed.

deviceId=metas[0].deviceId,
requestedShots=sum(m.requestedShots for m in metas),
successfulShots=sum(m.successfulShots for m in metas),
programMetadata=[
ProgramMetadata(
executables=[
ProgramSetExecutableResultMetadata()
for _ in range(_count_executables(b))
]
)
for b in program_set.entries
],
deviceParameters=None, # TODO: find a way to fill this in
createdAt=min(m.createdAt for m in metas if m.createdAt),
endedAt=max(m.endedAt for m in metas if m.endedAt),
status="COMPLETED" if any(m.status == "COMPLETED" for m in metas) else "FAILED",
totalFailedExecutables=sum(m.totalFailedExecutables for m in metas),
),
num_executables=total_executables,
program_set=program_set,
)

def __len__(self):
return len(self.entries)

Expand Down Expand Up @@ -481,6 +596,85 @@ def _compute_num_executables(metadata: ProgramSetTaskMetadata) -> int:
return counter


def _binding_to_program(binding: CircuitBinding | Circuit) -> Program:
if isinstance(binding, Circuit):
return Program(source=binding.to_ir(IRType.OPENQASM).source, inputs=None)
return binding.to_ir()


def _count_executables(binding: CircuitBinding | Circuit) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprised CircuitBinding doesn't actually have a simple .executables property here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly because there isn't really a class that encapsulates a single executable, so such a method would just return a list of tuples

if isinstance(binding, Circuit):
return 1
num_ps = len(binding.input_sets) if binding.input_sets is not None else 1
num_obs = len(binding.observables) if binding.observables is not None else 1
return num_ps * num_obs


def _buffer_result(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a description to this function? The other ones are straightforward but this has more complexity and _buffer_result is a bit vague.

k: int,
result: ProgramSetQuantumTaskResult,
map_k: list[int],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is like the composite_entry_map?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Map to the parent program set index

program_set: ProgramSet,
programs: list[Program],
executable_indices: list[tuple[int, int, int]],
buffer: list[MeasuredEntry | ProgramSetExecutableFailure | None],
) -> None:
j = 0
for composite in result.entries:
for entry in composite.entries:
if j >= len(map_k):
raise ValueError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this realistically happen? I guess if you have two programsets, where each one has the same total number of different numbers in the constituents?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if you pass in the wrong merge list; this is just defensive, like the if j != len(parent_indices) below.

f"t=Task {result.task_metadata.id} at index {k} "
"produced more executables than index map expects"
)
orig_idx = map_k[j]
binding_idx, ps_idx, obs_idx = executable_indices[orig_idx]
buffer[orig_idx] = _convert_measured_entry(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So map_k is some contiguous list between [0] and the composite entry number of executables, right? j is essentially a subindex, or the rank-2 index. So why is the buffer at along this index when the original buffer is len(total_executables), i.e. unfolded?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, map_k's sublist actually starts where the previous sublist ended; it's the index in the parent program set. I'll rename accordingly.

entry,
program_set.entries[binding_idx],
programs[binding_idx],
ps_idx,
obs_idx,
)
j += 1
if j != len(map_k):
raise ValueError(
f"Task {result.task_metadata.id} at index {k} produced {j} executables "
f"but index map expected {len(map_k)}"
)


def _convert_measured_entry(
entry: MeasuredEntry | ProgramSetExecutableFailure,
original_binding: CircuitBinding | Circuit,
original_program: Program,
parameter_set_index: int,
observable_index: int,
) -> MeasuredEntry | ProgramSetExecutableFailure:
if isinstance(entry, ProgramSetExecutableFailure):
return entry
if isinstance(original_binding, Circuit):
return replace(entry, program=original_program.source, inputs=None, observable=None)
observables = original_binding.observables
if observables is None:
observable: Observable | None = None
num_obs = 1
elif isinstance(observables, Sum):
observable = observables.summands[observable_index]
num_obs = len(observables.summands)
else:
observable = observables[observable_index]
num_obs = len(observables)
orig_inputs_index = parameter_set_index * num_obs + observable_index
program_inputs = original_program.inputs or {}
return replace(
entry,
program=original_program.source,
inputs={key: value[orig_inputs_index] for key, value in program_inputs.items()} or None,
observable=observable,
)


def _retrieve_s3_object_body(s3_bucket: str, s3_object_key: str, s3_client: BaseClient) -> str:
"""Retrieve the S3 object body.

Expand Down
Loading
Loading