2525import numpy as np
2626
2727from cirq import protocols
28- from cirq .linalg import predicates
29-
30- # This is a special indicator value used by the `sub_state_vector` method to
31- # determine whether or not the caller provided a 'default' argument. It must be
32- # of type np.ndarray to ensure the method has the correct type signature in that
33- # case. It is checked for using `is`, so it won't have a false positive if the
34- # user provides a different np.array([]) value.
35- RaiseValueErrorIfNotProvided : np .ndarray = np .array ([])
3628
3729
3830def reflection_matrix_pow (reflection_matrix : np .ndarray , exponent : float ) -> np .ndarray :
@@ -479,7 +471,7 @@ def sub_state_vector(
479471 state_vector : np .ndarray ,
480472 keep_indices : list [int ],
481473 * ,
482- default : np .ndarray = RaiseValueErrorIfNotProvided ,
474+ default : np .ndarray | None = None ,
483475 atol : float = 1e-6 ,
484476) -> np .ndarray :
485477 r"""Attempts to factor a state vector into two parts and return one of them.
@@ -503,16 +495,16 @@ def sub_state_vector(
503495
504496 If the provided `state_vector` cannot be factored into a pure state over
505497 `keep_indices`, the method will fall back to return `default`. If `default`
506- is not provided, the method will fail and raise `ValueError` .
498+ is not provided, the method will fail and raise EntangledStateError .
507499
508500 Args:
509501 state_vector: The target state_vector.
510502 keep_indices: Which indices to attempt to get the separable part of the
511503 `state_vector` on.
512504 default: Determines the fallback behavior when `state_vector` doesn't
513505 have a pure state factorization. If the factored state is not pure
514- and `default` is not set, a ValueError is raised. If default is set
515- to a value, that value is returned.
506+ and `default` is not set, an EntangledStateError is raised. If
507+ default is set to a value, that value is returned.
516508 atol: The minimum tolerance for comparing the output state's coherence
517509 measure to 1.
518510
@@ -540,36 +532,42 @@ def sub_state_vector(
540532 ret_shape : tuple [int ] | tuple [int , ...]
541533 if state_vector .shape == (state_vector .size ,):
542534 ret_shape = (keep_dims ,)
543- state_vector = state_vector .reshape ((2 ,) * n_qubits )
544535 elif state_vector .shape == (2 ,) * n_qubits :
536+ state_vector = state_vector .reshape (- 1 )
545537 ret_shape = tuple (2 for _ in range (len (keep_indices )))
546538 else :
547539 raise ValueError ("Input state_vector must be shaped like (2 ** n,) or (2,) * n" )
548540
549- keep_dims = 1 << len (keep_indices )
550541 if not np .isclose (np .linalg .norm (state_vector ), 1 ):
551542 raise ValueError ("Input state must be normalized." )
552543 if len (set (keep_indices )) != len (keep_indices ):
553544 raise ValueError (f"keep_indices were { keep_indices } but must be unique." )
554545 if any (ind >= n_qubits for ind in keep_indices ):
555546 raise ValueError ("keep_indices {} are an invalid subset of the input state vector." )
556547
557- other_qubits = sorted (set (range (n_qubits )) - set (keep_indices ))
558- candidates = [
559- state_vector [predicates .slice_for_qubits_equal_to (other_qubits , k )].reshape (keep_dims )
560- for k in range (1 << len (other_qubits ))
561- ]
562- # The coherence measure is computed using unnormalized candidates.
563- best_candidate = max (candidates , key = lambda c : float (np .linalg .norm (c , 2 )))
564- best_candidate = best_candidate / np .linalg .norm (best_candidate )
565- left = np .conj (best_candidate .reshape ((keep_dims ,))).T
566- coherence_measure = sum (abs (np .dot (left , c .reshape ((keep_dims ,)))) ** 2 for c in candidates )
567-
568- if protocols .approx_eq (coherence_measure , 1 , atol = atol ):
569- return np .exp (2j * np .pi * np .random .random ()) * best_candidate .reshape (ret_shape )
548+ # The permutation moves the specified qubits to the start of the qubit order.
549+ keeps = frozenset (keep_indices )
550+ remainder = np .array ([i for i in range (n_qubits ) if i not in keeps ], dtype = np .int64 )
551+ permutation = np .concatenate ([keep_indices , remainder ])
552+
553+ # Permute qubits and construct the pure-state density matrix.
554+ raveled = state_vector .reshape ([2 ] * n_qubits )
555+ raveled = np .transpose (raveled , permutation )
556+ num_qubits_out = len (keep_indices )
557+ c_psi = raveled .reshape ([2 ** num_qubits_out , - 1 ])
558+ rho = c_psi @ c_psi .conj ().T
559+
560+ # Return the eigenvector with eigenvalue 1.
561+ evals , evec = np .linalg .eigh (rho )
562+ if np .isclose (evals , 1 , atol = atol ).any ():
563+ factor_index = np .argmax (evals )
564+ factored = evec [:, factor_index ]
565+ # Prevent accidental reliance on global phase.
566+ random_phase = np .exp (2j * np .pi * np .random .random ())
567+ return random_phase * factored .reshape (ret_shape )
570568
571569 # Method did not yield a pure state. Fall back to `default` argument.
572- if default is not RaiseValueErrorIfNotProvided :
570+ if default is not None :
573571 return default
574572
575573 raise EntangledStateError (
0 commit comments