Skip to content

Commit aad6f27

Browse files
authored
test(autogram): Stop creating two engines for one model (#454)
1 parent 63d9d84 commit aad6f27

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

tests/unit/autogram/test_engine.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -479,18 +479,20 @@ def test_reshape_equivariance(shape: list[int], batch_dim: int | None):
479479
input_size = shape[0]
480480
output_size = prod(shape[1:])
481481

482-
model = Linear(input_size, output_size).to(device=DEVICE)
483-
engine1 = Engine(model, batch_dim=None)
484-
engine2 = Engine(model, batch_dim=None)
482+
torch.manual_seed(0)
483+
model1 = Linear(input_size, output_size).to(device=DEVICE)
484+
torch.manual_seed(0)
485+
model2 = Linear(input_size, output_size).to(device=DEVICE)
485486

486-
input = randn_([input_size])
487-
output = model(input)
487+
engine1 = Engine(model1, batch_dim=None)
488+
engine2 = Engine(model2, batch_dim=None)
488489

489-
reshaped_output = output.reshape(shape[1:])
490+
input = randn_([input_size])
491+
output = model1(input)
492+
reshaped_output = model2(input).reshape(shape[1:])
490493

491494
gramian = engine1.compute_gramian(output)
492495
reshaped_gramian = engine2.compute_gramian(reshaped_output)
493-
494496
expected_reshaped_gramian = reshape_gramian(gramian, shape[1:])
495497

496498
assert_close(reshaped_gramian, expected_reshaped_gramian)
@@ -519,18 +521,20 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
519521
input_size = shape[0]
520522
output_size = prod(shape[1:])
521523

522-
model = Linear(input_size, output_size).to(device=DEVICE)
523-
engine1 = Engine(model, batch_dim=None)
524-
engine2 = Engine(model, batch_dim=None)
524+
torch.manual_seed(0)
525+
model1 = Linear(input_size, output_size).to(device=DEVICE)
526+
torch.manual_seed(0)
527+
model2 = Linear(input_size, output_size).to(device=DEVICE)
525528

526-
input = randn_([input_size])
527-
output = model(input).reshape(shape[1:])
529+
engine1 = Engine(model1, batch_dim=None)
530+
engine2 = Engine(model2, batch_dim=None)
528531

529-
moved_output = output.movedim(source, destination)
532+
input = randn_([input_size])
533+
output = model1(input).reshape(shape[1:])
534+
moved_output = model2(input).reshape(shape[1:]).movedim(source, destination)
530535

531536
gramian = engine1.compute_gramian(output)
532537
moved_gramian = engine2.compute_gramian(moved_output)
533-
534538
expected_moved_gramian = movedim_gramian(gramian, source, destination)
535539

536540
assert_close(moved_gramian, expected_moved_gramian)
@@ -562,17 +566,20 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int):
562566
batch_size = shape[batch_dim]
563567
output_size = input_size
564568

565-
model = Linear(input_size, output_size).to(device=DEVICE)
566-
engine1 = Engine(model, batch_dim=batch_dim)
567-
engine2 = Engine(model, batch_dim=None)
569+
torch.manual_seed(0)
570+
model1 = Linear(input_size, output_size).to(device=DEVICE)
571+
torch.manual_seed(0)
572+
model2 = Linear(input_size, output_size).to(device=DEVICE)
573+
574+
engine1 = Engine(model1, batch_dim=batch_dim)
575+
engine2 = Engine(model2, batch_dim=None)
568576

569577
input = randn_([batch_size, input_size])
570-
output = model(input)
571-
output = output.reshape([batch_size] + non_batched_shape)
572-
output = output.movedim(0, batch_dim)
578+
output1 = model1(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim)
579+
output2 = model2(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim)
573580

574-
gramian1 = engine1.compute_gramian(output)
575-
gramian2 = engine2.compute_gramian(output)
581+
gramian1 = engine1.compute_gramian(output1)
582+
gramian2 = engine2.compute_gramian(output2)
576583

577584
assert_close(gramian1, gramian2)
578585

0 commit comments

Comments
 (0)