Skip to content

Commit 20e9cf6

Browse files
committed
Make Transform subclasses also have pos-only arguments
Similarly to the previous commit, this is not strictly necessary according to LSP, but I think it's weird that subclasses don't enforce pos-only arguments if the parent class enforces that.
1 parent 3f9f95a commit 20e9cf6

9 files changed

Lines changed: 11 additions & 11 deletions

File tree

src/torchjd/autojac/_transform/_accumulate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class AccumulateGrad(Transform):
1313
should not be used elsewhere.
1414
"""
1515

16-
def __call__(self, gradients: TensorDict) -> TensorDict:
16+
def __call__(self, gradients: TensorDict, /) -> TensorDict:
1717
accumulate_grads(gradients.keys(), gradients.values())
1818
return {}
1919

@@ -30,7 +30,7 @@ class AccumulateJac(Transform):
3030
should not be used elsewhere.
3131
"""
3232

33-
def __call__(self, jacobians: TensorDict) -> TensorDict:
33+
def __call__(self, jacobians: TensorDict, /) -> TensorDict:
3434
accumulate_jacs(jacobians.keys(), jacobians.values())
3535
return {}
3636

src/torchjd/autojac/_transform/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, outer: Transform, inner: Transform):
7676
def __str__(self) -> str:
7777
return str(self.outer) + " ∘ " + str(self.inner)
7878

79-
def __call__(self, input: TensorDict) -> TensorDict:
79+
def __call__(self, input: TensorDict, /) -> TensorDict:
8080
intermediate = self.inner(input)
8181
return self.outer(intermediate)
8282

@@ -107,7 +107,7 @@ def __str__(self) -> str:
107107
strings.append(s)
108108
return "(" + " | ".join(strings) + ")"
109109

110-
def __call__(self, tensor_dict: TensorDict) -> TensorDict:
110+
def __call__(self, tensor_dict: TensorDict, /) -> TensorDict:
111111
union: TensorDict = {}
112112
for transform in self.transforms:
113113
union |= transform(tensor_dict)

src/torchjd/autojac/_transform/_diagonalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, key_order: OrderedSet[Tensor]):
6060
self.indices.append((begin, end))
6161
begin = end
6262

63-
def __call__(self, tensors: TensorDict) -> TensorDict:
63+
def __call__(self, tensors: TensorDict, /) -> TensorDict:
6464
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.key_order]
6565
diagonal_matrix = torch.cat(flattened_considered_values).diag()
6666
diagonalized_tensors = {

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
self.retain_graph = retain_graph
3838
self.create_graph = create_graph
3939

40-
def __call__(self, tensors: TensorDict) -> TensorDict:
40+
def __call__(self, tensors: TensorDict, /) -> TensorDict:
4141
tensor_outputs = [tensors[output] for output in self.outputs]
4242

4343
differentiated_tuple = self._differentiate(tensor_outputs)

src/torchjd/autojac/_transform/_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Init(Transform):
1616
def __init__(self, values: Set[Tensor]):
1717
self.values = values
1818

19-
def __call__(self, input: TensorDict) -> TensorDict:
19+
def __call__(self, input: TensorDict, /) -> TensorDict:
2020
return {value: torch.ones_like(value) for value in self.values}
2121

2222
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:

src/torchjd/autojac/_transform/_select.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class Select(Transform):
1515
def __init__(self, keys: Set[Tensor]):
1616
self.keys = keys
1717

18-
def __call__(self, tensor_dict: TensorDict) -> TensorDict:
18+
def __call__(self, tensor_dict: TensorDict, /) -> TensorDict:
1919
output = {key: tensor_dict[key] for key in self.keys}
2020
return type(tensor_dict)(output)
2121

src/torchjd/autojac/_transform/_stack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class Stack(Transform):
2323
def __init__(self, transforms: Sequence[Transform]):
2424
self.transforms = transforms
2525

26-
def __call__(self, input: TensorDict) -> TensorDict:
26+
def __call__(self, input: TensorDict, /) -> TensorDict:
2727
results = [transform(input) for transform in self.transforms]
2828
result = _stack(results)
2929
return result

tests/unit/autojac/_transform/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]):
1717
def __str__(self):
1818
return "T"
1919

20-
def __call__(self, input: TensorDict) -> TensorDict:
20+
def __call__(self, input: TensorDict, /) -> TensorDict:
2121
# Ignore the input, create a dictionary with the right keys as an output.
2222
output_dict = {key: empty_(0) for key in self._output_keys}
2323
return output_dict

tests/unit/autojac/_transform/test_stack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class FakeGradientsTransform(Transform):
1515
def __init__(self, keys: Iterable[Tensor]):
1616
self.keys = set(keys)
1717

18-
def __call__(self, input: TensorDict) -> TensorDict:
18+
def __call__(self, input: TensorDict, /) -> TensorDict:
1919
return {key: torch.ones_like(key) for key in self.keys}
2020

2121
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:

0 commit comments

Comments
 (0)