-
Notifications
You must be signed in to change notification settings - Fork 167
feat: Merge program set task results #1254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: split
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| import warnings | ||
| from collections import Counter | ||
| from collections.abc import Sequence | ||
| from dataclasses import dataclass | ||
| from dataclasses import dataclass, replace | ||
|
|
||
| import boto3 | ||
| import numpy as np | ||
|
|
@@ -31,10 +31,15 @@ | |
| ProgramSetTaskMetadata, | ||
| ProgramSetTaskResult, | ||
| ) | ||
| from braket.task_result.program_set_executable_result_v1 import ( | ||
| ProgramSetExecutableResultMetadata, | ||
| ) | ||
| from braket.task_result.program_set_task_metadata_v1 import ProgramMetadata | ||
|
|
||
| from braket.circuits import Observable | ||
| from braket.circuits import Circuit, Observable | ||
| from braket.circuits.observable import EULER_OBSERVABLE_PREFIX | ||
| from braket.circuits.observables import Sum | ||
| from braket.circuits.serialization import IRType | ||
| from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet | ||
| from braket.tasks.measurement_utils import ( | ||
| expectation_from_measurements, | ||
|
|
@@ -141,7 +146,6 @@ def expectation(self) -> float | None: | |
| return self._expectation | ||
|
|
||
|
|
||
| @dataclass | ||
| class CompositeEntry: | ||
| """Results of a program in a program set | ||
|
|
||
|
|
@@ -152,15 +156,70 @@ class CompositeEntry: | |
| observables (Sum | list[Observable] | None): The Sum Hamiltonian or observables | ||
| that were measured, if any. | ||
| shots_per_executable (int): The number of shots each underlying executable was run with | ||
| additional_metadata (AdditionalMetadata): Additional metadata about this program | ||
| additional_metadata (AdditionalMetadata | None): Additional metadata about this program. | ||
| ``None`` for entries produced by ``ProgramSetQuantumTaskResult.merge``, | ||
| since per-program metadata cannot be aggregated meaningfully across underlying tasks. | ||
| """ | ||
|
|
||
| entries: list[MeasuredEntry] | ||
| program: Program | ||
| inputs: ParameterSets | ||
| observables: Sum | list[Observable] | None | ||
| shots_per_executable: int | ||
| additional_metadata: AdditionalMetadata | ||
| def __init__( | ||
| self, | ||
| entries: list[MeasuredEntry], | ||
| program: Program, | ||
| inputs: ParameterSets, | ||
| observables: Sum | list[Observable] | None, | ||
| shots_per_executable: int, | ||
| additional_metadata: AdditionalMetadata | None, | ||
| ): | ||
| self._entries = entries | ||
| self._program = program | ||
| self._inputs = inputs | ||
| self._observables = observables | ||
| self._shots_per_executable = shots_per_executable | ||
| self._additional_metadata = additional_metadata | ||
| self._was_merged = False | ||
| self._expectations = self._compute_expectations() if isinstance(observables, Sum) else None | ||
|
|
||
| @property | ||
| def entries(self) -> list[MeasuredEntry]: | ||
| """list[MeasuredEntry]: The results of each executable in this program.""" | ||
| return self._entries | ||
|
|
||
| @property | ||
| def program(self) -> Program: | ||
| """Program: The program that was run.""" | ||
| return self._program | ||
|
|
||
| @property | ||
| def inputs(self) -> ParameterSets: | ||
| """ParameterSets: The input values this program was run with.""" | ||
| return self._inputs | ||
|
|
||
| @property | ||
| def observables(self) -> Sum | list[Observable] | None: | ||
| """Sum | list[Observable] | None: The Sum Hamiltonian or observables measured, | ||
| if any.""" | ||
| return self._observables | ||
|
|
||
| @property | ||
| def shots_per_executable(self) -> int: | ||
| """int: The number of shots each underlying executable was run with.""" | ||
| return self._shots_per_executable | ||
|
|
||
| @property | ||
| def additional_metadata(self) -> AdditionalMetadata | None: | ||
| """AdditionalMetadata | None: Additional metadata about this program. | ||
|
|
||
| For entries produced by ``ProgramSetQuantumTaskResult.merge``, this will be ``None``; | ||
| Use the original per-task results for true per-program metadata. | ||
| """ | ||
| if self._was_merged: | ||
| warnings.warn( | ||
| "additional_metadata for a CompositeEntry on a merged " | ||
| "ProgramSetQuantumTaskResult is None; " | ||
| "use the original per-task results for true per-program metadata.", | ||
| stacklevel=2, | ||
| ) | ||
| return self._additional_metadata | ||
|
|
||
| @staticmethod | ||
| def _from_object( | ||
|
|
@@ -192,11 +251,6 @@ def _from_object( | |
| additional_metadata=program_result.additionalMetadata, | ||
| ) | ||
|
|
||
| def __post_init__(self): | ||
| self._expectations = ( | ||
| self._compute_expectations() if isinstance(self.observables, Sum) else None | ||
| ) | ||
|
|
||
| def __len__(self): | ||
| return len(self.entries) | ||
|
|
||
|
|
@@ -310,7 +364,6 @@ def _dispatch_executable_result( | |
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ProgramSetQuantumTaskResult: | ||
| """The result of a program set task. | ||
|
|
||
|
|
@@ -323,10 +376,45 @@ class ProgramSetQuantumTaskResult: | |
| can be automatically computed. | ||
| """ | ||
|
|
||
| entries: list[CompositeEntry] | ||
| task_metadata: ProgramSetTaskMetadata | ||
| num_executables: int | ||
| program_set: ProgramSet | None | ||
| def __init__( | ||
| self, | ||
| entries: list[CompositeEntry], | ||
| task_metadata: ProgramSetTaskMetadata, | ||
| num_executables: int, | ||
| program_set: ProgramSet | None, | ||
| ): | ||
| self._entries = entries | ||
| self._task_metadata = task_metadata | ||
| self._num_executables = num_executables | ||
| self._program_set = program_set | ||
| self._was_merged = False | ||
|
|
||
| @property | ||
| def entries(self) -> list[CompositeEntry]: | ||
| """list[CompositeEntry]: The results of each program in this program set.""" | ||
| return self._entries | ||
|
|
||
| @property | ||
| def task_metadata(self) -> ProgramSetTaskMetadata: | ||
| """ProgramSetTaskMetadata: The metadata of the task.""" | ||
| if self._was_merged: | ||
| warnings.warn( | ||
| "task_metadata for a merged ProgramSetQuantumTaskResult is synthesized " | ||
| "from multiple underlying tasks; it does not reflect any one underlying task. " | ||
| "Use the original per-task results for true task metadata.", | ||
| stacklevel=2, | ||
| ) | ||
| return self._task_metadata | ||
|
|
||
| @property | ||
| def num_executables(self) -> int: | ||
| """int: The total number of executables in this program set task.""" | ||
| return self._num_executables | ||
|
|
||
| @property | ||
| def program_set(self) -> ProgramSet | None: | ||
| """ProgramSet | None: The program set that was run, if provided to the constructor.""" | ||
| return self._program_set | ||
|
|
||
| @staticmethod | ||
| def from_object( | ||
|
|
@@ -370,6 +458,117 @@ def from_object( | |
| program_set=program_set, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def merge( | ||
| results: Sequence[ProgramSetQuantumTaskResult], | ||
| program_set: ProgramSet, | ||
| index_map: list[list[int]], | ||
| ) -> ProgramSetQuantumTaskResult: | ||
| """Reconstruct a ``ProgramSetQuantumTaskResult`` from the task results produced by running | ||
| each program set of ``program_set.split(...)``. | ||
|
|
||
| ``index_map`` is the per-executable map returned alongside the program sets by | ||
| ``ProgramSet.split``: ``index_map[k][j]`` gives the index, in the order of ``program_set``, | ||
| of the executable that the jth executable of the kth task represents. The kth task's | ||
| executables are read in order for its program set, namely across ``results[k].entries``, | ||
| and within each ``CompositeEntry`` across its ``entries``. | ||
|
|
||
| The returned ``ProgramSetQuantumTaskResult`` has the same shape as if ``program_set`` had | ||
| been run unsplit, namely one ``CompositeEntry`` per entry of ``program_set.entries``, | ||
| and ``MeasuredEntry`` objects in the order of the program. | ||
|
|
||
| Expectation values and ``Sum`` Hamiltonian expectations are computed | ||
| for the original ``ProgramSet``. | ||
|
|
||
| Args: | ||
| results (Sequence[ProgramSetQuantumTaskResult]): The result of each task, in the same | ||
| order as ``program_set.split``'s return. | ||
| program_set (ProgramSet): The original unsplit program set. | ||
| index_map (list[list[int]]): The per-executable map from ``ProgramSet.split``. | ||
|
|
||
| Returns: | ||
| ProgramSetQuantumTaskResult: A result matching the shape of ``program_set``. | ||
|
|
||
| Raises: | ||
| ValueError: If ``len(results) != len(index_map)``, if the total size of ``index_map`` | ||
| doesn't match ``program_set.total_executables``, or if any task produces a | ||
| different number of executables than its map expects. | ||
| """ | ||
| if len(results) != len(index_map): | ||
| raise ValueError( | ||
| f"Got {len(results)} task results but {len(index_map)} entries in index_map" | ||
| ) | ||
| total_executables = program_set.total_executables | ||
| total_mapped = sum(len(m) for m in index_map) | ||
| if total_mapped != total_executables: | ||
| raise ValueError( | ||
| f"Index map covers {total_mapped} executables but the original program set " | ||
| f"has {total_executables}" | ||
| ) | ||
|
|
||
| programs = [_binding_to_program(binding) for binding in program_set.entries] | ||
| executable_indices = list(program_set.enumerate_executables()) | ||
| shots_per_executable = program_set.shots_per_executable | ||
|
|
||
| buffer = [None] * total_executables | ||
| for k, result in enumerate(results): | ||
| _buffer_result( | ||
| k=k, | ||
| result=result, | ||
| parent_indices=index_map[k], | ||
| program_set=program_set, | ||
| programs=programs, | ||
| executable_indices=executable_indices, | ||
| buffer=buffer, | ||
| ) | ||
|
|
||
| entries = [] | ||
| start = 0 | ||
| for binding_idx, binding in enumerate(program_set.entries): | ||
| count = _count_executables(binding) | ||
| program = programs[binding_idx] | ||
| observables = binding.observables if isinstance(binding, CircuitBinding) else None | ||
| entry = CompositeEntry( | ||
| entries=buffer[start : start + count], | ||
| program=program, | ||
| inputs=CompositeEntry._get_inputs(program, observables), | ||
| observables=observables, | ||
| shots_per_executable=shots_per_executable, | ||
| additional_metadata=None, | ||
| ) | ||
| entry._was_merged = True | ||
| entries.append(entry) | ||
| start += count | ||
|
|
||
| metas = [r._task_metadata for r in results] | ||
| merged = ProgramSetQuantumTaskResult( | ||
| entries=entries, | ||
| task_metadata=ProgramSetTaskMetadata( | ||
| id=";".join(meta.id for meta in metas), # Better way to do this? | ||
| deviceId=metas[0].deviceId, | ||
| requestedShots=sum(m.requestedShots for m in metas), | ||
| successfulShots=sum(m.successfulShots for m in metas), | ||
| programMetadata=[ | ||
| ProgramMetadata( | ||
| executables=[ | ||
| ProgramSetExecutableResultMetadata() | ||
| for _ in range(_count_executables(b)) | ||
| ] | ||
| ) | ||
| for b in program_set.entries | ||
| ], | ||
| deviceParameters=None, # TODO: find a way to fill this in | ||
| createdAt=min(m.createdAt for m in metas if m.createdAt), | ||
| endedAt=max(m.endedAt for m in metas if m.endedAt), | ||
| status="COMPLETED" if any(m.status == "COMPLETED" for m in metas) else "FAILED", | ||
| totalFailedExecutables=sum(m.totalFailedExecutables for m in metas), | ||
| ), | ||
| num_executables=total_executables, | ||
| program_set=program_set, | ||
| ) | ||
| merged._was_merged = True | ||
| return merged | ||
|
|
||
| def __len__(self): | ||
| return len(self.entries) | ||
|
|
||
|
|
@@ -481,6 +680,85 @@ def _compute_num_executables(metadata: ProgramSetTaskMetadata) -> int: | |
| return counter | ||
|
|
||
|
|
||
| def _binding_to_program(binding: CircuitBinding | Circuit) -> Program: | ||
| if isinstance(binding, Circuit): | ||
| return Program(source=binding.to_ir(IRType.OPENQASM).source, inputs=None) | ||
| return binding.to_ir() | ||
|
|
||
|
|
||
| def _count_executables(binding: CircuitBinding | Circuit) -> int: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Surprised CircuitBinding doesn't actually have a simple .executables property here.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mainly because there isn't really a class that encapsulates a single executable, so such a method would just return a list of tuples |
||
| if isinstance(binding, Circuit): | ||
| return 1 | ||
| num_ps = len(binding.input_sets) if binding.input_sets is not None else 1 | ||
| num_obs = len(binding.observables) if binding.observables is not None else 1 | ||
| return num_ps * num_obs | ||
|
|
||
|
|
||
| def _buffer_result( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a description to this function? The other ones are straightforward but this has more complexity and |
||
| k: int, | ||
| result: ProgramSetQuantumTaskResult, | ||
| parent_indices: list[int], | ||
| program_set: ProgramSet, | ||
| programs: list[Program], | ||
| executable_indices: list[tuple[int, int, int]], | ||
| buffer: list[MeasuredEntry | ProgramSetExecutableFailure | None], | ||
| ) -> None: | ||
| j = 0 | ||
| for composite in result.entries: | ||
| for entry in composite.entries: | ||
| if j >= len(parent_indices): | ||
| raise ValueError( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this realistically happen? I guess if you have two programsets, where each one has the same total number of different numbers in the constituents?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, if you pass in the wrong merge list; this is just defensive, like the |
||
| f"t=Task {result.task_metadata.id} at index {k} " | ||
| "produced more executables than index map expects" | ||
| ) | ||
| orig_idx = parent_indices[j] | ||
| binding_idx, ps_idx, obs_idx = executable_indices[orig_idx] | ||
| buffer[orig_idx] = _convert_measured_entry( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, |
||
| entry, | ||
| program_set.entries[binding_idx], | ||
| programs[binding_idx], | ||
| ps_idx, | ||
| obs_idx, | ||
| ) | ||
| j += 1 | ||
| if j != len(parent_indices): | ||
| raise ValueError( | ||
| f"Task {result.task_metadata.id} at index {k} produced {j} executables " | ||
| f"but index map expected {len(parent_indices)}" | ||
| ) | ||
|
|
||
|
|
||
| def _convert_measured_entry( | ||
| entry: MeasuredEntry | ProgramSetExecutableFailure, | ||
| original_binding: CircuitBinding | Circuit, | ||
| original_program: Program, | ||
| parameter_set_index: int, | ||
| observable_index: int, | ||
| ) -> MeasuredEntry | ProgramSetExecutableFailure: | ||
| if isinstance(entry, ProgramSetExecutableFailure): | ||
| return entry | ||
| if isinstance(original_binding, Circuit): | ||
| return replace(entry, program=original_program.source, inputs=None, observable=None) | ||
| observables = original_binding.observables | ||
| if observables is None: | ||
| observable: Observable | None = None | ||
| num_obs = 1 | ||
| elif isinstance(observables, Sum): | ||
| observable = observables.summands[observable_index] | ||
| num_obs = len(observables.summands) | ||
| else: | ||
| observable = observables[observable_index] | ||
| num_obs = len(observables) | ||
| orig_inputs_index = parameter_set_index * num_obs + observable_index | ||
| program_inputs = original_program.inputs or {} | ||
| return replace( | ||
| entry, | ||
| program=original_program.source, | ||
| inputs={key: value[orig_inputs_index] for key, value in program_inputs.items()} or None, | ||
| observable=observable, | ||
| ) | ||
|
|
||
|
|
||
| def _retrieve_s3_object_body(s3_bucket: str, s3_object_key: str, s3_client: BaseClient) -> str: | ||
| """Retrieve the S3 object body. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there is a way to denote that this is just synthetic metadata, or an aggregate, and these are just convenience attributes. Unfortunately the best way for that would probably be a new schema, which seems excessive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or if task_metadata could have a union type with
Sequence[ProgramSetTaskMetadata]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add a flag that triggers a warning the first time the metadata is accessed.