Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions src/braket/program_sets/program_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,114 @@ def total_shots(self) -> int:
raise ValueError("No per-executable shots defined")
return self._shots_per_executable * self.total_executables

def split(self, max_executables: int) -> list[ProgramSet]:
"""
Split this program set into a list of program sets with
at most ``max_executables`` executables.

Sum Hamiltonians and lists of observables will not be broken into separate program sets;
consequently, this method will fail if the size of any Hamiltonian or observable list
Comment thread
speller26 marked this conversation as resolved.
Outdated
exceeds ``max_executables``.

Adjacent triples originating from the same ``CircuitBinding`` are coalesced into
Comment thread
speller26 marked this conversation as resolved.
Outdated
a single multi-parameter-set ``CircuitBinding`` in the resulting sub-program set.

Concatenating the executables of the program sets in the list in order reproduces
Comment thread
speller26 marked this conversation as resolved.
Outdated
the executables of the original program set.

Args:
max_executables (int): The maximum number of executables allowed per
sub-program set. Must be positive.

Returns:
list[ProgramSet]: The sub-program sets. If this program set already fits
within ``max_executables``, a single-element list containing ``self`` is
returned.

Raises:
ValueError: If ``max_executables`` is not positive, or if a single triple
(one parameter-set index of a single ``CircuitBinding``) requires
more than ``max_executables`` executables, because its observable list or
``Sum`` Hamiltonian is larger than allowed.

Examples:
>>> ps = ProgramSet([
... CircuitBinding(c1, inputs1, obs1), # 100 param sets, 4 observables
... CircuitBinding(c2, inputs2, obs2), # 50 param sets, 2 observables
... ])
>>> sub = ps.split(120)
>>> [s.total_executables for s in sub]
[120, 120, 120, 80, 60]
Comment thread
speller26 marked this conversation as resolved.
Outdated
"""
if max_executables <= 0:
raise ValueError(f"max_executables must be positive, got {max_executables}")

if self.total_executables <= max_executables:
return [self]

program_sets = []
current = []
current_size = 0
for triple in self._enumerate_triples(max_executables):
size = triple[2]
if current and current_size + size > max_executables:
program_sets.append(self._build_sub_program_set(current))
current = []
current_size = 0
current.append(triple)
current_size += size
program_sets.append(self._build_sub_program_set(current))

return program_sets

def _enumerate_triples(self, max_executables: int) -> list[tuple[int, int | None, int]]:
triples = []
for prog_idx, prog in enumerate(self._programs):
if isinstance(prog, Circuit):
triples.append((prog_idx, None, 1))
continue
obs = prog.observables
class_size = max(1, len(obs)) if obs is not None else 1
Comment thread
speller26 marked this conversation as resolved.
Outdated
if class_size > max_executables:
raise ValueError(
f"Program at index {prog_idx} has a single parameter-set index with "
f"{class_size} executables, exceeding max_executables={max_executables}"
)
input_sets = prog.input_sets
if input_sets is None:
triples.append((prog_idx, None, class_size))
else:
triples.extend((prog_idx, i, class_size) for i in range(len(input_sets)))
Comment thread
speller26 marked this conversation as resolved.
Outdated
return triples

def _build_sub_program_set(self, triples: list[tuple[int, int | None, int]]) -> ProgramSet:
entries = []
i = 0
while i < len(triples):
prog_idx, param_idx, _ = triples[i]
prog = self._programs[prog_idx]
if param_idx is None:
entries.append(prog)
i += 1
continue
j = i
while (
j + 1 < len(triples)
and triples[j + 1][0] == prog_idx
and triples[j + 1][1] == triples[j][1] + 1
):
j += 1
start, stop = triples[i][1], triples[j][1] + 1
entries.append(
CircuitBinding(
prog.circuit,
input_sets=prog.input_sets.as_list()[start:stop],
observables=prog.observables,
)
)
i = j + 1
return ProgramSet(entries, self._shots_per_executable)

