@@ -479,18 +479,20 @@ def test_reshape_equivariance(shape: list[int], batch_dim: int | None):
479479 input_size = shape [0 ]
480480 output_size = prod (shape [1 :])
481481
482- model = Linear (input_size , output_size ).to (device = DEVICE )
483- engine1 = Engine (model , batch_dim = None )
484- engine2 = Engine (model , batch_dim = None )
482+ torch .manual_seed (0 )
483+ model1 = Linear (input_size , output_size ).to (device = DEVICE )
484+ torch .manual_seed (0 )
485+ model2 = Linear (input_size , output_size ).to (device = DEVICE )
485486
486- input = randn_ ([ input_size ] )
487- output = model ( input )
487+ engine1 = Engine ( model1 , batch_dim = None )
488+ engine2 = Engine ( model2 , batch_dim = None )
488489
489- reshaped_output = output .reshape (shape [1 :])
490+ input = randn_ ([input_size ])
491+ output = model1 (input )
492+ reshaped_output = model2 (input ).reshape (shape [1 :])
490493
491494 gramian = engine1 .compute_gramian (output )
492495 reshaped_gramian = engine2 .compute_gramian (reshaped_output )
493-
494496 expected_reshaped_gramian = reshape_gramian (gramian , shape [1 :])
495497
496498 assert_close (reshaped_gramian , expected_reshaped_gramian )
@@ -519,18 +521,20 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
519521 input_size = shape [0 ]
520522 output_size = prod (shape [1 :])
521523
522- model = Linear (input_size , output_size ).to (device = DEVICE )
523- engine1 = Engine (model , batch_dim = None )
524- engine2 = Engine (model , batch_dim = None )
524+ torch .manual_seed (0 )
525+ model1 = Linear (input_size , output_size ).to (device = DEVICE )
526+ torch .manual_seed (0 )
527+ model2 = Linear (input_size , output_size ).to (device = DEVICE )
525528
526- input = randn_ ([ input_size ] )
527- output = model ( input ). reshape ( shape [ 1 :] )
529+ engine1 = Engine ( model1 , batch_dim = None )
530+ engine2 = Engine ( model2 , batch_dim = None )
528531
529- moved_output = output .movedim (source , destination )
532+ input = randn_ ([input_size ])
533+ output = model1 (input ).reshape (shape [1 :])
534+ moved_output = model2 (input ).reshape (shape [1 :]).movedim (source , destination )
530535
531536 gramian = engine1 .compute_gramian (output )
532537 moved_gramian = engine2 .compute_gramian (moved_output )
533-
534538 expected_moved_gramian = movedim_gramian (gramian , source , destination )
535539
536540 assert_close (moved_gramian , expected_moved_gramian )
@@ -562,17 +566,20 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int):
562566 batch_size = shape [batch_dim ]
563567 output_size = input_size
564568
565- model = Linear (input_size , output_size ).to (device = DEVICE )
566- engine1 = Engine (model , batch_dim = batch_dim )
567- engine2 = Engine (model , batch_dim = None )
569+ torch .manual_seed (0 )
570+ model1 = Linear (input_size , output_size ).to (device = DEVICE )
571+ torch .manual_seed (0 )
572+ model2 = Linear (input_size , output_size ).to (device = DEVICE )
573+
574+ engine1 = Engine (model1 , batch_dim = batch_dim )
575+ engine2 = Engine (model2 , batch_dim = None )
568576
569577 input = randn_ ([batch_size , input_size ])
570- output = model (input )
571- output = output .reshape ([batch_size ] + non_batched_shape )
572- output = output .movedim (0 , batch_dim )
578+ output1 = model1 (input ).reshape ([batch_size ] + non_batched_shape ).movedim (0 , batch_dim )
579+ output2 = model2 (input ).reshape ([batch_size ] + non_batched_shape ).movedim (0 , batch_dim )
573580
574- gramian1 = engine1 .compute_gramian (output )
575- gramian2 = engine2 .compute_gramian (output )
581+ gramian1 = engine1 .compute_gramian (output1 )
582+ gramian2 = engine2 .compute_gramian (output2 )
576583
577584 assert_close (gramian1 , gramian2 )
578585
0 commit comments