Skip to content

Commit d2e3fa5

Browse files
authored
fix: use sparse dict for PauliString (#1246)
1 parent bca955b commit d2e3fa5

1 file changed

Lines changed: 42 additions & 30 deletions

File tree

src/braket/quantum_information/pauli_string.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,17 @@ def weight_n_substrings(self, weight: int) -> tuple[PauliString, ...]:
115115
"""
116116
substrings = []
117117
for indices in itertools.combinations(self._nontrivial, weight):
118-
factors = [
119-
(
120-
self._nontrivial[qubit]
121-
if qubit in set(indices).intersection(self._nontrivial)
122-
else "I"
123-
)
124-
for qubit in range(self._qubit_count)
125-
]
126-
substrings.append(
127-
PauliString(f"{PauliString._phase_to_str(self._phase)}{''.join(factors)}")
128-
)
118+
idx_set = set(indices)
119+
nontrivial = {i: self._nontrivial[i] for i in idx_set}
120+
# Bypass __init__ via __new__ to skip string parsing. The internal
121+
# state (phase, qubit_count, nontrivial dict) is already known here,
122+
# so going through PauliString(str) would only round-trip the data
123+
# through a dense string representation for no benefit.
124+
ps = PauliString.__new__(PauliString)
125+
ps._phase = self._phase
126+
ps._qubit_count = self._qubit_count
127+
ps._nontrivial = nontrivial
128+
substrings.append(ps)
129129
return tuple(substrings)
130130

131131
def eigenstate(self, signs: str | list[int] | tuple[int, ...] | None = None) -> Circuit:
@@ -186,27 +186,30 @@ def dot(self, other: PauliString, inplace: bool = False) -> PauliString:
186186
f"Input Pauli string must be of length ({self._qubit_count}), "
187187
f"not {other._qubit_count}"
188188
)
189-
pauli_result = ""
189+
pauli_result = {}
190190
phase_result = self._phase * other._phase
191-
for i in range(self._qubit_count):
192-
# Are either identity?
193-
if i not in self._nontrivial and i not in other._nontrivial:
194-
pauli_result += "I"
195-
elif i not in self._nontrivial:
196-
pauli_result += other._nontrivial[i]
197-
elif i not in other._nontrivial:
198-
pauli_result += self._nontrivial[i]
199-
elif self._nontrivial[i] == other._nontrivial[i]:
200-
pauli_result += "I"
201-
else:
191+
for i in self._nontrivial:
192+
if i not in other._nontrivial:
193+
pauli_result[i] = self._nontrivial[i]
194+
elif self._nontrivial[i] != other._nontrivial[i]:
202195
gate, phase = _PRODUCT_MAP[self._nontrivial[i]][other._nontrivial[i]]
203-
pauli_result += gate
196+
pauli_result[i] = gate
204197
phase_result *= phase
198+
for i in other._nontrivial:
199+
if i not in self._nontrivial:
200+
pauli_result[i] = other._nontrivial[i]
205201

206202
# ignore complex global phase
207-
if phase_result.real < 0 or phase_result.imag < 0:
208-
pauli_result = f"-{pauli_result}"
209-
out_pauli_string = PauliString(pauli_result)
203+
out_phase = -1 if (phase_result.real < 0 or phase_result.imag < 0) else 1
204+
205+
# Bypass __init__ via __new__ to avoid serializing the computed dict
206+
# back into a string just to have __init__ parse it again. The fields
207+
# below fully define a valid PauliString, so direct assignment is both
208+
# faster and avoids an O(qubit_count) dense-string round trip.
209+
out_pauli_string = PauliString.__new__(PauliString)
210+
out_pauli_string._phase = out_phase
211+
out_pauli_string._qubit_count = self._qubit_count
212+
out_pauli_string._nontrivial = pauli_result
210213

211214
if inplace:
212215
self._phase = out_pauli_string._phase
@@ -270,11 +273,18 @@ def power(self, n: int, inplace: bool = False) -> PauliString:
270273
if not isinstance(n, int):
271274
raise TypeError("Must be raised to integer power")
272275

273-
# Since pauli ops involutory, result is either identity or unchanged
274-
pauli_other = PauliString(self)
276+
# Since pauli ops involutory, result is either identity or unchanged.
277+
# Bypass __init__ via __new__ to skip the PauliString(self) copy path,
278+
# which would re-validate fields we already know are consistent. Direct
279+
# field assignment keeps the hot path allocation-light.
280+
pauli_other = PauliString.__new__(PauliString)
281+
pauli_other._qubit_count = self._qubit_count
275282
if n % 2 == 0:
276283
pauli_other._phase = 1
277284
pauli_other._nontrivial = {}
285+
else:
286+
pauli_other._phase = self._phase
287+
pauli_other._nontrivial = dict(self._nontrivial)
278288

279289
if inplace:
280290
self._phase = pauli_other._phase
@@ -360,7 +370,9 @@ def __len__(self):
360370
return self._qubit_count
361371

362372
def __repr__(self):
363-
factors = [self._nontrivial.get(qubit, "I") for qubit in range(self._qubit_count)]
373+
factors = ["I"] * self._qubit_count
374+
for i, p in self._nontrivial.items():
375+
factors[i] = p
364376
return f"{PauliString._phase_to_str(self._phase)}{''.join(factors)}"
365377

366378
@staticmethod

0 commit comments

Comments
 (0)