@staticmethod
def zip(
circuits: Sequence[Circuit] | CircuitBinding,
Expand Down
167 changes: 167 additions & 0 deletions test/unit_tests/braket/program_sets/test_program_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,170 @@ def test_inequality(circuit_rx_parametrized):
program_set = ProgramSet([binding, binding])
assert program_set != ProgramSet([binding, circuit_rx_parametrized])
assert program_set != circuit_rx_parametrized


def test_split_already_fits(circuit_rx_parametrized):
binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}, {"theta": 3.21}])
program_set = ProgramSet(binding)
sub = program_set.split(10)
assert sub == [program_set]
assert sub[0] is program_set


def test_split_exact_fit(circuit_rx_parametrized):
binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}, {"theta": 3.21}])
program_set = ProgramSet(binding)
sub = program_set.split(2)
assert sub == [program_set]
assert sub[0] is program_set


def test_split_plain_circuits():
circs = [ghz(1), ghz(2), ghz(3), ghz(1), ghz(2)]
program_set = ProgramSet(circs, shots_per_executable=10)
sub = program_set.split(2)
assert [s.total_executables for s in sub] == [2, 2, 1]
assert sub[0].entries == circs[0:2]
assert sub[1].entries == circs[2:4]
assert sub[2].entries == circs[4:5]


def test_split_single_binding_packed(circuit_rx_parametrized):
inputs = {"theta": [float(i) for i in range(10)]}
binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs)
program_set = ProgramSet(binding)
sub = program_set.split(3)
assert [s.total_executables for s in sub] == [3, 3, 3, 1]
# Each sub-program-set is a single coalesced binding over a contiguous slice.
for s in sub:
assert len(s) == 1
assert s.entries[0].circuit == circuit_rx_parametrized
assert s.entries[0].observables is None
thetas = []
for s in sub:
thetas.extend(s.entries[0].input_sets.as_dict()["theta"])
assert thetas == inputs["theta"]


def test_split_with_observables(circuit_rx_parametrized):
# 5 parameter-set indices, 4 observables => 5 classes of size 4.
inputs = {"theta": [float(i) for i in range(5)]}
observables = [X(0), Y(0), Z(0), X(0) @ Y(1)]
binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables)
program_set = ProgramSet(binding)
sub = program_set.split(8)
assert [s.total_executables for s in sub] == [8, 8, 4]
# Observables propagate unchanged (never split across sub-program-sets).
for s in sub:
assert s.entries[0].observables == observables


def test_split_with_sum_hamiltonian(circuit_rx_parametrized):
# Sum with 3 summands => class size = 3 per parameter-set index.
inputs = {"theta": [float(i) for i in range(4)]}
hamiltonian = 1.0 * X(0) + 2.0 * Y(0) + 3.0 * Z(0)
binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=hamiltonian)
program_set = ProgramSet(binding)
sub = program_set.split(6)
assert [s.total_executables for s in sub] == [6, 6]
# Sum preserved intact.
for s in sub:
assert s.entries[0].observables is hamiltonian


def test_split_worked_example(circuit_rx_parametrized):
# Two bindings: c1 with 100 param sets × 4 obs, c2 with 50 param sets × 2 obs.
c1 = circuit_rx_parametrized
c2 = Circuit().rx(0, FreeParameter("phi"))
obs1 = [X(0), Y(0), Z(0), X(0) @ Y(1)]
obs2 = [X(0), Z(0)]
binding1 = CircuitBinding(c1, {"theta": [float(i) for i in range(100)]}, obs1)
binding2 = CircuitBinding(c2, {"phi": [float(i) for i in range(50)]}, obs2)
program_set = ProgramSet([binding1, binding2])

