Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Commit 1f87b04

Browse files
mganahlChase Roberts
andauthored
set seed in tests (#517)
Co-authored-by: Chase Roberts <chaseriley@google.com>
1 parent 55a7fcc commit 1f87b04

1 file changed

Lines changed: 41 additions & 42 deletions

File tree

examples/wavefunctions/wavefunctions_test.py

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@pytest.mark.parametrize("num_sites", [2, 3, 4])
2323
def test_expval(num_sites):
24-
op = np.kron(np.array([[1.0, 0.0], [0.0, -1.0]]), np.eye(2)).reshape([2]*4)
24+
op = np.kron(np.array([[1.0, 0.0], [0.0, -1.0]]), np.eye(2)).reshape([2] * 4)
2525
op = tf.convert_to_tensor(op)
2626
for j in range(num_sites):
2727
psi = np.zeros([2] * num_sites)
@@ -30,7 +30,7 @@ def test_expval(num_sites):
3030
psi = tf.convert_to_tensor(psi)
3131
for i in range(num_sites):
3232
res = wavefunctions.expval(psi, op, i, pbc=True)
33-
if i == num_sites-1-j:
33+
if i == num_sites - 1 - j:
3434
np.testing.assert_allclose(res, -1.0)
3535
else:
3636
np.testing.assert_allclose(res, 1.0)
@@ -50,68 +50,67 @@ def test_apply_op(num_sites):
5050
psi2 = tf.convert_to_tensor(psi2)
5151

5252
opX = tf.convert_to_tensor(np.array([[0.0, 1.0], [1.0, 0.0]]))
53-
psi2 = wavefunctions.apply_op(psi2, opX, num_sites-1-j)
53+
psi2 = wavefunctions.apply_op(psi2, opX, num_sites - 1 - j)
5454

5555
res = wavefunctions.inner(psi1, psi2)
5656
np.testing.assert_allclose(res, 1.0)
5757

5858

59-
@pytest.mark.parametrize(
60-
"num_sites,phys_dim,graph",
61-
[(2, 3, False), (2, 3, True), (5, 2, False)])
59+
@pytest.mark.parametrize("num_sites,phys_dim,graph",
60+
[(2, 3, False), (2, 3, True), (5, 2, False)])
6261
def test_evolve_trotter(num_sites, phys_dim, graph):
63-
psi = tf.complex(
62+
tf.random.set_seed(10)
63+
psi = tf.complex(
6464
tf.random.normal([phys_dim] * num_sites, dtype=tf.float64),
6565
tf.random.normal([phys_dim] * num_sites, dtype=tf.float64))
66-
h = tf.complex(
66+
h = tf.complex(
6767
tf.random.normal((phys_dim**2, phys_dim**2), dtype=tf.float64),
6868
tf.random.normal((phys_dim**2, phys_dim**2), dtype=tf.float64))
69-
h = 0.5 * (h + tf.linalg.adjoint(h))
70-
h = tf.reshape(h, (phys_dim, phys_dim, phys_dim, phys_dim))
71-
H = [h] * (num_sites - 1)
69+
h = 0.5 * (h + tf.linalg.adjoint(h))
70+
h = tf.reshape(h, (phys_dim, phys_dim, phys_dim, phys_dim))
71+
H = [h] * (num_sites - 1)
7272

73-
norm1 = wavefunctions.inner(psi, psi)
74-
en1 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))
73+
norm1 = wavefunctions.inner(psi, psi)
74+
en1 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))
7575

76-
if graph:
77-
psi, t = wavefunctions.evolve_trotter_defun(psi, H, 0.001, 10)
78-
else:
79-
psi, t = wavefunctions.evolve_trotter(psi, H, 0.001, 10)
76+
if graph:
77+
psi, t = wavefunctions.evolve_trotter_defun(psi, H, 0.001, 10)
78+
else:
79+
psi, t = wavefunctions.evolve_trotter(psi, H, 0.001, 10)
8080

81-
norm2 = wavefunctions.inner(psi, psi)
82-
en2 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))
81+
norm2 = wavefunctions.inner(psi, psi)
82+
en2 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))
8383

84-
np.testing.assert_allclose(t, 0.01)
85-
np.testing.assert_almost_equal(norm1/norm2, 1.0)
86-
np.testing.assert_almost_equal(en1/en2, 1.0, decimal=2)
84+
np.testing.assert_allclose(t, 0.01)
85+
np.testing.assert_almost_equal(norm1 / norm2, 1.0)
86+
np.testing.assert_almost_equal(en1 / en2, 1.0, decimal=2)
8787

8888

89-
@pytest.mark.parametrize(
90-
"num_sites,phys_dim,graph",
91-
[(2, 3, False), (2, 3, True), (5, 2, False)])
89+
@pytest.mark.parametrize("num_sites,phys_dim,graph",
90+
[(2, 3, False), (2, 3, True), (5, 2, False)])
9291
def test_evolve_trotter_euclidean(num_sites, phys_dim, graph):
93-
psi = tf.complex(
92+
tf.random.set_seed(10)
93+
psi = tf.complex(
9494
tf.random.normal([phys_dim] * num_sites, dtype=tf.float64),
9595
tf.random.normal([phys_dim] * num_sites, dtype=tf.float64))
96-
h = tf.complex(
96+
h = tf.complex(
9797
tf.random.normal((phys_dim**2, phys_dim**2), dtype=tf.float64),
9898
tf.random.normal((phys_dim**2, phys_dim**2), dtype=tf.float64))
99-
h = 0.5 * (h + tf.linalg.adjoint(h))
100-
h = tf.reshape(h, (phys_dim, phys_dim, phys_dim, phys_dim))
101-
H = [h] * (num_sites - 1)
99+
h = 0.5 * (h + tf.linalg.adjoint(h))
100+
h = tf.reshape(h, (phys_dim, phys_dim, phys_dim, phys_dim))
101+
H = [h] * (num_sites - 1)
102102

103-
norm1 = wavefunctions.inner(psi, psi)
104-
en1 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))
103+
norm1 = wavefunctions.inner(psi, psi)
104+
en1 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))
105105

106-
if graph:
107-
psi, t = wavefunctions.evolve_trotter_defun(
108-
psi, H, 0.1, 10, euclidean=True)
109-
else:
110-
psi, t = wavefunctions.evolve_trotter(psi, H, 0.1, 10, euclidean=True)
106+
if graph:
107+
psi, t = wavefunctions.evolve_trotter_defun(psi, H, 0.1, 10, euclidean=True)
108+
else:
109+
psi, t = wavefunctions.evolve_trotter(psi, H, 0.1, 10, euclidean=True)
111110

112-
norm2 = wavefunctions.inner(psi, psi)
113-
en2 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))
111+
norm2 = wavefunctions.inner(psi, psi)
112+
en2 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))
114113

115-
np.testing.assert_allclose(t, 1.0)
116-
np.testing.assert_almost_equal(norm2, 1.0)
117-
assert en2.numpy()/norm2.numpy() < en1.numpy()/norm1.numpy()
114+
np.testing.assert_allclose(t, 1.0)
115+
np.testing.assert_almost_equal(norm2, 1.0)
116+
assert en2.numpy() / norm2.numpy() < en1.numpy() / norm1.numpy()

0 commit comments

Comments
 (0)