Skip to content

Commit fd4353f

Browse files
committed
feat: add PauliStringSum container
1 parent ec31c80 commit fd4353f

4 files changed

Lines changed: 363 additions & 2 deletions

File tree

src/braket/quantum_information/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# language governing permissions and limitations under the License.
1313

1414
"""Provides utilities for working with quantum information concepts. It
15-
includes the PauliString class for representing and manipulating tensor products
16-
of Pauli operators.
15+
includes classes for representing and manipulating tensor products and weighted
16+
sums of Pauli operators.
1717
"""
1818

1919
from braket.quantum_information.pauli_string import PauliString # noqa: F401
20+
from braket.quantum_information.pauli_string_sum import PauliStringSum # noqa: F401

src/braket/quantum_information/pauli_string.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def __mul__(self, other: PauliString) -> PauliString:
234234
See Also:
235235
`braket.quantum_information.PauliString.dot()`
236236
"""
237+
if not isinstance(other, PauliString):
238+
return NotImplemented
237239
return self.dot(other)
238240

239241
def __imul__(self, other: PauliString) -> PauliString:
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
from __future__ import annotations
15+
16+
import numbers
17+
from collections.abc import Iterable
18+
from itertools import combinations
19+
20+
from braket.circuits.observable import Observable, StandardObservable
21+
from braket.circuits.observables import I, Sum, TensorProduct, X, Y, Z
22+
from braket.quantum_information.pauli_string import PauliString
23+
24+
_OBSERVABLE_TO_FACTOR = {I: "I", X: "X", Y: "Y", Z: "Z"}
25+
26+
27+
class PauliStringSum:
28+
"""A weighted sum of Pauli strings."""
29+
30+
def __init__(self, terms: Iterable[tuple[numbers.Number, str | PauliString]] = ()):
31+
"""Initializes a ``PauliStringSum``.
32+
33+
Args:
34+
terms (Iterable[tuple[numbers.Number, str | PauliString]]): Pairs of coefficient and
35+
Pauli string.
36+
"""
37+
self._terms: dict[str, numbers.Number] = {}
38+
for coefficient, pauli_string in terms:
39+
self._add_term(coefficient, PauliString(pauli_string))
40+
41+
@classmethod
42+
def from_list(cls, terms: Iterable[tuple[numbers.Number, str | PauliString]]) -> PauliStringSum:
43+
"""Builds a ``PauliStringSum`` from a list of weighted Pauli strings."""
44+
return cls(terms)
45+
46+
@property
47+
def terms(self) -> tuple[tuple[numbers.Number, PauliString], ...]:
48+
"""tuple[tuple[numbers.Number, PauliString], ...]: The weighted Pauli terms."""
49+
return tuple(
50+
(coefficient, PauliString(pauli)) for pauli, coefficient in self._terms.items()
51+
)
52+
53+
@property
54+
def qubit_count(self) -> int:
55+
"""int: The number of qubits in the largest Pauli string term."""
56+
if not self._terms:
57+
return 0
58+
return max(PauliString(pauli).qubit_count for pauli in self._terms)
59+
60+
def to_list(self) -> list[tuple[numbers.Number, str]]:
61+
"""Returns a list representation of the weighted Pauli strings."""
62+
return [(coefficient, pauli) for pauli, coefficient in self._terms.items()]
63+
64+
def to_sum(self) -> Sum:
65+
"""Converts the weighted Pauli strings into a circuit ``Sum`` observable."""
66+
if not self._terms:
67+
raise ValueError("Cannot convert an empty PauliStringSum to Sum")
68+
observables = []
69+
for coefficient, pauli in self.terms:
70+
observables.append(coefficient * pauli.to_unsigned_observable(include_trivial=True))
71+
return Sum(observables)
72+
73+
def commutes_with(self, other: PauliStringSum | PauliString | str) -> bool:
74+
"""Returns whether all terms commute with ``other``."""
75+
other_sum = self._coerce(other)
76+
return all(
77+
self._pauli_strings_commute(left, right)
78+
for _, left in self.terms
79+
for _, right in other_sum.terms
80+
)
81+
82+
def is_self_commuting(self) -> bool:
83+
"""Returns whether all terms in this sum commute with each other."""
84+
return all(
85+
self._pauli_strings_commute(left, right)
86+
for (_, left), (_, right) in combinations(self.terms, 2)
87+
)
88+
89+
@classmethod
90+
def from_sum(cls, observable_sum: Sum) -> PauliStringSum:
91+
"""Builds a ``PauliStringSum`` from a circuit ``Sum`` observable."""
92+
if not isinstance(observable_sum, Sum):
93+
raise TypeError("Expected a Sum observable")
94+
return cls(cls._term_from_observable(observable) for observable in observable_sum.summands)
95+
96+
def __add__(self, other: PauliStringSum | PauliString | str) -> PauliStringSum:
97+
other_sum = self._coerce(other)
98+
return PauliStringSum((*self.terms, *other_sum.terms))
99+
100+
def __radd__(self, other: PauliStringSum | PauliString | str) -> PauliStringSum:
101+
return self + other
102+
103+
def __sub__(self, other: PauliStringSum | PauliString | str) -> PauliStringSum:
104+
return self + (-1 * self._coerce(other))
105+
106+
def __rsub__(self, other: PauliStringSum | PauliString | str) -> PauliStringSum:
107+
return self._coerce(other) + (-self)
108+
109+
def __neg__(self) -> PauliStringSum:
110+
return -1 * self
111+
112+
def __mul__(self, other: numbers.Number | PauliString | str) -> PauliStringSum:
113+
if isinstance(other, numbers.Number):
114+
return PauliStringSum((coefficient * other, pauli) for coefficient, pauli in self.terms)
115+
pauli = PauliString(other)
116+
qubit_count = max(self.qubit_count, pauli.qubit_count)
117+
right = self._pad_pauli_string(pauli, qubit_count)
118+
return PauliStringSum(
119+
(coefficient, self._pad_pauli_string(term, qubit_count).dot(right))
120+
for coefficient, term in self.terms
121+
)
122+
123+
def __rmul__(self, other: numbers.Number | PauliString | str) -> PauliStringSum:
124+
if isinstance(other, numbers.Number):
125+
return self * other
126+
pauli = PauliString(other)
127+
qubit_count = max(self.qubit_count, pauli.qubit_count)
128+
left = self._pad_pauli_string(pauli, qubit_count)
129+
return PauliStringSum(
130+
(coefficient, left.dot(self._pad_pauli_string(term, qubit_count)))
131+
for coefficient, term in self.terms
132+
)
133+
134+
def __contains__(self, item: str | PauliString) -> bool:
135+
_, pauli = self._canonical_term(PauliString(item))
136+
return pauli in self._terms
137+
138+
def __getitem__(self, item: int) -> tuple[numbers.Number, PauliString]:
139+
return self.terms[item]
140+
141+
def __iter__(self):
142+
return iter(self.terms)
143+
144+
def __len__(self) -> int:
145+
return len(self._terms)
146+
147+
def __eq__(self, other: PauliStringSum) -> bool:
148+
if not isinstance(other, PauliStringSum):
149+
return False
150+
return self._terms == other._terms
151+
152+
def __repr__(self) -> str:
153+
return f"PauliStringSum({self.to_list()})"
154+
155+
def _add_term(self, coefficient: numbers.Number, pauli_string: PauliString) -> None:
156+
if not isinstance(coefficient, numbers.Number):
157+
raise TypeError("PauliStringSum coefficients must be numbers")
158+
coefficient, pauli = self._canonical_term(pauli_string, coefficient)
159+
if coefficient == 0:
160+
return
161+
new_coefficient = self._terms.get(pauli, 0) + coefficient
162+
if new_coefficient == 0:
163+
self._terms.pop(pauli, None)
164+
else:
165+
self._terms[pauli] = new_coefficient
166+
167+
@staticmethod
168+
def _canonical_term(
169+
pauli_string: PauliString, coefficient: numbers.Number = 1
170+
) -> tuple[numbers.Number, str]:
171+
factors = ["I"] * pauli_string.qubit_count
172+
for qubit in range(pauli_string.qubit_count):
173+
factors[qubit] = "IXYZ"[pauli_string[qubit]]
174+
return coefficient * pauli_string.phase, f"+{''.join(factors)}"
175+
176+
@staticmethod
177+
def _coerce(other: PauliStringSum | PauliString | str) -> PauliStringSum:
178+
if isinstance(other, PauliStringSum):
179+
return other
180+
return PauliStringSum([(1, other)])
181+
182+
@staticmethod
183+
def _pauli_strings_commute(left: PauliString, right: PauliString) -> bool:
184+
anticommuting_factors = 0
185+
qubit_count = max(left.qubit_count, right.qubit_count)
186+
for qubit in range(qubit_count):
187+
left_factor = left[qubit] if qubit < left.qubit_count else 0
188+
right_factor = right[qubit] if qubit < right.qubit_count else 0
189+
if left_factor and right_factor and left_factor != right_factor:
190+
anticommuting_factors += 1
191+
return anticommuting_factors % 2 == 0
192+
193+
@staticmethod
194+
def _pad_pauli_string(pauli_string: PauliString, qubit_count: int) -> PauliString:
195+
if pauli_string.qubit_count == qubit_count:
196+
return pauli_string
197+
factors = ["I"] * qubit_count
198+
for qubit in range(pauli_string.qubit_count):
199+
factors[qubit] = "IXYZ"[pauli_string[qubit]]
200+
sign = "-" if pauli_string.phase < 0 else "+"
201+
return PauliString(f"{sign}{''.join(factors)}")
202+
203+
@staticmethod
204+
def _term_from_observable(observable: Observable) -> tuple[numbers.Number, str]:
205+
coefficient = observable.coefficient
206+
unscaled = observable._unscaled()
207+
if isinstance(unscaled, StandardObservable):
208+
factors = PauliStringSum._factors_from_standard(unscaled)
209+
elif isinstance(unscaled, TensorProduct):
210+
factors = PauliStringSum._factors_from_tensor_product(unscaled)
211+
else:
212+
raise TypeError(f"Unsupported observable type {type(observable).__name__}")
213+
return coefficient, f"+{''.join(factors)}"
214+
215+
@staticmethod
216+
def _factors_from_standard(observable: StandardObservable) -> list[str]:
217+
factor = PauliStringSum._factor_from_standard(observable)
218+
if observable.targets:
219+
factors = ["I"] * (max(observable.targets) + 1)
220+
factors[int(observable.targets[0])] = factor
221+
return factors
222+
return [factor]
223+
224+
@staticmethod
225+
def _factors_from_tensor_product(observable: TensorProduct) -> list[str]:
226+
if observable.targets:
227+
factors = ["I"] * (max(observable.targets) + 1)
228+
for factor, target in zip(observable.factors, observable.targets, strict=True):
229+
factors[int(target)] = PauliStringSum._factor_from_standard(factor)
230+
return factors
231+
return [PauliStringSum._factor_from_standard(factor) for factor in observable.factors]
232+
233+
@staticmethod
234+
def _factor_from_standard(observable: StandardObservable) -> str:
235+
for observable_type, factor in _OBSERVABLE_TO_FACTOR.items():
236+
if isinstance(observable, observable_type):
237+
return factor
238+
raise TypeError(f"Unsupported observable factor {type(observable).__name__}")
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import pytest
15+
16+
from braket.circuits.observables import H, X, Y, Z
17+
from braket.quantum_information import PauliString, PauliStringSum
18+
19+
20+
def test_initializes_from_weighted_list_and_combines_terms():
21+
pauli_sum = PauliStringSum.from_list([(1.5, "XZ"), (2.0, "-XZ"), (3.0, "YY")])
22+
23+
assert pauli_sum.to_list() == [(-0.5, "+XZ"), (3.0, "+YY")]
24+
25+
26+
def test_addition_subtraction_and_scalar_multiplication():
27+
first = PauliStringSum([(1.0, "XI")])
28+
second = PauliStringSum([(2.0, "IZ")])
29+
30+
assert (first + second - PauliString("XI")).to_list() == [(2.0, "+IZ")]
31+
assert (0.5 * second).to_list() == [(1.0, "+IZ")]
32+
assert (-second).to_list() == [(-2.0, "+IZ")]
33+
34+
35+
def test_multiplication_by_pauli_string():
36+
pauli_sum = PauliStringSum([(2.0, "XY"), (3.0, "ZZ")])
37+
38+
assert (pauli_sum * PauliString("YZ")).to_list() == [(-2.0, "+ZX"), (-3.0, "+XI")]
39+
40+
41+
def test_multiplication_by_pauli_string_pads_mixed_width_terms():
42+
pauli_sum = PauliStringSum([(1.0, "X"), (2.0, "IZ")])
43+
44+
assert (pauli_sum * PauliString("ZZ")).to_list() == [(-1.0, "+YZ"), (2.0, "+ZI")]
45+
46+
47+
def test_left_multiplication_by_pauli_string_preserves_order():
48+
pauli_sum = PauliStringSum([(2.0, "X")])
49+
50+
assert (pauli_sum * PauliString("Z")).to_list() == [(-2.0, "+Y")]
51+
assert (PauliString("Z") * pauli_sum).to_list() == [(2.0, "+Y")]
52+
53+
54+
def test_indexing_and_membership():
55+
pauli_sum = PauliStringSum([(2.0, "XI"), (3.0, "IZ")])
56+
57+
assert "XI" in pauli_sum
58+
assert PauliString("-IZ") in pauli_sum
59+
assert pauli_sum[0] == (2.0, PauliString("XI"))
60+
assert list(pauli_sum) == [(2.0, PauliString("XI")), (3.0, PauliString("IZ"))]
61+
62+
63+
def test_to_sum_and_from_sum_round_trip():
64+
pauli_sum = PauliStringSum([(2.0, "XY"), (-3.0, "ZI")])
65+
66+
observable_sum = pauli_sum.to_sum()
67+
68+
assert PauliStringSum.from_sum(observable_sum) == pauli_sum
69+
70+
71+
def test_equality_is_independent_of_term_insertion_order():
72+
first = PauliStringSum([(1.0, "XI"), (2.0, "IZ")])
73+
second = PauliStringSum([(2.0, "IZ"), (1.0, "XI")])
74+
75+
assert first == second
76+
77+
78+
def test_reverse_subtraction():
79+
pauli_sum = PauliStringSum([(2.0, "XI")])
80+
81+
assert (PauliString("IZ") - pauli_sum).to_list() == [(1, "+IZ"), (-2.0, "+XI")]
82+
83+
84+
def test_from_sum_with_targeted_observables():
85+
observable_sum = 2.0 * (X(0) @ Y(2)) - 3.0 * Z(1)
86+
87+
assert PauliStringSum.from_sum(observable_sum).to_list() == [
88+
(2.0, "+XIY"),
89+
(-3.0, "+IZ"),
90+
]
91+
92+
93+
def test_commutation_checks():
94+
commuting_sum = PauliStringSum([(1.0, "XX"), (2.0, "YY")])
95+
non_commuting_sum = PauliStringSum([(1.0, "XI"), (2.0, "ZI")])
96+
97+
assert commuting_sum.is_self_commuting()
98+
assert commuting_sum.commutes_with("ZZ")
99+
assert not non_commuting_sum.is_self_commuting()
100+
assert not non_commuting_sum.commutes_with("YI")
101+
102+
103+
def test_empty_sum_cannot_convert_to_observable_sum():
104+
with pytest.raises(ValueError, match="empty PauliStringSum"):
105+
PauliStringSum().to_sum()
106+
107+
108+
def test_non_numeric_coefficients_are_rejected():
109+
with pytest.raises(TypeError, match="coefficients must be numbers"):
110+
PauliStringSum([("weight", "XI")])
111+
112+
113+
def test_from_sum_rejects_non_sum_observable():
114+
with pytest.raises(TypeError, match="Expected a Sum observable"):
115+
PauliStringSum.from_sum(X())
116+
117+
118+
def test_from_sum_rejects_non_pauli_observable_terms():
119+
with pytest.raises(TypeError, match="Unsupported observable factor H"):
120+
PauliStringSum.from_sum(H() + X())

0 commit comments

Comments
 (0)