Skip to content

Commit 61ff55b

Browse files
committed
rename tests with check_and_get_keys to check_keys
1 parent 64372d5 commit 61ff55b

File tree

9 files changed

+26
-26
lines changed

9 files changed

+26
-26
lines changed

tests/unit/autojac/_transform/test_accumulate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def test_no_leaf_and_no_retains_grad_fails():
9595
accumulate(input)
9696

9797

98-
def test_check_and_get_keys():
99-
"""Tests that the `check_and_get_keys` method works correctly."""
98+
def test_check_keys():
99+
"""Tests that the `check_keys` method works correctly."""
100100

101101
key = torch.tensor([1.0], requires_grad=True)
102102
accumulate = Accumulate()

tests/unit/autojac/_transform/test_aggregate.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def test_reshape():
144144
assert_tensor_dicts_are_close(output, expected_output)
145145

146146

147-
def test_aggregate_matrices_check_and_get_keys():
148-
"""Tests that the `check_and_get_keys` method works correctly."""
147+
def test_aggregate_matrices_check_keys():
148+
"""Tests that the `check_keys` method works correctly."""
149149

150150
key1 = torch.tensor([1.0])
151151
key2 = torch.tensor([2.0])
@@ -156,8 +156,8 @@ def test_aggregate_matrices_check_and_get_keys():
156156
assert output_keys == {key1, key2}
157157

158158

159-
def test_matrixify_check_and_get_keys():
160-
"""Tests that the `check_and_get_keys` method works correctly."""
159+
def test_matrixify_check_keys():
160+
"""Tests that the `check_keys` method works correctly."""
161161

162162
key1 = torch.tensor([1.0])
163163
key2 = torch.tensor([2.0])
@@ -168,8 +168,8 @@ def test_matrixify_check_and_get_keys():
168168
assert output_keys == {key1, key2}
169169

170170

171-
def test_reshape_check_and_get_keys():
172-
"""Tests that the `check_and_get_keys` method works correctly."""
171+
def test_reshape_check_keys():
172+
"""Tests that the `check_keys` method works correctly."""
173173

174174
key1 = torch.tensor([1.0])
175175
key2 = torch.tensor([2.0])

tests/unit/autojac/_transform/test_base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
3333
return self._output_keys
3434

3535

36-
def test_composition_check_and_get_keys():
36+
def test_composition_check_keys():
3737
"""
38-
Tests that `check_and_get_keys` works correctly for a composition of transforms: the inner
38+
Tests that `check_keys` works correctly for a composition of transforms: the inner
3939
transform's `output_keys` has to match with the outer transform's `required_keys`.
4040
"""
4141

@@ -52,9 +52,9 @@ def test_composition_check_and_get_keys():
5252
(t2 << t1).check_keys({a1, a2})
5353

5454

55-
def test_conjunct_check_and_get_keys_1():
55+
def test_conjunct_check_keys_1():
5656
"""
57-
Tests that `check_and_get_keys` works correctly for a conjunction of transforms: all transforms
57+
Tests that `check_keys` works correctly for a conjunction of transforms: all transforms
5858
should successfully check their keys.
5959
"""
6060

@@ -76,9 +76,9 @@ def test_conjunct_check_and_get_keys_1():
7676
(t1 | t2 | t3).check_keys({a1, a2})
7777

7878

79-
def test_conjunct_check_and_get_keys_2():
79+
def test_conjunct_check_keys_2():
8080
"""
81-
Tests that `check_and_get_keys` works correctly for a conjunction of transforms: their
81+
Tests that `check_keys` works correctly for a conjunction of transforms: their
8282
`output_keys` should be disjoint.
8383
"""
8484

tests/unit/autojac/_transform/test_diagonalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def test_permute_order():
9797
assert_tensor_dicts_are_close(output, expected_output)
9898

9999

100-
def test_check_and_get_keys():
101-
"""Tests that the `check_and_get_keys` method works correctly."""
100+
def test_check_keys():
101+
"""Tests that the `check_keys` method works correctly."""
102102

103103
key = torch.tensor([1.0])
104104
diag = Diagonalize([key])

tests/unit/autojac/_transform/test_grad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ def test_create_graph():
283283
assert gradients[a].requires_grad
284284

285285

286-
def test_check_and_get_keys():
287-
"""Tests that the `check_and_get_keys` method works correctly."""
286+
def test_check_keys():
287+
"""Tests that the `check_keys` method works correctly."""
288288

289289
x = torch.tensor(5.0)
290290
a1 = torch.tensor(2.0, requires_grad=True)

tests/unit/autojac/_transform/test_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def test_conjunction_of_inits_is_init():
6363
assert_tensor_dicts_are_close(output, expected_output)
6464

6565

66-
def test_check_and_get_keys():
67-
"""Tests that the `check_and_get_keys` method works correctly."""
66+
def test_check_keys():
67+
"""Tests that the `check_keys` method works correctly."""
6868

6969
key = torch.tensor([1.0])
7070
init = Init([key])

tests/unit/autojac/_transform/test_interactions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,9 @@ def test_equivalence_jac_grads():
251251
assert_close(jac_c, torch.stack([grad_1_c, grad_2_c]))
252252

253253

254-
def test_stack_check_and_get_keys():
254+
def test_stack_check_keys():
255255
"""
256-
Tests that the `check_and_get_keys` method works correctly for a stack of transforms: all of
256+
Tests that the `check_keys` method works correctly for a stack of transforms: all of
257257
them should have the same `required_keys`.
258258
"""
259259

tests/unit/autojac/_transform/test_jac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ def test_create_graph():
283283
assert jacobians[a2].requires_grad
284284

285285

286-
def test_check_and_get_keys():
287-
"""Tests that the `check_and_get_keys` method works correctly."""
286+
def test_check_keys():
287+
"""Tests that the `check_keys` method works correctly."""
288288

289289
x = torch.tensor(5.0)
290290
a1 = torch.tensor(2.0, requires_grad=True)

tests/unit/autojac/_transform/test_select.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ def test_conjunction_of_selects_is_select():
5656
assert_tensor_dicts_are_close(output, expected_output)
5757

5858

59-
def test_check_and_get_keys():
59+
def test_check_keys():
6060
"""
61-
Tests that the `check_and_get_keys` method works correctly: the set of keys to select should
61+
Tests that the `check_keys` method works correctly: the set of keys to select should
6262
be a subset of the set of required_keys.
6363
"""
6464

0 commit comments

Comments
 (0)