Skip to content

Commit 25a3980

Browse files
committed
Improve test formatting
1 parent 4e96aa4 commit 25a3980

10 files changed

Lines changed: 1 addition & 14 deletions

File tree

tests/unit/autojac/_transform/test_accumulate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,5 +102,4 @@ def test_check_keys():
102102
accumulate = Accumulate()
103103

104104
output_keys = accumulate.check_keys({key})
105-
106105
assert output_keys == set()

tests/unit/autojac/_transform/test_aggregate.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ def test_aggregate_matrices_check_keys():
161161
aggregate = _AggregateMatrices(Random(), [key2, key1])
162162

163163
output_keys = aggregate.check_keys({key1, key2})
164-
165164
assert output_keys == {key1, key2}
166165

167166
with raises(RequirementError):
@@ -179,7 +178,6 @@ def test_matrixify_check_keys():
179178
matrixify = _Matrixify()
180179

181180
output_keys = matrixify.check_keys({key1, key2})
182-
183181
assert output_keys == {key1, key2}
184182

185183

@@ -191,5 +189,4 @@ def test_reshape_check_keys():
191189
reshape = _Reshape()
192190

193191
output_keys = reshape.check_keys({key1, key2})
194-
195192
assert output_keys == {key1, key2}

tests/unit/autojac/_transform/test_base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def test_composition_check_keys():
4545
t2 = FakeTransform(required_keys={a2}, output_keys={a1})
4646

4747
output_keys = (t1 << t2).check_keys({a2})
48-
4948
assert output_keys == {a1, a2}
5049

5150
# Inner Transform fails its check
@@ -71,7 +70,6 @@ def test_conjunct_check_keys_1():
7170
t3 = FakeTransform(required_keys={a2}, output_keys=set())
7271

7372
output_keys = (t1 | t2).check_keys({a1})
74-
7573
assert output_keys == set()
7674

7775
with raises(RequirementError):
@@ -95,7 +93,6 @@ def test_conjunct_check_keys_2():
9593
t3 = FakeTransform(required_keys=set(), output_keys={a2})
9694

9795
output_keys = (t2 | t3).check_keys(set())
98-
9996
assert output_keys == {a1, a2}
10097

10198
with raises(RequirementError):

tests/unit/autojac/_transform/test_diagonalize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def test_check_keys():
109109
diag = Diagonalize([key1])
110110

111111
output_keys = diag.check_keys({key1})
112-
113112
assert output_keys == {key1}
114113

115114
with raises(RequirementError):

tests/unit/autojac/_transform/test_grad.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ def test_check_keys():
297297
grad = Grad(outputs=[y], inputs=[a1, a2])
298298

299299
output_keys = grad.check_keys({y})
300-
301300
assert output_keys == {a1, a2}
302301

303302
with raises(RequirementError):

tests/unit/autojac/_transform/test_init.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def test_check_keys():
7171
init = Init([key])
7272

7373
output_keys = init.check_keys(set())
74-
7574
assert output_keys == {key}
7675

7776
with raises(RequirementError):

tests/unit/autojac/_transform/test_jac.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ def test_check_keys():
297297
jac = Jac(outputs=[y], inputs=[a1, a2], chunk_size=None)
298298

299299
output_keys = jac.check_keys({y})
300-
301300
assert output_keys == {a1, a2}
302301

303302
with raises(RequirementError):

tests/unit/autojac/_transform/test_select.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def test_check_keys():
6767
key3 = torch.tensor([3.0])
6868

6969
output_keys = Select([key1, key2]).check_keys({key1, key2, key3})
70-
7170
assert output_keys == {key1, key2}
7271

7372
with raises(RequirementError):

tests/unit/autojac/test_backward.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def test_check_create_transform():
2525
parallel_chunk_size=None,
2626
)
2727
output_keys = transform.check_keys(set())
28-
2928
assert output_keys == set()
3029

3130

tests/unit/autojac/test_mtl_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def test_check_create_transform():
2929
retain_graph=False,
3030
parallel_chunk_size=None,
3131
)
32-
output_keys = transform.check_keys(set())
3332

33+
output_keys = transform.check_keys(set())
3434
assert output_keys == set()
3535

3636

0 commit comments

Comments
 (0)