Skip to content

Commit f0f576a

Browse files
committed
Add ARG (flake8 unused arguments). Few fixes from there like making check_keys have a positional only list of argument and remove jacobians from _disunite.
1 parent 20fc697 commit f0f576a

14 files changed

Lines changed: 26 additions & 27 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ select = [
139139
"FIX", # flake8-fixme
140140
"TID", # flake8-tidy-imports
141141
"SIM", # flake8-simplify
142+
"ARG", # flake8-unused-arguments
142143
"PERF", # Perflint
143144
"FURB", # refurb
144145
"RUF", # Ruff-specific rules

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101

102102
def __call__(
103103
self,
104-
module: nn.Module,
104+
_: nn.Module,
105105
args: tuple[PyTree, ...],
106106
kwargs: dict[str, PyTree],
107107
outputs: PyTree,
@@ -157,11 +157,11 @@ class AutogramNode(torch.autograd.Function):
157157

158158
@staticmethod
159159
def forward(
160-
gramian_accumulation_phase: BoolRef,
161-
gramian_computer: GramianComputer,
162-
args: tuple[PyTree, ...],
163-
kwargs: dict[str, PyTree],
164-
gramian_accumulator: GramianAccumulator,
160+
_: BoolRef,
161+
__: GramianComputer,
162+
___: tuple[PyTree, ...],
163+
____: dict[str, PyTree],
164+
_____: GramianAccumulator,
165165
*rg_tensors: Tensor,
166166
) -> tuple[Tensor, ...]:
167167
return tuple(t.detach() for t in rg_tensors)

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def jac_to_grad(
7373

7474
jacobian_matrix = _unite_jacobians(jacobians)
7575
gradient_vector = aggregator(jacobian_matrix)
76-
gradients = _disunite_gradient(gradient_vector, jacobians, tensors_)
76+
gradients = _disunite_gradient(gradient_vector, tensors_)
7777
accumulate_grads(tensors_, gradients)
7878

7979

@@ -83,9 +83,7 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
8383
return jacobian_matrix
8484

8585

86-
def _disunite_gradient(
87-
gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac]
88-
) -> list[Tensor]:
86+
def _disunite_gradient(gradient_vector: Tensor, tensors: list[TensorWithJac]) -> list[Tensor]:
8987
gradient_vectors = gradient_vector.split([t.numel() for t in tensors])
9088
gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)]
9189
return gradients

src/torchjd/autojac/_transform/_accumulate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __call__(self, gradients: TensorDict, /) -> TensorDict:
1818
accumulate_grads(gradients.keys(), gradients.values())
1919
return {}
2020

21-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
21+
def check_keys(self, _: set[Tensor], /) -> set[Tensor]:
2222
return set()
2323

2424

@@ -35,5 +35,5 @@ def __call__(self, jacobians: TensorDict, /) -> TensorDict:
3535
accumulate_jacs(jacobians.keys(), jacobians.values())
3636
return {}
3737

38-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
38+
def check_keys(self, _: set[Tensor], /) -> set[Tensor]:
3939
return set()

src/torchjd/autojac/_transform/_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __call__(self, input: TensorDict, /) -> TensorDict:
4545
"""Applies the transform to the input."""
4646

4747
@abstractmethod
48-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
48+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
4949
"""
5050
Checks that the provided input_keys satisfy the transform's requirements and returns the
5151
corresponding output keys for recursion.
@@ -80,7 +80,7 @@ def __call__(self, input: TensorDict, /) -> TensorDict:
8080
intermediate = self.inner(input)
8181
return self.outer(intermediate)
8282

83-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
83+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
8484
intermediate_keys = self.inner.check_keys(input_keys)
8585
output_keys = self.outer.check_keys(intermediate_keys)
8686
return output_keys
@@ -113,7 +113,7 @@ def __call__(self, tensor_dict: TensorDict, /) -> TensorDict:
113113
union |= transform(tensor_dict)
114114
return union
115115

116-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
116+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
117117
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]
118118
output_keys = set(output_keys_list)
119119

src/torchjd/autojac/_transform/_diagonalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __call__(self, tensors: TensorDict, /) -> TensorDict:
6969
}
7070
return diagonalized_tensors
7171

72-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
72+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
7373
if not set(self.key_order) == input_keys:
7474
raise RequirementError(
7575
f"The input_keys must match the key_order. Found input_keys {input_keys} and"

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _differentiate(self, tensor_outputs: Sequence[Tensor], /) -> tuple[Tensor, .
5555
tensor_outputs should be.
5656
"""
5757

58-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
58+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
5959
outputs = set(self.outputs)
6060
if not outputs == input_keys:
6161
raise RequirementError(

src/torchjd/autojac/_transform/_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ class Init(Transform):
1616
def __init__(self, values: Set[Tensor]):
1717
self.values = values
1818

19-
def __call__(self, input: TensorDict, /) -> TensorDict:
19+
def __call__(self, _: TensorDict, /) -> TensorDict:
2020
return {value: torch.ones_like(value) for value in self.values}
2121

22-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
22+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
2323
if not input_keys == set():
2424
raise RequirementError(
2525
f"The input_keys should be the empty set. Found input_keys {input_keys}."

src/torchjd/autojac/_transform/_select.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __call__(self, tensor_dict: TensorDict, /) -> TensorDict:
1919
output = {key: tensor_dict[key] for key in self.keys}
2020
return type(tensor_dict)(output)
2121

22-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
22+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
2323
keys = set(self.keys)
2424
if not keys.issubset(input_keys):
2525
raise RequirementError(

src/torchjd/autojac/_transform/_stack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __call__(self, input: TensorDict, /) -> TensorDict:
2828
result = _stack(results)
2929
return result
3030

31-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
31+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
3232
return {key for transform in self.transforms for key in transform.check_keys(input_keys)}
3333

3434

0 commit comments

Comments
 (0)