From 5b644c4fc6518a3c2ab6a5c5e0770da845ffe0c2 Mon Sep 17 00:00:00 2001 From: Anton Kast Date: Thu, 14 May 2026 14:33:43 -0700 Subject: [PATCH 1/4] Speed up cirq.sub_state_vector(). The current implementation appears to be exponential in space and time. The present change, due originally to jamieas@google.com, replaces python loops with numpy functions motivated by standard methods for computing partial traces. The change is runnable in Colab, including tests and a performance benchmark, at https://colab.sandbox.google.com/drive/1qWJTnlT6hPXI8cQsTwyqQveEtn7fGhNc?resourcekey=0-67OmbvgwTlf9fnKp3ppbrw --- cirq-core/cirq/linalg/transformations.py | 60 ++++++++++--------- cirq-core/cirq/linalg/transformations_test.py | 14 +++-- 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 1acc9b4c08f..8d4415e0ae7 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -27,13 +27,6 @@ from cirq import protocols from cirq.linalg import predicates -# This is a special indicator value used by the `sub_state_vector` method to -# determine whether or not the caller provided a 'default' argument. It must be -# of type np.ndarray to ensure the method has the correct type signature in that -# case. It is checked for using `is`, so it won't have a false positive if the -# user provides a different np.array([]) value. -RaiseValueErrorIfNotProvided: np.ndarray = np.array([]) - def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float) -> np.ndarray: """Raises a matrix with two opposing eigenvalues to a power. @@ -477,9 +470,9 @@ def partial_trace_of_state_vector_as_mixture( def sub_state_vector( state_vector: np.ndarray, - keep_indices: list[int], + keep_indices: np.ndarray, *, - default: np.ndarray = RaiseValueErrorIfNotProvided, + default: np.ndarray | None = None, atol: float = 1e-6, ) -> np.ndarray: r"""Attempts to factor a state vector into two parts and return one of them. @@ -511,8 +504,8 @@ def sub_state_vector( `state_vector` on. default: Determines the fallback behavior when `state_vector` doesn't have a pure state factorization. If the factored state is not pure - and `default` is not set, a ValueError is raised. If default is set - to a value, that value is returned. + and `default` is not set, an EntangledStateError is raised. If + default is set to a value, that value is returned. atol: The minimum tolerance for comparing the output state's coherence measure to 1. @@ -528,7 +521,6 @@ def sub_state_vector( `default` is not provided. """ - if not np.log2(state_vector.size).is_integer(): raise ValueError( f"Input state_vector of size {state_vector.size} does not represent a " @@ -540,36 +532,46 @@ def sub_state_vector( ret_shape: tuple[int] | tuple[int, ...] if state_vector.shape == (state_vector.size,): ret_shape = (keep_dims,) - state_vector = state_vector.reshape((2,) * n_qubits) elif state_vector.shape == (2,) * n_qubits: + state_vector = state_vector.reshape(-1) ret_shape = tuple(2 for _ in range(len(keep_indices))) else: raise ValueError("Input state_vector must be shaped like (2 ** n,) or (2,) * n") - keep_dims = 1 << len(keep_indices) if not np.isclose(np.linalg.norm(state_vector), 1): raise ValueError("Input state must be normalized.") if len(set(keep_indices)) != len(keep_indices): raise ValueError(f"keep_indices were {keep_indices} but must be unique.") - if any(ind >= n_qubits for ind in keep_indices): + if any([ind >= n_qubits for ind in keep_indices]): raise ValueError("keep_indices {} are an invalid subset of the input state vector.") - other_qubits = sorted(set(range(n_qubits)) - set(keep_indices)) - candidates = [ - state_vector[predicates.slice_for_qubits_equal_to(other_qubits, k)].reshape(keep_dims) - for k in range(1 << len(other_qubits)) - ] - # The coherence measure is computed using unnormalized candidates. - best_candidate = max(candidates, key=lambda c: float(np.linalg.norm(c, 2))) - best_candidate = best_candidate / np.linalg.norm(best_candidate) - left = np.conj(best_candidate.reshape((keep_dims,))).T - coherence_measure = sum(abs(np.dot(left, c.reshape((keep_dims,)))) ** 2 for c in candidates) - - if protocols.approx_eq(coherence_measure, 1, atol=atol): - return np.exp(2j * np.pi * np.random.random()) * best_candidate.reshape(ret_shape) + # The number of output qubits. + # keep_indices = np.array(keep_indices) + num_qubits_out = len(keep_indices)#.shape[0] + + # The permutation moves the specified qubits to the start of the qubit order. + keeps = set(keep_indices) + remainder = np.array( + [i for i in range(n_qubits) if i not in keeps], dtype=np.int64) + permutation = np.concatenate([keep_indices, remainder]) + + # Permute qubits and construct the pure-state density matrix. + raveled = state_vector.reshape([2] * n_qubits) + raveled = np.transpose(raveled, permutation) + c_psi = raveled.reshape([2 ** num_qubits_out, -1]) + rho = c_psi @ c_psi.conj().T + + # Return the eigenvector with eigenvalue 1. + evals, evec = np.linalg.eigh(rho) + if np.isclose(evals, 1, atol=atol).any(): + factor_index = np.argmax(evals) + factored = evec[:, factor_index] + # Prevent accidental reliance on global phase. + random_phase = np.exp(2j * np.pi * np.random.random()) + return random_phase * factored.reshape(ret_shape) # Method did not yield a pure state. Fall back to `default` argument. - if default is not RaiseValueErrorIfNotProvided: + if default is not None: return default raise EntangledStateError( diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index c7f82a7ce2e..90ab75a5f76 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -385,23 +385,29 @@ def test_sub_state_vector() -> None: cirq.sub_state_vector(reshaped_state, [5, 6, 7, 8], atol=1e-15), c ) + # Make an imperfect product state to probe tolerances. + rng = np.random.default_rng(0) + noise = rng.uniform(-1, 1, size=state.size) + 1j * rng.uniform(-1, 1, state.size) + imperfect_state = state + 1e-2 * noise.reshape((2,) * 9) + imperfect_state /= np.linalg.norm(imperfect_state) + # Reject factoring for very tight tolerance. assert ( - cirq.sub_state_vector(state, [0, 1], default=_DEFAULT_ARRAY, atol=1e-16) is _DEFAULT_ARRAY + cirq.sub_state_vector(imperfect_state, [0, 1], default=_DEFAULT_ARRAY, atol=1e-16) is _DEFAULT_ARRAY ) assert ( - cirq.sub_state_vector(state, [2, 3, 4], default=_DEFAULT_ARRAY, atol=1e-16) + cirq.sub_state_vector(imperfect_state, [2, 3, 4], default=_DEFAULT_ARRAY, atol=1e-16) is _DEFAULT_ARRAY ) assert ( - cirq.sub_state_vector(state, [5, 6, 7, 8], default=_DEFAULT_ARRAY, atol=1e-16) + cirq.sub_state_vector(imperfect_state, [5, 6, 7, 8], default=_DEFAULT_ARRAY, atol=1e-16) is _DEFAULT_ARRAY ) # Permit invalid factoring for loose tolerance. for q1 in range(9): assert ( - cirq.sub_state_vector(state, [q1], default=_DEFAULT_ARRAY, atol=1) is not _DEFAULT_ARRAY + cirq.sub_state_vector(imperfect_state, [q1], default=_DEFAULT_ARRAY, atol=1) is not _DEFAULT_ARRAY ) From c13efc6ffa5ba6cb1cefa40625c54ebcb195cb4d Mon Sep 17 00:00:00 2001 From: Anton Kast Date: Thu, 14 May 2026 21:11:26 -0700 Subject: [PATCH 2/4] Tidy the diff. --- cirq-core/cirq/linalg/transformations.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 8d4415e0ae7..2d81fe45343 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -470,7 +470,7 @@ def partial_trace_of_state_vector_as_mixture( def sub_state_vector( state_vector: np.ndarray, - keep_indices: np.ndarray, + keep_indices: list[int], *, default: np.ndarray | None = None, atol: float = 1e-6, @@ -496,7 +496,7 @@ def sub_state_vector( If the provided `state_vector` cannot be factored into a pure state over `keep_indices`, the method will fall back to return `default`. If `default` - is not provided, the method will fail and raise `ValueError`. + is not provided, the method will fail and raise EntangledStateError. Args: state_vector: The target state_vector. @@ -521,6 +521,7 @@ def sub_state_vector( `default` is not provided. """ + if not np.log2(state_vector.size).is_integer(): raise ValueError( f"Input state_vector of size {state_vector.size} does not represent a " @@ -542,7 +543,7 @@ def sub_state_vector( raise ValueError("Input state must be normalized.") if len(set(keep_indices)) != len(keep_indices): raise ValueError(f"keep_indices were {keep_indices} but must be unique.") - if any([ind >= n_qubits for ind in keep_indices]): + if any(ind >= n_qubits for ind in keep_indices): raise ValueError("keep_indices {} are an invalid subset of the input state vector.") # The number of output qubits. @@ -564,11 +565,11 @@ def sub_state_vector( # Return the eigenvector with eigenvalue 1. evals, evec = np.linalg.eigh(rho) if np.isclose(evals, 1, atol=atol).any(): - factor_index = np.argmax(evals) - factored = evec[:, factor_index] - # Prevent accidental reliance on global phase. - random_phase = np.exp(2j * np.pi * np.random.random()) - return random_phase * factored.reshape(ret_shape) + factor_index = np.argmax(evals) + factored = evec[:, factor_index] + # Prevent accidental reliance on global phase. + random_phase = np.exp(2j * np.pi * np.random.random()) + return random_phase * factored.reshape(ret_shape) # Method did not yield a pure state. Fall back to `default` argument. if default is not None: From 5968996e52bc7f73ca52c8ab0cfe3502d535fc3f Mon Sep 17 00:00:00 2001 From: Anton Kast Date: Fri, 15 May 2026 10:47:45 -0700 Subject: [PATCH 3/4] Satisfy checks/all --changed. --- cirq-core/cirq/linalg/transformations.py | 11 +++-------- cirq-core/cirq/linalg/transformations_test.py | 6 ++++-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 2d81fe45343..9af509e6c6a 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -25,7 +25,6 @@ import numpy as np from cirq import protocols -from cirq.linalg import predicates def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float) -> np.ndarray: @@ -546,20 +545,16 @@ def sub_state_vector( if any(ind >= n_qubits for ind in keep_indices): raise ValueError("keep_indices {} are an invalid subset of the input state vector.") - # The number of output qubits. - # keep_indices = np.array(keep_indices) - num_qubits_out = len(keep_indices)#.shape[0] - # The permutation moves the specified qubits to the start of the qubit order. keeps = set(keep_indices) - remainder = np.array( - [i for i in range(n_qubits) if i not in keeps], dtype=np.int64) + remainder = np.array([i for i in range(n_qubits) if i not in keeps], dtype=np.int64) permutation = np.concatenate([keep_indices, remainder]) # Permute qubits and construct the pure-state density matrix. raveled = state_vector.reshape([2] * n_qubits) raveled = np.transpose(raveled, permutation) - c_psi = raveled.reshape([2 ** num_qubits_out, -1]) + num_qubits_out = len(keep_indices) + c_psi = raveled.reshape([2**num_qubits_out, -1]) rho = c_psi @ c_psi.conj().T # Return the eigenvector with eigenvalue 1. diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index 90ab75a5f76..d23e25059c0 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -393,7 +393,8 @@ def test_sub_state_vector() -> None: # Reject factoring for very tight tolerance. assert ( - cirq.sub_state_vector(imperfect_state, [0, 1], default=_DEFAULT_ARRAY, atol=1e-16) is _DEFAULT_ARRAY + cirq.sub_state_vector(imperfect_state, [0, 1], default=_DEFAULT_ARRAY, atol=1e-16) + is _DEFAULT_ARRAY ) assert ( cirq.sub_state_vector(imperfect_state, [2, 3, 4], default=_DEFAULT_ARRAY, atol=1e-16) @@ -407,7 +408,8 @@ def test_sub_state_vector() -> None: # Permit invalid factoring for loose tolerance. for q1 in range(9): assert ( - cirq.sub_state_vector(imperfect_state, [q1], default=_DEFAULT_ARRAY, atol=1) is not _DEFAULT_ARRAY + cirq.sub_state_vector(imperfect_state, [q1], default=_DEFAULT_ARRAY, atol=1) + is not _DEFAULT_ARRAY ) From cdb32169e699afc39fbc4886245a764c2b2019fa Mon Sep 17 00:00:00 2001 From: Anton Kast Date: Fri, 15 May 2026 10:51:06 -0700 Subject: [PATCH 4/4] nits --- cirq-core/cirq/linalg/transformations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 9af509e6c6a..a0906cbe24b 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -546,7 +546,7 @@ def sub_state_vector( raise ValueError("keep_indices {} are an invalid subset of the input state vector.") # The permutation moves the specified qubits to the start of the qubit order. - keeps = set(keep_indices) + keeps = frozenset(keep_indices) remainder = np.array([i for i in range(n_qubits) if i not in keeps], dtype=np.int64) permutation = np.concatenate([keep_indices, remainder])