Skip to content

Commit 1299de6

Browse files
authored
feat: Add PauliSum container (#1270)
1 parent 2916f20 commit 1299de6

5 files changed

Lines changed: 520 additions & 3 deletions

File tree

src/braket/quantum_information/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# language governing permissions and limitations under the License.
1313

1414
"""Provides utilities for working with quantum information concepts. It
15-
includes the PauliString class for representing and manipulating tensor products
16-
of Pauli operators.
15+
includes classes for representing and manipulating tensor products and weighted
16+
sums of Pauli operators.
1717
"""
1818

1919
from braket.quantum_information.pauli_string import PauliString # noqa: F401
20+
from braket.quantum_information.pauli_sum import PauliStringSum, PauliSum # noqa: F401

src/braket/quantum_information/pauli_string.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def __mul__(self, other: PauliString) -> PauliString:
234234
See Also:
235235
`braket.quantum_information.PauliString.dot()`
236236
"""
237+
if not isinstance(other, PauliString):
238+
return NotImplemented
237239
return self.dot(other)
238240

239241
def __imul__(self, other: PauliString) -> PauliString:
@@ -255,6 +257,55 @@ def __imul__(self, other: PauliString) -> PauliString:
255257
"""
256258
return self.dot(other, inplace=True)
257259

260+
def __add__(self, other: PauliString | str):
261+
"""Operator overload for addition with another Pauli string.
262+
263+
Args:
264+
other (PauliString | str): The Pauli string to add.
265+
266+
Returns:
267+
PauliSum: The sum containing both Pauli strings.
268+
"""
269+
if not isinstance(other, (PauliString, str)):
270+
return NotImplemented
271+
from braket.quantum_information.pauli_sum import PauliSum # noqa: PLC0415
272+
273+
return PauliSum([(1, self), (1, other)])
274+
275+
def __radd__(self, other: PauliString | str):
276+
"""Operator overload for reverse addition with another Pauli string.
277+
278+
Args:
279+
other (PauliString | str): The Pauli string to add.
280+
281+
Returns:
282+
PauliSum: The sum containing both Pauli strings.
283+
"""
284+
if not isinstance(other, (PauliString, str)):
285+
return NotImplemented
286+
from braket.quantum_information.pauli_sum import PauliSum # noqa: PLC0415
287+
288+
return PauliSum([(1, other), (1, self)])
289+
290+
def commutes_with(self, other: PauliString | str) -> bool:
291+
"""Returns whether this Pauli string commutes with another Pauli string.
292+
293+
Args:
294+
other (PauliString | str): The Pauli string to check against.
295+
296+
Returns:
297+
bool: Whether the Pauli strings commute.
298+
"""
299+
other_pauli = PauliString(other)
300+
anticommuting_factors = 0
301+
qubit_count = max(self._qubit_count, other_pauli._qubit_count)
302+
for qubit in range(qubit_count):
303+
left_factor = self[qubit] if qubit < self._qubit_count else 0
304+
right_factor = other_pauli[qubit] if qubit < other_pauli._qubit_count else 0
305+
if left_factor and right_factor and left_factor != right_factor:
306+
anticommuting_factors += 1
307+
return anticommuting_factors % 2 == 0
308+
258309
def power(self, n: int, inplace: bool = False) -> PauliString:
259310
"""Composes Pauli string with itself n times.
260311
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
from __future__ import annotations
15+
16+
import numbers
17+
from collections.abc import Iterable
18+
from itertools import combinations, starmap
19+
20+
from braket.circuits.observable import Observable, StandardObservable
21+
from braket.circuits.observables import I, Sum, TensorProduct, X, Y, Z
22+
from braket.quantum_information.pauli_string import PauliString
23+
24+
_OBSERVABLE_TO_FACTOR = {I: "I", X: "X", Y: "Y", Z: "Z"}
25+
26+
27+
class PauliSum:
28+
"""A weighted sum of Pauli strings."""
29+
30+
def __init__(self, terms: Iterable[tuple[numbers.Number, str | PauliString]] = ()):
31+
"""Initializes a ``PauliSum``.
32+
33+
Args:
34+
terms (Iterable[tuple[numbers.Number, str | PauliString]]): Pairs of coefficient and
35+
Pauli string.
36+
"""
37+
self._terms: dict[str, numbers.Number] = {}
38+
for coefficient, pauli_string in terms:
39+
self._add_term(coefficient, PauliString(pauli_string))
40+
self._all_terms_commute = self._compute_all_terms_commute()
41+
42+
@property
43+
def all_terms_commute(self) -> bool:
44+
"""bool: Whether all terms in the sum commute with each other."""
45+
return self._all_terms_commute
46+
47+
@classmethod
48+
def from_list(cls, terms: Iterable[tuple[numbers.Number, str | PauliString]]) -> PauliSum:
49+
"""Builds a ``PauliSum`` from a list of weighted Pauli strings."""
50+
return cls(terms)
51+
52+
@property
53+
def terms(self) -> tuple[tuple[numbers.Number, PauliString], ...]:
54+
"""tuple[tuple[numbers.Number, PauliString], ...]: The weighted Pauli terms."""
55+
return tuple(
56+
(coefficient, PauliString(pauli)) for pauli, coefficient in self._terms.items()
57+
)
58+
59+
@property
60+
def qubit_count(self) -> int:
61+
"""int: The number of qubits in the largest Pauli string term."""
62+
if not self._terms:
63+
return 0
64+
return max(PauliString(pauli).qubit_count for pauli in self._terms)
65+
66+
def to_list(self) -> list[tuple[numbers.Number, str]]:
67+
"""Returns a list representation of the weighted Pauli strings."""
68+
return [(coefficient, pauli) for pauli, coefficient in self._terms.items()]
69+
70+
def to_sum(self) -> Sum:
71+
"""Converts the weighted Pauli strings into a circuit ``Sum`` observable."""
72+
if not self._terms:
73+
raise ValueError("Cannot convert an empty PauliSum to Sum")
74+
observables = []
75+
for coefficient, pauli in self.terms:
76+
observables.append(coefficient * pauli.to_unsigned_observable(include_trivial=True))
77+
return Sum(observables)
78+
79+
def commutes_with(self, other: PauliSum | PauliString | str) -> bool:
80+
"""Returns whether all terms commute with ``other``."""
81+
other_sum = self._coerce(other)
82+
if not self.all_terms_commute or not other_sum.all_terms_commute:
83+
return False
84+
return all(
85+
self._pauli_strings_commute(left, right)
86+
for _, left in self.terms
87+
for _, right in other_sum.terms
88+
)
89+
90+
def is_self_commuting(self) -> bool:
91+
"""Returns whether all terms in this sum commute with each other."""
92+
return self._all_terms_commute
93+
94+
@classmethod
95+
def from_sum(cls, observable_sum: Sum) -> PauliSum:
96+
"""Builds a ``PauliSum`` from a circuit ``Sum`` observable."""
97+
if not isinstance(observable_sum, Sum):
98+
raise TypeError("Expected a Sum observable")
99+
return cls(cls._term_from_observable(observable) for observable in observable_sum.summands)
100+
101+
def __add__(self, other: PauliSum | PauliString | str) -> PauliSum:
102+
other_sum = self._coerce(other)
103+
return PauliSum((*self.terms, *other_sum.terms))
104+
105+
def __radd__(self, other: PauliSum | PauliString | str) -> PauliSum:
106+
return self + other
107+
108+
def __sub__(self, other: PauliSum | PauliString | str) -> PauliSum:
109+
return self + (-1 * self._coerce(other))
110+
111+
def __rsub__(self, other: PauliSum | PauliString | str) -> PauliSum:
112+
return self._coerce(other) + (-self)
113+
114+
def __neg__(self) -> PauliSum:
115+
return -1 * self
116+
117+
def __mul__(self, other: numbers.Number | PauliString | str) -> PauliSum:
118+
if isinstance(other, numbers.Number):
119+
return PauliSum((coefficient * other, pauli) for coefficient, pauli in self.terms)
120+
pauli = PauliString(other)
121+
qubit_count = max(self.qubit_count, pauli.qubit_count)
122+
right = self._pad_pauli_string(pauli, qubit_count)
123+
return PauliSum(
124+
(coefficient, self._pad_pauli_string(term, qubit_count).dot(right))
125+
for coefficient, term in self.terms
126+
)
127+
128+
def __rmul__(self, other: numbers.Number | PauliString | str) -> PauliSum:
129+
if isinstance(other, numbers.Number):
130+
return self * other
131+
pauli = PauliString(other)
132+
qubit_count = max(self.qubit_count, pauli.qubit_count)
133+
left = self._pad_pauli_string(pauli, qubit_count)
134+
return PauliSum(
135+
(coefficient, left.dot(self._pad_pauli_string(term, qubit_count)))
136+
for coefficient, term in self.terms
137+
)
138+
139+
def __contains__(self, item: str | PauliString) -> bool:
140+
_, pauli = self._canonical_term(PauliString(item))
141+
return pauli in self._terms
142+
143+
def __getitem__(self, item: int) -> tuple[numbers.Number, PauliString]:
144+
return self.terms[item]
145+
146+
def __iter__(self):
147+
return iter(self.terms)
148+
149+
def __len__(self) -> int:
150+
return len(self._terms)
151+
152+
def __eq__(self, other: PauliSum) -> bool:
153+
if not isinstance(other, PauliSum):
154+
return False
155+
return self._terms == other._terms
156+
157+
def __repr__(self) -> str:
158+
return f"PauliSum({self.to_list()})"
159+
160+
def _add_term(self, coefficient: numbers.Number, pauli_string: PauliString) -> None:
161+
if not isinstance(coefficient, numbers.Number):
162+
raise TypeError("PauliSum coefficients must be numbers")
163+
coefficient, pauli = self._canonical_term(pauli_string, coefficient)
164+
if coefficient == 0:
165+
return
166+
new_coefficient = self._terms.get(pauli, 0) + coefficient
167+
if new_coefficient == 0:
168+
self._terms.pop(pauli, None)
169+
else:
170+
self._terms[pauli] = new_coefficient
171+
172+
def _compute_all_terms_commute(self) -> bool:
173+
if len(self._terms) <= 1:
174+
return True
175+
paulis = [PauliString(pauli) for pauli in self._terms]
176+
return all(starmap(self._pauli_strings_commute, combinations(paulis, 2)))
177+
178+
@staticmethod
179+
def _canonical_term(
180+
pauli_string: PauliString, coefficient: numbers.Number = 1
181+
) -> tuple[numbers.Number, str]:
182+
factors = ["I"] * pauli_string.qubit_count
183+
for qubit in range(pauli_string.qubit_count):
184+
factors[qubit] = "IXYZ"[pauli_string[qubit]]
185+
return coefficient * pauli_string.phase, f"+{''.join(factors)}"
186+
187+
@staticmethod
188+
def _coerce(other: PauliSum | PauliString | str) -> PauliSum:
189+
if isinstance(other, PauliSum):
190+
return other
191+
return PauliSum([(1, other)])
192+
193+
@staticmethod
194+
def _pauli_strings_commute(left: PauliString, right: PauliString) -> bool:
195+
anticommuting_factors = 0
196+
qubit_count = max(left.qubit_count, right.qubit_count)
197+
for qubit in range(qubit_count):
198+
left_factor = left[qubit] if qubit < left.qubit_count else 0
199+
right_factor = right[qubit] if qubit < right.qubit_count else 0
200+
if left_factor and right_factor and left_factor != right_factor:
201+
anticommuting_factors += 1
202+
return anticommuting_factors % 2 == 0
203+
204+
@staticmethod
205+
def _pad_pauli_string(pauli_string: PauliString, qubit_count: int) -> PauliString:
206+
if pauli_string.qubit_count == qubit_count:
207+
return pauli_string
208+
factors = ["I"] * qubit_count
209+
for qubit in range(pauli_string.qubit_count):
210+
factors[qubit] = "IXYZ"[pauli_string[qubit]]
211+
sign = "-" if pauli_string.phase < 0 else "+"
212+
return PauliString(f"{sign}{''.join(factors)}")
213+
214+
@staticmethod
215+
def _term_from_observable(observable: Observable) -> tuple[numbers.Number, str]:
216+
coefficient = observable.coefficient
217+
unscaled = observable._unscaled()
218+
if isinstance(unscaled, StandardObservable):
219+
factors = PauliSum._factors_from_standard(unscaled)
220+
elif isinstance(unscaled, TensorProduct):
221+
factors = PauliSum._factors_from_tensor_product(unscaled)
222+
else:
223+
raise TypeError(f"Unsupported observable type {type(observable).__name__}")
224+
return coefficient, f"+{''.join(factors)}"
225+
226+
@staticmethod
227+
def _factors_from_standard(observable: StandardObservable) -> list[str]:
228+
factor = PauliSum._factor_from_standard(observable)
229+
if observable.targets:
230+
factors = ["I"] * (max(observable.targets) + 1)
231+
factors[int(observable.targets[0])] = factor
232+
return factors
233+
return [factor]
234+
235+
@staticmethod
236+
def _factors_from_tensor_product(observable: TensorProduct) -> list[str]:
237+
if observable.targets:
238+
factors = ["I"] * (max(observable.targets) + 1)
239+
for factor, target in zip(observable.factors, observable.targets, strict=True):
240+
factors[int(target)] = PauliSum._factor_from_standard(factor)
241+
return factors
242+
return [PauliSum._factor_from_standard(factor) for factor in observable.factors]
243+
244+
@staticmethod
245+
def _factor_from_standard(observable: StandardObservable) -> str:
246+
for observable_type, factor in _OBSERVABLE_TO_FACTOR.items():
247+
if isinstance(observable, observable_type):
248+
return factor
249+
raise TypeError(f"Unsupported observable factor {type(observable).__name__}")
250+
251+
252+
PauliStringSum = PauliSum

test/unit_tests/braket/quantum_information/test_pauli_string.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from braket.circuits import gates
2222
from braket.circuits.circuit import Circuit
2323
from braket.circuits.observables import I, X, Y, Z
24-
from braket.quantum_information import PauliString
24+
from braket.quantum_information import PauliString, PauliStringSum
2525

2626
ORDER = ["I", "X", "Y", "Z"]
2727
PAULI_INDEX_MATRICES = {
@@ -172,6 +172,35 @@ def test_dot(circ_arg_1, circ_arg_2, circ_res):
172172
assert circ1 == PauliString(circ_res)
173173

174174

175+
def test_pauli_string_addition_builds_pauli_string_sum():
176+
pauli_sum = PauliString("XI") + PauliString("-IZ")
177+
178+
assert pauli_sum == PauliStringSum([(1, "XI"), (1, "-IZ")])
179+
180+
181+
def test_pauli_string_addition_accepts_strings_and_reverse_addition():
182+
assert PauliString("XI") + "IZ" == PauliStringSum([(1, "XI"), (1, "IZ")])
183+
assert "IZ" + PauliString("XI") == PauliStringSum([(1, "IZ"), (1, "XI")])
184+
185+
186+
def test_pauli_string_operator_overloads_reject_unsupported_types():
187+
pauli_string = PauliString("XI")
188+
189+
with pytest.raises(TypeError):
190+
pauli_string * "IZ"
191+
with pytest.raises(TypeError):
192+
pauli_string + 1
193+
with pytest.raises(TypeError):
194+
1 + pauli_string
195+
196+
197+
def test_pauli_string_commutation_check():
198+
assert PauliString("XX").commutes_with("YY")
199+
assert PauliString("X").commutes_with("IZ")
200+
assert PauliString("IZ").commutes_with("X")
201+
assert not PauliString("XI").commutes_with("ZI")
202+
203+
175204
@pytest.mark.xfail(raises=ValueError)
176205
@pytest.mark.parametrize(
177206
"circ1, circ2, operation",

0 commit comments

Comments
 (0)