@@ -45,8 +45,8 @@ def test_einsum(
4545 b_indices : list [int ],
4646 output_indices : list [int ],
4747):
48- a = DiagonalSparseTensor (torch . randn (a_pshape ), a_v_to_ps )
49- b = DiagonalSparseTensor (torch . randn (b_pshape ), b_v_to_ps )
48+ a = DiagonalSparseTensor (randn_ (a_pshape ), a_v_to_ps )
49+ b = DiagonalSparseTensor (randn_ (b_pshape ), b_v_to_ps )
5050
5151 res = einsum ((a , a_indices ), (b , b_indices ), output = output_indices )
5252
@@ -212,7 +212,7 @@ def test_fix_ungrouped_dims(
212212 expected_physical_shape : list [int ],
213213 expected_v_to_ps : list [list [int ]],
214214):
215- physical = torch . randn (physical_shape )
215+ physical = randn_ (physical_shape )
216216 fixed_physical , fixed_v_to_ps = fix_ungrouped_dims (physical , v_to_ps )
217217
218218 assert list (fixed_physical .shape ) == expected_physical_shape
@@ -240,7 +240,7 @@ def test_unsquash_pdim(
240240 expected_physical_shape : list [int ],
241241 expected_new_encoding : list [list [int ]],
242242):
243- physical = torch . randn (physical_shape )
243+ physical = randn_ (physical_shape )
244244 new_physical , new_encoding = unsquash_pdim (physical , pdim , new_pdim_shape )
245245
246246 assert list (new_physical .shape ) == expected_physical_shape
0 commit comments