Skip to content

Commit 02e927f

Browse files
Sub state vector (#8077)
The current implementation appears to be exponential in space and time. The present change, due originally to jamieas@google.com, replaces python loops with numpy functions motivated by standard methods for computing partial traces. The change is runnable in Colab, including tests and a performance benchmark, at https://colab.sandbox.google.com/drive/1qWJTnlT6hPXI8cQsTwyqQveEtn7fGhNc?resourcekey=0-67OmbvgwTlf9fnKp3ppbrw --------- Co-authored-by: Anton Kast <akast@google.com>
1 parent bf16378 commit 02e927f

2 files changed

Lines changed: 38 additions & 32 deletions

File tree

cirq-core/cirq/linalg/transformations.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,6 @@
2525
import numpy as np
2626

2727
from cirq import protocols
28-
from cirq.linalg import predicates
29-
30-
# This is a special indicator value used by the `sub_state_vector` method to
31-
# determine whether or not the caller provided a 'default' argument. It must be
32-
# of type np.ndarray to ensure the method has the correct type signature in that
33-
# case. It is checked for using `is`, so it won't have a false positive if the
34-
# user provides a different np.array([]) value.
35-
RaiseValueErrorIfNotProvided: np.ndarray = np.array([])
3628

3729

3830
def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float) -> np.ndarray:
@@ -479,7 +471,7 @@ def sub_state_vector(
479471
state_vector: np.ndarray,
480472
keep_indices: list[int],
481473
*,
482-
default: np.ndarray = RaiseValueErrorIfNotProvided,
474+
default: np.ndarray | None = None,
483475
atol: float = 1e-6,
484476
) -> np.ndarray:
485477
r"""Attempts to factor a state vector into two parts and return one of them.
@@ -503,16 +495,16 @@ def sub_state_vector(
503495
504496
If the provided `state_vector` cannot be factored into a pure state over
505497
`keep_indices`, the method will fall back to return `default`. If `default`
506-
is not provided, the method will fail and raise `ValueError`.
498+
is not provided, the method will fail and raise EntangledStateError.
507499
508500
Args:
509501
state_vector: The target state_vector.
510502
keep_indices: Which indices to attempt to get the separable part of the
511503
`state_vector` on.
512504
default: Determines the fallback behavior when `state_vector` doesn't
513505
have a pure state factorization. If the factored state is not pure
514-
and `default` is not set, a ValueError is raised. If default is set
515-
to a value, that value is returned.
506+
and `default` is not set, an EntangledStateError is raised. If
507+
default is set to a value, that value is returned.
516508
atol: The minimum tolerance for comparing the output state's coherence
517509
measure to 1.
518510
@@ -540,36 +532,42 @@ def sub_state_vector(
540532
ret_shape: tuple[int] | tuple[int, ...]
541533
if state_vector.shape == (state_vector.size,):
542534
ret_shape = (keep_dims,)
543-
state_vector = state_vector.reshape((2,) * n_qubits)
544535
elif state_vector.shape == (2,) * n_qubits:
536+
state_vector = state_vector.reshape(-1)
545537
ret_shape = tuple(2 for _ in range(len(keep_indices)))
546538
else:
547539
raise ValueError("Input state_vector must be shaped like (2 ** n,) or (2,) * n")
548540

549-
keep_dims = 1 << len(keep_indices)
550541
if not np.isclose(np.linalg.norm(state_vector), 1):
551542
raise ValueError("Input state must be normalized.")
552543
if len(set(keep_indices)) != len(keep_indices):
553544
raise ValueError(f"keep_indices were {keep_indices} but must be unique.")
554545
if any(ind >= n_qubits for ind in keep_indices):
555546
raise ValueError("keep_indices {} are an invalid subset of the input state vector.")
556547

557-
other_qubits = sorted(set(range(n_qubits)) - set(keep_indices))
558-
candidates = [
559-
state_vector[predicates.slice_for_qubits_equal_to(other_qubits, k)].reshape(keep_dims)
560-
for k in range(1 << len(other_qubits))
561-
]
562-
# The coherence measure is computed using unnormalized candidates.
563-
best_candidate = max(candidates, key=lambda c: float(np.linalg.norm(c, 2)))
564-
best_candidate = best_candidate / np.linalg.norm(best_candidate)
565-
left = np.conj(best_candidate.reshape((keep_dims,))).T
566-
coherence_measure = sum(abs(np.dot(left, c.reshape((keep_dims,)))) ** 2 for c in candidates)
567-
568-
if protocols.approx_eq(coherence_measure, 1, atol=atol):
569-
return np.exp(2j * np.pi * np.random.random()) * best_candidate.reshape(ret_shape)
548+
# The permutation moves the specified qubits to the start of the qubit order.
549+
keeps = frozenset(keep_indices)
550+
remainder = np.array([i for i in range(n_qubits) if i not in keeps], dtype=np.int64)
551+
permutation = np.concatenate([keep_indices, remainder])
552+
553+
# Permute qubits and construct the pure-state density matrix.
554+
raveled = state_vector.reshape([2] * n_qubits)
555+
raveled = np.transpose(raveled, permutation)
556+
num_qubits_out = len(keep_indices)
557+
c_psi = raveled.reshape([2**num_qubits_out, -1])
558+
rho = c_psi @ c_psi.conj().T
559+
560+
# Return the eigenvector with eigenvalue 1.
561+
evals, evec = np.linalg.eigh(rho)
562+
if np.isclose(evals, 1, atol=atol).any():
563+
factor_index = np.argmax(evals)
564+
factored = evec[:, factor_index]
565+
# Prevent accidental reliance on global phase.
566+
random_phase = np.exp(2j * np.pi * np.random.random())
567+
return random_phase * factored.reshape(ret_shape)
570568

571569
# Method did not yield a pure state. Fall back to `default` argument.
572-
if default is not RaiseValueErrorIfNotProvided:
570+
if default is not None:
573571
return default
574572

575573
raise EntangledStateError(

cirq-core/cirq/linalg/transformations_test.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,23 +385,31 @@ def test_sub_state_vector() -> None:
385385
cirq.sub_state_vector(reshaped_state, [5, 6, 7, 8], atol=1e-15), c
386386
)
387387

388+
# Make an imperfect product state to probe tolerances.
389+
rng = np.random.default_rng(0)
390+
noise = rng.uniform(-1, 1, size=state.size) + 1j * rng.uniform(-1, 1, state.size)
391+
imperfect_state = state + 1e-2 * noise.reshape((2,) * 9)
392+
imperfect_state /= np.linalg.norm(imperfect_state)
393+
388394
# Reject factoring for very tight tolerance.
389395
assert (
390-
cirq.sub_state_vector(state, [0, 1], default=_DEFAULT_ARRAY, atol=1e-16) is _DEFAULT_ARRAY
396+
cirq.sub_state_vector(imperfect_state, [0, 1], default=_DEFAULT_ARRAY, atol=1e-16)
397+
is _DEFAULT_ARRAY
391398
)
392399
assert (
393-
cirq.sub_state_vector(state, [2, 3, 4], default=_DEFAULT_ARRAY, atol=1e-16)
400+
cirq.sub_state_vector(imperfect_state, [2, 3, 4], default=_DEFAULT_ARRAY, atol=1e-16)
394401
is _DEFAULT_ARRAY
395402
)
396403
assert (
397-
cirq.sub_state_vector(state, [5, 6, 7, 8], default=_DEFAULT_ARRAY, atol=1e-16)
404+
cirq.sub_state_vector(imperfect_state, [5, 6, 7, 8], default=_DEFAULT_ARRAY, atol=1e-16)
398405
is _DEFAULT_ARRAY
399406
)
400407

401408
# Permit invalid factoring for loose tolerance.
402409
for q1 in range(9):
403410
assert (
404-
cirq.sub_state_vector(state, [q1], default=_DEFAULT_ARRAY, atol=1) is not _DEFAULT_ARRAY
411+
cirq.sub_state_vector(imperfect_state, [q1], default=_DEFAULT_ARRAY, atol=1)
412+
is not _DEFAULT_ARRAY
405413
)
406414

407415

0 commit comments

Comments
 (0)