Skip to content

Commit 4803ec3

Browse files
committed
Fix tests and add some mandatory exceptions
1 parent 6f594aa commit 4803ec3

10 files changed

Lines changed: 17 additions & 14 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,9 @@ select = [
151151
]
152152

153153
ignore = [
154-
"E501", # line-too-long (handled by the formatter)
155-
"E402", # module-import-not-at-top-of-file
154+
"E501", # line-too-long (handled by the formatter)
155+
"E402", # module-import-not-at-top-of-file
156+
"COM812",
156157
]
157158

158159
[tool.ruff.lint.isort]

src/torchjd/aggregation/_nash_mtl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _init_optim_problem(self) -> None:
189189

190190
G_alpha = self.G_param @ self.alpha_param
191191
constraint = [
192-
-cp.log(self.a * self.normalization_factor_param) - cp.log(G_a) <= 0
192+
-cp.log(a * self.normalization_factor_param) - cp.log(G_a) <= 0
193193
for a, G_a in zip(self.alpha_param, G_alpha, strict=True)
194194
]
195195
obj = cp.Minimize(cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param)

src/torchjd/autogram/_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,9 @@ def compute_gramian(self, output: Tensor) -> Tensor:
288288
self._module_hook_manager.gramian_accumulation_phase.value = True
289289

290290
try:
291-
square_gramian = self._compute_square_gramian(reshaped_output, has_non_batch_dim)
291+
square_gramian = self._compute_square_gramian(
292+
reshaped_output, has_non_batch_dim=has_non_batch_dim
293+
)
292294
finally:
293295
# Reset everything that has a state, even if the previous call raised an exception
294296
self._module_hook_manager.gramian_accumulation_phase.value = False

src/torchjd/autojac/_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _create_transform(
100100
diag = Diagonalize(tensors)
101101

102102
# Transform that computes the required Jacobians.
103-
jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph)
103+
jac = Jac(tensors, inputs, chunk_size=parallel_chunk_size, retain_graph=retain_graph)
104104

105105
# Transform that accumulates the result in the .jac field of the inputs.
106106
accumulate = AccumulateJac()

src/torchjd/autojac/_jac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,6 @@ def _create_transform(
148148
diag = Diagonalize(outputs)
149149

150150
# Transform that computes the required Jacobians.
151-
jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph)
151+
jac = Jac(outputs, inputs, chunk_size=parallel_chunk_size, retain_graph=retain_graph)
152152

153153
return jac << diag << init

src/torchjd/autojac/_mtl_backward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _create_transform(
132132
features,
133133
task_params,
134134
OrderedSet([loss]),
135-
retain_graph,
135+
retain_graph=retain_graph,
136136
)
137137
for task_params, loss in zip(tasks_params, losses, strict=True)
138138
]
@@ -142,7 +142,7 @@ def _create_transform(
142142
stack = Stack(task_transforms)
143143

144144
# Transform that computes the Jacobians of the losses w.r.t. the shared parameters.
145-
jac = Jac(features, shared_params, parallel_chunk_size, retain_graph)
145+
jac = Jac(features, shared_params, chunk_size=parallel_chunk_size, retain_graph=retain_graph)
146146

147147
# Transform that accumulates the result in the .jac field of the shared parameters.
148148
accumulate = AccumulateJac()
@@ -165,7 +165,7 @@ def _create_task_transform(
165165

166166
# Transform that computes the gradients of the loss w.r.t. the task-specific parameters and
167167
# the features.
168-
grad = Grad(loss, to_differentiate, retain_graph)
168+
grad = Grad(loss, to_differentiate, retain_graph=retain_graph)
169169

170170
# Transform that accumulates the gradients w.r.t. the task-specific parameters into their
171171
# .grad fields.

src/torchjd/autojac/_transform/_grad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
retain_graph: bool = False,
3434
create_graph: bool = False,
3535
):
36-
super().__init__(outputs, inputs, retain_graph, create_graph)
36+
super().__init__(outputs, inputs, retain_graph=retain_graph, create_graph=create_graph)
3737

3838
def _differentiate(self, grad_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]:
3939
"""
@@ -54,4 +54,4 @@ def _differentiate(self, grad_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...
5454
if len(self.outputs) == 0:
5555
return tuple(torch.zeros_like(input) for input in self.inputs)
5656

57-
return self._get_vjp(grad_outputs, self.retain_graph)
57+
return self._get_vjp(grad_outputs, retain_graph=self.retain_graph)

src/torchjd/autojac/_transform/_jac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
retain_graph: bool = False,
4141
create_graph: bool = False,
4242
):
43-
super().__init__(outputs, inputs, retain_graph, create_graph)
43+
super().__init__(outputs, inputs, retain_graph=retain_graph, create_graph=create_graph)
4444
self.chunk_size = chunk_size
4545

4646
def _differentiate(self, jac_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]:

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def pytest_collection_modifyitems(config, items):
4949
item.add_marker(xfail_cuda)
5050

5151

52-
def pytest_make_parametrize_id(_, val, __):
52+
def pytest_make_parametrize_id(config, val, argname): # noqa: ARG001
5353
MAX_SIZE = 40
5454
optional_string = None # Returning None means using pytest's way of making the string
5555

tests/unit/autojac/test_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_jac_is_populated():
4848
@mark.parametrize("chunk_size", [1, 2, None])
4949
def test_value_is_correct(
5050
shape: tuple[int, int],
51-
*manually_specify_inputs: bool,
51+
manually_specify_inputs: bool, # noqa: FBT001
5252
chunk_size: int | None,
5353
):
5454
"""

0 commit comments

Comments
 (0)