diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 1acc9b4c08f..a0906cbe24b 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -25,14 +25,6 @@ import numpy as np from cirq import protocols -from cirq.linalg import predicates - -# This is a special indicator value used by the `sub_state_vector` method to -# determine whether or not the caller provided a 'default' argument. It must be -# of type np.ndarray to ensure the method has the correct type signature in that -# case. It is checked for using `is`, so it won't have a false positive if the -# user provides a different np.array([]) value. -RaiseValueErrorIfNotProvided: np.ndarray = np.array([]) def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float) -> np.ndarray: @@ -479,7 +471,7 @@ def sub_state_vector( state_vector: np.ndarray, keep_indices: list[int], *, - default: np.ndarray = RaiseValueErrorIfNotProvided, + default: np.ndarray | None = None, atol: float = 1e-6, ) -> np.ndarray: r"""Attempts to factor a state vector into two parts and return one of them. @@ -503,7 +495,7 @@ def sub_state_vector( If the provided `state_vector` cannot be factored into a pure state over `keep_indices`, the method will fall back to return `default`. If `default` - is not provided, the method will fail and raise `ValueError`. + is not provided, the method will fail and raise EntangledStateError. Args: state_vector: The target state_vector. @@ -511,8 +503,8 @@ def sub_state_vector( `state_vector` on. default: Determines the fallback behavior when `state_vector` doesn't have a pure state factorization. If the factored state is not pure - and `default` is not set, a ValueError is raised. If default is set - to a value, that value is returned. + and `default` is not set, an EntangledStateError is raised. If + default is set to a value, that value is returned. atol: The minimum tolerance for comparing the output state's coherence measure to 1. @@ -540,13 +532,12 @@ def sub_state_vector( ret_shape: tuple[int] | tuple[int, ...] if state_vector.shape == (state_vector.size,): ret_shape = (keep_dims,) - state_vector = state_vector.reshape((2,) * n_qubits) elif state_vector.shape == (2,) * n_qubits: + state_vector = state_vector.reshape(-1) ret_shape = tuple(2 for _ in range(len(keep_indices))) else: raise ValueError("Input state_vector must be shaped like (2 ** n,) or (2,) * n") - keep_dims = 1 << len(keep_indices) if not np.isclose(np.linalg.norm(state_vector), 1): raise ValueError("Input state must be normalized.") if len(set(keep_indices)) != len(keep_indices): @@ -554,22 +545,29 @@ def sub_state_vector( if any(ind >= n_qubits for ind in keep_indices): raise ValueError("keep_indices {} are an invalid subset of the input state vector.") - other_qubits = sorted(set(range(n_qubits)) - set(keep_indices)) - candidates = [ - state_vector[predicates.slice_for_qubits_equal_to(other_qubits, k)].reshape(keep_dims) - for k in range(1 << len(other_qubits)) - ] - # The coherence measure is computed using unnormalized candidates. - best_candidate = max(candidates, key=lambda c: float(np.linalg.norm(c, 2))) - best_candidate = best_candidate / np.linalg.norm(best_candidate) - left = np.conj(best_candidate.reshape((keep_dims,))).T - coherence_measure = sum(abs(np.dot(left, c.reshape((keep_dims,)))) ** 2 for c in candidates) - - if protocols.approx_eq(coherence_measure, 1, atol=atol): - return np.exp(2j * np.pi * np.random.random()) * best_candidate.reshape(ret_shape) + # The permutation moves the specified qubits to the start of the qubit order. + keeps = frozenset(keep_indices) + remainder = np.array([i for i in range(n_qubits) if i not in keeps], dtype=np.int64) + permutation = np.concatenate([keep_indices, remainder]) + + # Permute qubits and construct the pure-state density matrix. + raveled = state_vector.reshape([2] * n_qubits) + raveled = np.transpose(raveled, permutation) + num_qubits_out = len(keep_indices) + c_psi = raveled.reshape([2**num_qubits_out, -1]) + rho = c_psi @ c_psi.conj().T + + # Return the eigenvector with eigenvalue 1. + evals, evec = np.linalg.eigh(rho) + if np.isclose(evals, 1, atol=atol).any(): + factor_index = np.argmax(evals) + factored = evec[:, factor_index] + # Prevent accidental reliance on global phase. + random_phase = np.exp(2j * np.pi * np.random.random()) + return random_phase * factored.reshape(ret_shape) # Method did not yield a pure state. Fall back to `default` argument. - if default is not RaiseValueErrorIfNotProvided: + if default is not None: return default raise EntangledStateError( diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index c7f82a7ce2e..d23e25059c0 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -385,23 +385,31 @@ def test_sub_state_vector() -> None: cirq.sub_state_vector(reshaped_state, [5, 6, 7, 8], atol=1e-15), c ) + # Make an imperfect product state to probe tolerances. + rng = np.random.default_rng(0) + noise = rng.uniform(-1, 1, size=state.size) + 1j * rng.uniform(-1, 1, state.size) + imperfect_state = state + 1e-2 * noise.reshape((2,) * 9) + imperfect_state /= np.linalg.norm(imperfect_state) + # Reject factoring for very tight tolerance. assert ( - cirq.sub_state_vector(state, [0, 1], default=_DEFAULT_ARRAY, atol=1e-16) is _DEFAULT_ARRAY + cirq.sub_state_vector(imperfect_state, [0, 1], default=_DEFAULT_ARRAY, atol=1e-16) + is _DEFAULT_ARRAY ) assert ( - cirq.sub_state_vector(state, [2, 3, 4], default=_DEFAULT_ARRAY, atol=1e-16) + cirq.sub_state_vector(imperfect_state, [2, 3, 4], default=_DEFAULT_ARRAY, atol=1e-16) is _DEFAULT_ARRAY ) assert ( - cirq.sub_state_vector(state, [5, 6, 7, 8], default=_DEFAULT_ARRAY, atol=1e-16) + cirq.sub_state_vector(imperfect_state, [5, 6, 7, 8], default=_DEFAULT_ARRAY, atol=1e-16) is _DEFAULT_ARRAY ) # Permit invalid factoring for loose tolerance. for q1 in range(9): assert ( - cirq.sub_state_vector(state, [q1], default=_DEFAULT_ARRAY, atol=1) is not _DEFAULT_ARRAY + cirq.sub_state_vector(imperfect_state, [q1], default=_DEFAULT_ARRAY, atol=1) + is not _DEFAULT_ARRAY )