Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Commit ce21161

Browse files
MichaelMarienChase Roberts
andcommitted
Backend test (#448)
* added test for mps switch backend * added switch backend method to MPS * added test for network operations switch backend * make sure switch_backend not only fixes tensor but also node property * added switch_backend to init * missing test for backend contextmanager * notimplemented tests for base backend * added subtraction test notimplemented * added jax backend index_update test * first missing tests for numpy * actually catched an error in numpy_backend eigs method! * more eigs tests * didnt catch an error, unexpected convention * more tests for eigsh_lancszos * added missing pytorch backend tests * added missing tf backend tests * pytype * suppress pytype Co-authored-by: Chase Roberts <chaseriley@google.com>
1 parent 7087127 commit ce21161

6 files changed

Lines changed: 442 additions & 3 deletions

File tree

tensornetwork/backends/backend_test.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import numpy as np
66
from tensornetwork import connect, contract, Node
7+
from tensornetwork.backends.base_backend import BaseBackend
78

89

910
def clean_tensornetwork_modules():
@@ -146,3 +147,200 @@ def test_basic_network_without_backends_raises_error():
146147
Node(np.ones((2, 2)), backend="tensorflow")
147148
with pytest.raises(ImportError):
148149
Node(np.ones((2, 2)), backend="pytorch")
150+
[]
151+
152+
def test_base_backend_name():
153+
backend = BaseBackend()
154+
assert backend.name == "base backend"
155+
156+
157+
def test_base_backend_tensordot_not_implemented():
158+
backend = BaseBackend()
159+
with pytest.raises(NotImplementedError):
160+
backend.tensordot(np.ones((2, 2)), np.ones((2, 2)), axes=[[0], [0]])
161+
162+
163+
def test_base_backend_reshape_not_implemented():
164+
backend = BaseBackend()
165+
with pytest.raises(NotImplementedError):
166+
backend.reshape(np.ones((2, 2)), (4, 1))
167+
168+
169+
def test_base_backend_transpose_not_implemented():
170+
backend = BaseBackend()
171+
with pytest.raises(NotImplementedError):
172+
backend.transpose(np.ones((2, 2)), [0, 1])
173+
174+
175+
def test_base_backend_svd_decompositon_not_implemented():
176+
backend = BaseBackend()
177+
with pytest.raises(NotImplementedError):
178+
backend.svd_decomposition(np.ones((2, 2)), 0)
179+
180+
181+
def test_base_backend_qr_decompositon_not_implemented():
182+
backend = BaseBackend()
183+
with pytest.raises(NotImplementedError):
184+
backend.qr_decomposition(np.ones((2, 2)), 0)
185+
186+
187+
def test_base_backend_rq_decompositon_not_implemented():
188+
backend = BaseBackend()
189+
with pytest.raises(NotImplementedError):
190+
backend.rq_decomposition(np.ones((2, 2)), 0)
191+
192+
193+
def test_base_backend_shape_concat_not_implemented():
194+
backend = BaseBackend()
195+
with pytest.raises(NotImplementedError):
196+
backend.shape_concat([np.ones((2, 2)), np.ones((2, 2))], 0)
197+
198+
199+
def test_base_backend_shape_tensor_not_implemented():
200+
backend = BaseBackend()
201+
with pytest.raises(NotImplementedError):
202+
backend.shape_tensor(np.ones((2, 2)))
203+
204+
205+
def test_base_backend_shape_tuple_not_implemented():
206+
backend = BaseBackend()
207+
with pytest.raises(NotImplementedError):
208+
backend.shape_tuple(np.ones((2, 2)))
209+
210+
211+
def test_base_backend_shape_prod_not_implemented():
212+
backend = BaseBackend()
213+
with pytest.raises(NotImplementedError):
214+
backend.shape_prod(np.ones((2, 2)))
215+
216+
217+
def test_base_backend_sqrt_not_implemented():
218+
backend = BaseBackend()
219+
with pytest.raises(NotImplementedError):
220+
backend.sqrt(np.ones((2, 2)))
221+
222+
223+
def test_base_backend_diag_not_implemented():
224+
backend = BaseBackend()
225+
with pytest.raises(NotImplementedError):
226+
backend.diag(np.ones((2, 2)))
227+
228+
229+
def test_base_backend_convert_to_tensor_not_implemented():
230+
backend = BaseBackend()
231+
with pytest.raises(NotImplementedError):
232+
backend.convert_to_tensor(np.ones((2, 2)))
233+
234+
235+
def test_base_backend_trace_not_implemented():
236+
backend = BaseBackend()
237+
with pytest.raises(NotImplementedError):
238+
backend.trace(np.ones((2, 2)))
239+
240+
241+
def test_base_backend_outer_product_not_implemented():
242+
backend = BaseBackend()
243+
with pytest.raises(NotImplementedError):
244+
backend.outer_product(np.ones((2, 2)), np.ones((2, 2)))
245+
246+
247+
def test_base_backend_einsul_not_implemented():
248+
backend = BaseBackend()
249+
with pytest.raises(NotImplementedError):
250+
backend.einsum("ii", np.ones((2, 2)))
251+
252+
253+
def test_base_backend_norm_not_implemented():
254+
backend = BaseBackend()
255+
with pytest.raises(NotImplementedError):
256+
backend.norm(np.ones((2, 2)))
257+
258+
259+
def test_base_backend_eye_not_implemented():
260+
backend = BaseBackend()
261+
with pytest.raises(NotImplementedError):
262+
backend.eye(2, dtype=np.float64)
263+
264+
265+
def test_base_backend_ones_not_implemented():
266+
backend = BaseBackend()
267+
with pytest.raises(NotImplementedError):
268+
backend.ones((2, 2), dtype=np.float64)
269+
270+
271+
def test_base_backend_zeros_not_implemented():
272+
backend = BaseBackend()
273+
with pytest.raises(NotImplementedError):
274+
backend.zeros((2, 2), dtype=np.float64)
275+
276+
277+
def test_base_backend_randn_not_implemented():
278+
backend = BaseBackend()
279+
with pytest.raises(NotImplementedError):
280+
backend.randn((2, 2))
281+
282+
283+
def test_base_backend_random_uniforl_not_implemented():
284+
backend = BaseBackend()
285+
with pytest.raises(NotImplementedError):
286+
backend.random_uniform((2, 2))
287+
288+
289+
def test_base_backend_conj_not_implemented():
290+
backend = BaseBackend()
291+
with pytest.raises(NotImplementedError):
292+
backend.conj(np.ones((2, 2)))
293+
294+
295+
def test_base_backend_eigh_not_implemented():
296+
backend = BaseBackend()
297+
with pytest.raises(NotImplementedError):
298+
backend.eigh(np.ones((2, 2)))
299+
300+
301+
def test_base_backend_eigs_not_implemented():
302+
backend = BaseBackend()
303+
with pytest.raises(NotImplementedError):
304+
backend.eigs(np.ones((2, 2)))
305+
306+
307+
def test_base_backend_eigs_lanczos_not_implemented():
308+
backend = BaseBackend()
309+
with pytest.raises(NotImplementedError):
310+
backend.eigsh_lanczos(np.ones((2, 2)))
311+
312+
313+
def test_base_backend_addition_not_implemented():
314+
backend = BaseBackend()
315+
with pytest.raises(NotImplementedError):
316+
backend.addition(np.ones((2, 2)), np.ones((2, 2)))
317+
318+
319+
def test_base_backend_subtraction_not_implemented():
320+
backend = BaseBackend()
321+
with pytest.raises(NotImplementedError):
322+
backend.subtraction(np.ones((2, 2)), np.ones((2, 2)))
323+
324+
325+
def test_base_backend_multiply_not_implemented():
326+
backend = BaseBackend()
327+
with pytest.raises(NotImplementedError):
328+
backend.multiply(np.ones((2, 2)), np.ones((2, 2)))
329+
330+
331+
def test_base_backend_divide_not_implemented():
332+
backend = BaseBackend()
333+
with pytest.raises(NotImplementedError):
334+
backend.divide(np.ones((2, 2)), np.ones((2, 2)))
335+
336+
337+
def test_base_backend_index_update_not_implemented():
338+
backend = BaseBackend()
339+
with pytest.raises(NotImplementedError):
340+
backend.index_update(np.ones((2, 2)), np.ones((2, 2)), np.ones((2, 2)))
341+
342+
343+
def test_base_backend_inv_not_implemented():
344+
backend = BaseBackend()
345+
with pytest.raises(NotImplementedError):
346+
backend.inv(np.ones((2, 2)))

tensornetwork/backends/jax/jax_backend_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,27 @@ def index_update(dtype):
271271
tensor = np.array(tensor)
272272
tensor[tensor > 0.1] = 0.0
273273
np.testing.assert_allclose(tensor, out)
274+
275+
276+
def test_base_backend_eigs_not_implemented():
277+
backend = jax_backend.JaxBackend()
278+
tensor = backend.randn((4, 2, 3), dtype=np.float64)
279+
with pytest.raises(NotImplementedError):
280+
backend.eigs(tensor)
281+
282+
283+
def test_base_backend_eigsh_lanczos_not_implemented():
284+
backend = jax_backend.JaxBackend()
285+
tensor = backend.randn((4, 2, 3), dtype=np.float64)
286+
with pytest.raises(NotImplementedError):
287+
backend.eigsh_lanczos(tensor)
288+
289+
290+
@pytest.mark.parametrize("dtype", np_dtypes)
291+
def test_index_update(dtype):
292+
backend = jax_backend.JaxBackend()
293+
tensor = backend.randn((4, 2, 3), dtype=dtype, seed=10)
294+
out = backend.index_update(tensor, tensor > 0.1, 0.0)
295+
np_tensor = np.array(tensor)
296+
np_tensor[np_tensor > 0.1] = 0.0
297+
np.testing.assert_allclose(out, np_tensor)

0 commit comments

Comments
 (0)