-
Notifications
You must be signed in to change notification settings - Fork 167
feat: Merge program set task results #1254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: split
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| import warnings | ||
| from collections import Counter | ||
| from collections.abc import Sequence | ||
| from dataclasses import dataclass | ||
| from dataclasses import dataclass, replace | ||
|
|
||
| import boto3 | ||
| import numpy as np | ||
|
|
@@ -31,10 +31,15 @@ | |
| ProgramSetTaskMetadata, | ||
| ProgramSetTaskResult, | ||
| ) | ||
| from braket.task_result.program_set_executable_result_v1 import ( | ||
| ProgramSetExecutableResultMetadata, | ||
| ) | ||
| from braket.task_result.program_set_task_metadata_v1 import ProgramMetadata | ||
|
|
||
| from braket.circuits import Observable | ||
| from braket.circuits import Circuit, Observable | ||
| from braket.circuits.observable import EULER_OBSERVABLE_PREFIX | ||
| from braket.circuits.observables import Sum | ||
| from braket.circuits.serialization import IRType | ||
| from braket.program_sets import CircuitBinding, ParameterSets, ProgramSet | ||
| from braket.tasks.measurement_utils import ( | ||
| expectation_from_measurements, | ||
|
|
@@ -370,6 +375,116 @@ def from_object( | |
| program_set=program_set, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def merge( | ||
| results: Sequence[ProgramSetQuantumTaskResult], | ||
| program_set: ProgramSet, | ||
| index_map: list[list[int]], | ||
| ) -> ProgramSetQuantumTaskResult: | ||
| """Reconstruct a ``ProgramSetQuantumTaskResult`` from the task results produced by running | ||
| each program set of ``program_set.split(...)``. | ||
|
|
||
| ``index_map`` is the per-executable map returned alongside the program sets by | ||
| ``ProgramSet.split``: ``index_map[k][j]`` gives the index, in the order of ``program_set``, | ||
| of the executable that the jth executable of the kth task represents. The kth task's | ||
| executables are read in order for its program set, namely across ``results[k].entries``, | ||
| and within each ``CompositeEntry`` across its ``entries``. | ||
|
|
||
| The returned ``ProgramSetQuantumTaskResult`` has the same shape as if ``program_set`` had | ||
| been run unsplit, namely one ``CompositeEntry`` per entry of ``program_set.entries``, | ||
| and ``MeasuredEntry`` objects in the order of the program. | ||
|
|
||
| Expectation values and ``Sum`` Hamiltonian expectations are computed | ||
| for the original ``ProgramSet``. | ||
|
|
||
| Args: | ||
| results (Sequence[ProgramSetQuantumTaskResult]): The result of each task, in the same | ||
| order as ``program_set.split``'s return. | ||
| program_set (ProgramSet): The original unsplit program set. | ||
| index_map (list[list[int]]): The per-executable map from ``ProgramSet.split``. | ||
|
|
||
| Returns: | ||
| ProgramSetQuantumTaskResult: A result matching the shape of ``program_set``. | ||
|
|
||
| Raises: | ||
| ValueError: If ``len(results) != len(index_map)``, if the total size of ``index_map`` | ||
| doesn't match ``program_set.total_executables``, or if any task produces a | ||
| different number of executables than its map expects. | ||
| """ | ||
| if len(results) != len(index_map): | ||
| raise ValueError( | ||
| f"Got {len(results)} task results but {len(index_map)} entries in index_map" | ||
| ) | ||
| total_executables = program_set.total_executables | ||
| total_mapped = sum(len(m) for m in index_map) | ||
| if total_mapped != total_executables: | ||
| raise ValueError( | ||
| f"Index map covers {total_mapped} executables but the original program set " | ||
| f"has {total_executables}" | ||
| ) | ||
|
|
||
| programs = [_binding_to_program(binding) for binding in program_set.entries] | ||
| executable_indices = list(program_set.enumerate_executables()) | ||
| binding_executable_counts = [_count_executables(b) for b in program_set.entries] | ||
| shots_per_executable = results[0].entries[0].shots_per_executable | ||
|
|
||
| buffer = [None] * total_executables | ||
| for k, result in enumerate(results): | ||
| _buffer_result( | ||
| k=k, | ||
| result=result, | ||
| map_k=index_map[k], | ||
| program_set=program_set, | ||
| programs=programs, | ||
| executable_indices=executable_indices, | ||
| buffer=buffer, | ||
| ) | ||
|
|
||
| entries = [] | ||
| start = 0 | ||
| for binding_idx, binding in enumerate(program_set.entries): | ||
| count = binding_executable_counts[binding_idx] | ||
|
speller26 marked this conversation as resolved.
Outdated
|
||
| program = programs[binding_idx] | ||
| observables = binding.observables if isinstance(binding, CircuitBinding) else None | ||
| entries.append( | ||
| CompositeEntry( | ||
| entries=buffer[start : start + count], | ||
| program=program, | ||
| inputs=CompositeEntry._get_inputs(program, observables), | ||
| observables=observables, | ||
| shots_per_executable=shots_per_executable, | ||
| additional_metadata=None, | ||
| ) | ||
| ) | ||
| start += count | ||
|
|
||
| metas = [r.task_metadata for r in results] | ||
| return ProgramSetQuantumTaskResult( | ||
| entries=entries, | ||
| task_metadata=ProgramSetTaskMetadata( | ||
| id=";".join(meta.id for meta in metas), # Better way to do this? | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if there is a way to denote that this is just synthetic metadata, or an aggregate, and these are just convenience attributes. Unfortunately the best way for that would probably be a new schema, which seems excessive.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or if task_metadata could have a union type with
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can add a flag that triggers a warning the first time the metadata is accessed. |
||
| deviceId=metas[0].deviceId, | ||
| requestedShots=sum(m.requestedShots for m in metas), | ||
| successfulShots=sum(m.successfulShots for m in metas), | ||
| programMetadata=[ | ||
| ProgramMetadata( | ||
| executables=[ | ||
| ProgramSetExecutableResultMetadata() | ||
| for _ in range(_count_executables(b)) | ||
| ] | ||
| ) | ||
| for b in program_set.entries | ||
| ], | ||
| deviceParameters=None, # TODO: find a way to fill this in | ||
| createdAt=min(m.createdAt for m in metas if m.createdAt), | ||
| endedAt=max(m.endedAt for m in metas if m.endedAt), | ||
| status="COMPLETED" if any(m.status == "COMPLETED" for m in metas) else "FAILED", | ||
| totalFailedExecutables=sum(m.totalFailedExecutables for m in metas), | ||
| ), | ||
| num_executables=total_executables, | ||
| program_set=program_set, | ||
| ) | ||
|
|
||
| def __len__(self): | ||
| return len(self.entries) | ||
|
|
||
|
|
@@ -481,6 +596,85 @@ def _compute_num_executables(metadata: ProgramSetTaskMetadata) -> int: | |
| return counter | ||
|
|
||
|
|
||
| def _binding_to_program(binding: CircuitBinding | Circuit) -> Program: | ||
| if isinstance(binding, Circuit): | ||
| return Program(source=binding.to_ir(IRType.OPENQASM).source, inputs=None) | ||
| return binding.to_ir() | ||
|
|
||
|
|
||
| def _count_executables(binding: CircuitBinding | Circuit) -> int: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Surprised CircuitBinding doesn't actually have a simple .executables property here.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mainly because there isn't really a class that encapsulates a single executable, so such a method would just return a list of tuples |
||
| if isinstance(binding, Circuit): | ||
| return 1 | ||
| num_ps = len(binding.input_sets) if binding.input_sets is not None else 1 | ||
| num_obs = len(binding.observables) if binding.observables is not None else 1 | ||
| return num_ps * num_obs | ||
|
|
||
|
|
||
| def _buffer_result( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a description to this function? The other ones are straightforward but this has more complexity and |
||
| k: int, | ||
| result: ProgramSetQuantumTaskResult, | ||
| map_k: list[int], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is like the composite_entry_map?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Map to the parent program set index |
||
| program_set: ProgramSet, | ||
| programs: list[Program], | ||
| executable_indices: list[tuple[int, int, int]], | ||
| buffer: list[MeasuredEntry | ProgramSetExecutableFailure | None], | ||
| ) -> None: | ||
| j = 0 | ||
| for composite in result.entries: | ||
| for entry in composite.entries: | ||
| if j >= len(map_k): | ||
| raise ValueError( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this realistically happen? I guess if you have two programsets, where each one has the same total number of different numbers in the constituents?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, if you pass in the wrong merge list; this is just defensive, like the |
||
| f"t=Task {result.task_metadata.id} at index {k} " | ||
| "produced more executables than index map expects" | ||
| ) | ||
| orig_idx = map_k[j] | ||
| binding_idx, ps_idx, obs_idx = executable_indices[orig_idx] | ||
| buffer[orig_idx] = _convert_measured_entry( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, |
||
| entry, | ||
| program_set.entries[binding_idx], | ||
| programs[binding_idx], | ||
| ps_idx, | ||
| obs_idx, | ||
| ) | ||
| j += 1 | ||
| if j != len(map_k): | ||
| raise ValueError( | ||
| f"Task {result.task_metadata.id} at index {k} produced {j} executables " | ||
| f"but index map expected {len(map_k)}" | ||
| ) | ||
|
|
||
|
|
||
| def _convert_measured_entry( | ||
| entry: MeasuredEntry | ProgramSetExecutableFailure, | ||
| original_binding: CircuitBinding | Circuit, | ||
| original_program: Program, | ||
| parameter_set_index: int, | ||
| observable_index: int, | ||
| ) -> MeasuredEntry | ProgramSetExecutableFailure: | ||
| if isinstance(entry, ProgramSetExecutableFailure): | ||
| return entry | ||
| if isinstance(original_binding, Circuit): | ||
| return replace(entry, program=original_program.source, inputs=None, observable=None) | ||
| observables = original_binding.observables | ||
| if observables is None: | ||
| observable: Observable | None = None | ||
| num_obs = 1 | ||
| elif isinstance(observables, Sum): | ||
| observable = observables.summands[observable_index] | ||
| num_obs = len(observables.summands) | ||
| else: | ||
| observable = observables[observable_index] | ||
| num_obs = len(observables) | ||
| orig_inputs_index = parameter_set_index * num_obs + observable_index | ||
| program_inputs = original_program.inputs or {} | ||
| return replace( | ||
| entry, | ||
| program=original_program.source, | ||
| inputs={key: value[orig_inputs_index] for key, value in program_inputs.items()} or None, | ||
| observable=observable, | ||
| ) | ||
|
|
||
|
|
||
| def _retrieve_s3_object_body(s3_bucket: str, s3_object_key: str, s3_client: BaseClient) -> str: | ||
| """Retrieve the S3 object body. | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.