Skip to content

Commit 7228447

Browse files
authored
Clean up recent changes in sub_state_vector (#8081)
- Verify the maximum eigenvalue is indeed close to 1. - Enable arbitrary order of `keep_indices` as before. - Use the same tolerance check as before #8077 with rtol=0. This allows to restore previous tests instead of having to do them with larger state vector deviation. - Put back RaiseValueErrorIfNotProvided as a default for `sub_state_vector`. This allows using `None` for the fallback return value. - Add check for negative values in `keep_indices`. Follow-up to #8077
1 parent 02e927f commit 7228447

2 files changed

Lines changed: 31 additions & 23 deletions

File tree

cirq-core/cirq/linalg/transformations.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,21 @@
2020
import functools
2121
from collections.abc import Sequence
2222
from types import EllipsisType
23-
from typing import Any
23+
from typing import Any, TypeVar
2424

2525
import numpy as np
2626

2727
from cirq import protocols
2828

29+
# This is a special indicator value used by the `sub_state_vector` method to
30+
# determine whether or not the caller provided a 'default' argument. It must be
31+
# of type np.ndarray to ensure the method has the correct type signature in that
32+
# case. It is checked for using `is`, so it won't have a false positive if the
33+
# user provides a different np.array([]) value.
34+
RaiseValueErrorIfNotProvided: np.ndarray = np.array([])
35+
36+
TDefault = TypeVar('TDefault')
37+
2938

3039
def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float) -> np.ndarray:
3140
"""Raises a matrix with two opposing eigenvalues to a power.
@@ -471,9 +480,9 @@ def sub_state_vector(
471480
state_vector: np.ndarray,
472481
keep_indices: list[int],
473482
*,
474-
default: np.ndarray | None = None,
483+
default: np.ndarray | TDefault = RaiseValueErrorIfNotProvided,
475484
atol: float = 1e-6,
476-
) -> np.ndarray:
485+
) -> np.ndarray | TDefault:
477486
r"""Attempts to factor a state vector into two parts and return one of them.
478487
479488
The input `state_vector` must have shape ``(2,) * n`` or ``(2 ** n)`` where
@@ -542,32 +551,30 @@ def sub_state_vector(
542551
raise ValueError("Input state must be normalized.")
543552
if len(set(keep_indices)) != len(keep_indices):
544553
raise ValueError(f"keep_indices were {keep_indices} but must be unique.")
545-
if any(ind >= n_qubits for ind in keep_indices):
554+
if any(ind < 0 or ind >= n_qubits for ind in keep_indices):
546555
raise ValueError("keep_indices {} are an invalid subset of the input state vector.")
547556

548557
# The permutation moves the specified qubits to the start of the qubit order.
549558
keeps = frozenset(keep_indices)
550-
remainder = np.array([i for i in range(n_qubits) if i not in keeps], dtype=np.int64)
551-
permutation = np.concatenate([keep_indices, remainder])
559+
permutation = [*sorted(keep_indices), *(i for i in range(n_qubits) if not i in keeps)]
552560

553561
# Permute qubits and construct the pure-state density matrix.
554562
raveled = state_vector.reshape([2] * n_qubits)
555563
raveled = np.transpose(raveled, permutation)
556-
num_qubits_out = len(keep_indices)
557-
c_psi = raveled.reshape([2**num_qubits_out, -1])
564+
c_psi = raveled.reshape(keep_dims, -1)
558565
rho = c_psi @ c_psi.conj().T
559566

560567
# Return the eigenvector with eigenvalue 1.
561568
evals, evec = np.linalg.eigh(rho)
562-
if np.isclose(evals, 1, atol=atol).any():
563-
factor_index = np.argmax(evals)
569+
factor_index = np.argmax(evals)
570+
if np.isclose(evals[factor_index], 1, atol=atol, rtol=0):
564571
factored = evec[:, factor_index]
565572
# Prevent accidental reliance on global phase.
566573
random_phase = np.exp(2j * np.pi * np.random.random())
567574
return random_phase * factored.reshape(ret_shape)
568575

569576
# Method did not yield a pure state. Fall back to `default` argument.
570-
if default is not None:
577+
if default is not RaiseValueErrorIfNotProvided:
571578
return default
572579

573580
raise EntangledStateError(

cirq-core/cirq/linalg/transformations_test.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -384,32 +384,31 @@ def test_sub_state_vector() -> None:
384384
assert cirq.equal_up_to_global_phase(
385385
cirq.sub_state_vector(reshaped_state, [5, 6, 7, 8], atol=1e-15), c
386386
)
387-
388-
# Make an imperfect product state to probe tolerances.
389-
rng = np.random.default_rng(0)
390-
noise = rng.uniform(-1, 1, size=state.size) + 1j * rng.uniform(-1, 1, state.size)
391-
imperfect_state = state + 1e-2 * noise.reshape((2,) * 9)
392-
imperfect_state /= np.linalg.norm(imperfect_state)
387+
# Output state vector is independent of the order of keep_indices
388+
assert cirq.equal_up_to_global_phase(
389+
cirq.sub_state_vector(reshaped_state, [8, 5, 7, 6], atol=1e-15), c
390+
)
393391

394392
# Reject factoring for very tight tolerance.
395393
assert (
396-
cirq.sub_state_vector(imperfect_state, [0, 1], default=_DEFAULT_ARRAY, atol=1e-16)
397-
is _DEFAULT_ARRAY
394+
cirq.sub_state_vector(state, [0, 1], default=_DEFAULT_ARRAY, atol=1e-16) is _DEFAULT_ARRAY
398395
)
399396
assert (
400-
cirq.sub_state_vector(imperfect_state, [2, 3, 4], default=_DEFAULT_ARRAY, atol=1e-16)
397+
cirq.sub_state_vector(state, [2, 3, 4], default=_DEFAULT_ARRAY, atol=1e-16)
401398
is _DEFAULT_ARRAY
402399
)
403400
assert (
404-
cirq.sub_state_vector(imperfect_state, [5, 6, 7, 8], default=_DEFAULT_ARRAY, atol=1e-16)
401+
cirq.sub_state_vector(state, [5, 6, 7, 8], default=_DEFAULT_ARRAY, atol=1e-16)
405402
is _DEFAULT_ARRAY
406403
)
407404

405+
# Ensure None can be passed as the `default` argument
406+
assert cirq.sub_state_vector(state, [0, 1], default=None, atol=1e-16) is None
407+
408408
# Permit invalid factoring for loose tolerance.
409409
for q1 in range(9):
410410
assert (
411-
cirq.sub_state_vector(imperfect_state, [q1], default=_DEFAULT_ARRAY, atol=1)
412-
is not _DEFAULT_ARRAY
411+
cirq.sub_state_vector(state, [q1], default=_DEFAULT_ARRAY, atol=1) is not _DEFAULT_ARRAY
413412
)
414413

415414

@@ -482,6 +481,8 @@ def test_sub_state_vector_invalid_inputs() -> None:
482481
cirq.sub_state_vector(state, [1, 2, 2], atol=1e-8)
483482

484483
state = np.array([1, 0, 0, 0]).reshape((2, 2))
484+
with pytest.raises(ValueError, match='invalid'):
485+
cirq.sub_state_vector(state, [-1], atol=1e-8)
485486
with pytest.raises(ValueError, match='invalid'):
486487
cirq.sub_state_vector(state, [5], atol=1e-8)
487488
with pytest.raises(ValueError, match='invalid'):

0 commit comments

Comments
 (0)