Skip to content

Commit ef0516f

Browse files
committed
Add FBT (boolean traps)
1 parent 3961225 commit ef0516f

20 files changed

Lines changed: 28 additions & 9 deletions

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ select = [
134134
"W", # pycodestyle Warning
135135
"I", # isort
136136
"UP", # pyupgrade
137+
"FBT", # flake8-boolean-trap
137138
"B", # flake8-bugbear
138139
"C4", # flake8-comprehensions
139140
"FIX", # flake8-fixme

src/torchjd/autogram/_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def compute_gramian(self, output: Tensor) -> Tensor:
306306

307307
return gramian
308308

309-
def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> PSDMatrix:
309+
def _compute_square_gramian(self, output: Tensor, *, has_non_batch_dim: bool) -> PSDMatrix:
310310
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))
311311

312312
def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
):
3737
self._target_edges = target_edges
3838
self._gramian_accumulator = gramian_accumulator
39-
self.gramian_accumulation_phase = BoolRef(False)
39+
self.gramian_accumulation_phase = BoolRef(value=False)
4040
self._handles: list[TorchRemovableHandle] = []
4141

4242
# When the ModuleHookManager is not referenced anymore, there is no reason to keep the hooks
@@ -79,7 +79,7 @@ def remove_hooks(handles: list[TorchRemovableHandle]) -> None:
7979
class BoolRef:
8080
"""Class wrapping a boolean value, acting as a reference to this boolean value."""
8181

82-
def __init__(self, value: bool):
82+
def __init__(self, *, value: bool):
8383
self.value = value
8484

8585
def __bool__(self) -> bool:

src/torchjd/autojac/_backward.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
def backward(
1010
tensors: Sequence[Tensor] | Tensor,
1111
inputs: Iterable[Tensor] | None = None,
12+
*,
1213
retain_graph: bool = False,
1314
parallel_chunk_size: int | None = None,
1415
) -> None:
@@ -86,6 +87,7 @@ def backward(
8687
def _create_transform(
8788
tensors: OrderedSet[Tensor],
8889
inputs: OrderedSet[Tensor],
90+
*,
8991
retain_graph: bool,
9092
parallel_chunk_size: int | None,
9193
) -> Transform:

src/torchjd/autojac/_jac.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
def jac(
1818
outputs: Sequence[Tensor] | Tensor,
1919
inputs: Iterable[Tensor] | None = None,
20+
*,
2021
retain_graph: bool = False,
2122
parallel_chunk_size: int | None = None,
2223
) -> tuple[Tensor, ...]:
@@ -136,6 +137,7 @@ def jac(
136137
def _create_transform(
137138
outputs: OrderedSet[Tensor],
138139
inputs: OrderedSet[Tensor],
140+
*,
139141
retain_graph: bool,
140142
parallel_chunk_size: int | None,
141143
) -> Transform:

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
def jac_to_grad(
1212
tensors: Iterable[Tensor],
1313
aggregator: Aggregator,
14+
*,
1415
retain_jac: bool = False,
1516
) -> None:
1617
r"""

src/torchjd/autojac/_mtl_backward.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def mtl_backward(
2121
features: Sequence[Tensor] | Tensor,
2222
tasks_params: Sequence[Iterable[Tensor]] | None = None,
2323
shared_params: Iterable[Tensor] | None = None,
24+
*,
2425
retain_graph: bool = False,
2526
parallel_chunk_size: int | None = None,
2627
) -> None:
@@ -113,6 +114,7 @@ def _create_transform(
113114
features: OrderedSet[Tensor],
114115
tasks_params: list[OrderedSet[Tensor]],
115116
shared_params: OrderedSet[Tensor],
117+
*,
116118
retain_graph: bool,
117119
parallel_chunk_size: int | None,
118120
) -> Transform:
@@ -152,6 +154,7 @@ def _create_task_transform(
152154
features: OrderedSet[Tensor],
153155
task_params: OrderedSet[Tensor],
154156
loss: OrderedSet[Tensor], # contains a single scalar loss
157+
*,
155158
retain_graph: bool,
156159
) -> Transform:
157160
# Tensors with respect to which we compute the gradients.

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
self,
3030
outputs: OrderedSet[Tensor],
3131
inputs: OrderedSet[Tensor],
32+
*,
3233
retain_graph: bool,
3334
create_graph: bool,
3435
):
@@ -64,7 +65,7 @@ def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
6465
)
6566
return set(self.inputs)
6667

67-
def _get_vjp(self, grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor, ...]:
68+
def _get_vjp(self, grad_outputs: Sequence[Tensor], *, retain_graph: bool) -> tuple[Tensor, ...]:
6869
optional_grads = torch.autograd.grad(
6970
self.outputs,
7071
self.inputs,

src/torchjd/autojac/_transform/_grad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
self,
3030
outputs: OrderedSet[Tensor],
3131
inputs: OrderedSet[Tensor],
32+
*,
3233
retain_graph: bool = False,
3334
create_graph: bool = False,
3435
):

src/torchjd/autojac/_transform/_jac.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
self,
3636
outputs: OrderedSet[Tensor],
3737
inputs: OrderedSet[Tensor],
38+
*,
3839
chunk_size: int | None,
3940
retain_graph: bool = False,
4041
create_graph: bool = False,

0 commit comments

Comments
 (0)