sub = program_set.split(120)
# Greedy packing fills each bucket up to the budget before flushing.
assert [s.total_executables for s in sub] == [120, 120, 120, 120, 20]
assert sum(s.total_executables for s in sub) == program_set.total_executables
# First three buckets are pure c1 (30 × 4 each).
for i in range(3):
assert len(sub[i]) == 1
assert sub[i].entries[0].circuit == c1
assert len(sub[i].entries[0].input_sets) == 30
# Bucket 3 straddles both bindings (10 × 4 + 40 × 2 = 120); coalesced per binding.
assert len(sub[3]) == 2
assert sub[3].entries[0].circuit == c1
assert len(sub[3].entries[0].input_sets) == 10
assert sub[3].entries[1].circuit == c2
assert len(sub[3].entries[1].input_sets) == 40
# Last bucket is pure c2 remainder (10 × 2 = 20).
assert len(sub[4]) == 1
assert sub[4].entries[0].circuit == c2
assert len(sub[4].entries[0].input_sets) == 10


def test_split_preserves_shots(circuit_rx_parametrized):
inputs = {"theta": [float(i) for i in range(5)]}
binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs)
program_set = ProgramSet(binding, shots_per_executable=100)
sub = program_set.split(2)
assert all(s.shots_per_executable == 100 for s in sub)
assert sum(s.total_shots for s in sub) == program_set.total_shots


def test_split_coalesces_adjacent_same_binding(circuit_rx_parametrized):
# 6 parameter-set indices, class size 1, max_executables=4 => buckets of 4, 2.
# Each bucket should contain one coalesced multi-parameter-set binding,
# not four (resp. two) separate single-parameter-set bindings.
inputs = {"theta": [float(i) for i in range(6)]}
binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs)
program_set = ProgramSet(binding)
sub = program_set.split(4)
assert [len(s) for s in sub] == [1, 1]
assert len(sub[0].entries[0].input_sets) == 4
assert len(sub[1].entries[0].input_sets) == 2


def test_split_binding_without_input_sets(circuit_rx_parametrized):
# A binding with only observables is a single class of size len(observables).
c1 = circuit_rx_parametrized
c2 = Circuit().rx(0, FreeParameter("phi"))
binding_a = CircuitBinding(c1, observables=[X(0), Y(0)]) # size 2
binding_b = CircuitBinding(c2, observables=[X(0), Y(0), Z(0)]) # size 3
program_set = ProgramSet([binding_a, binding_b])
sub = program_set.split(3)
assert [s.total_executables for s in sub] == [2, 3]
assert sub[0].entries == [binding_a]
assert sub[1].entries == [binding_b]


def test_split_non_positive_raises(circuit_rx_parametrized):
binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}])
program_set = ProgramSet(binding)
with pytest.raises(ValueError, match="must be positive"):
program_set.split(0)
with pytest.raises(ValueError, match="must be positive"):
program_set.split(-3)


def test_split_oversize_class_raises(circuit_rx_parametrized):
Comment thread
speller26 marked this conversation as resolved.
Outdated
# One parameter-set index with 3 observables exceeds max_executables=2.
inputs = {"theta": [1.0, 2.0]}
observables = [X(0), Y(0), Z(0)]
binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables)
program_set = ProgramSet(binding)
with pytest.raises(ValueError, match="exceeding max_executables"):
program_set.split(2)


def test_split_sub_program_sets_are_serializable(circuit_rx_parametrized):
inputs = {"theta": [float(i) for i in range(10)]}
observables = [X(0), Y(0)]
binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables)
program_set = ProgramSet(binding)
sub = program_set.split(6)
# Each sub-program set is a fully formed ProgramSet: to_ir() works and returns a
# single-program IR (one coalesced CircuitBinding per sub-program set here).
for s in sub:
ir = s.to_ir()
assert len(ir.programs) == len(s)
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ basepython = python3
deps =
{[test-deps]deps}
commands =
pytest {posargs} --cov=braket --cov-report term-missing --cov-report html --cov-report xml --cov-append
pytest {posargs} --cov --cov-report term-missing --cov-report html --cov-report xml --cov-append
extras = test

[testenv:integ-tests]
Expand Down
Loading