Skip to content

Commit 84c5a09

Browse files
ci: Add more flags to ruff checker (#561)
* Add C4, TID, SIM, PERF, FURB, RUF, PYI, PIE, COM. * Add RUF010, RUF012, RUF022 and RET504 to exceptions. * Apply changes induced by these new checks. --------- Co-authored-by: Valérian Rey <valerian.rey@gmail.com>
1 parent e767e40 commit 84c5a09

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+255
-197
lines changed

pyproject.toml

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,32 @@ target-version = "py310"
129129

130130
[tool.ruff.lint]
131131
select = [
132-
"E", # pycodestyle Error
133-
"F", # Pyflakes
134-
"W", # pycodestyle Warning
135-
"I", # isort
136-
"UP", # pyupgrade
137-
"B", # flake8-bugbear
132+
"E", # pycodestyle Error
133+
"F", # Pyflakes
134+
"W", # pycodestyle Warning
135+
"I", # isort
136+
"UP", # pyupgrade
137+
"B", # flake8-bugbear
138+
"C4", # flake8-comprehensions
138139
"FIX", # flake8-fixme
140+
"TID", # flake8-tidy-imports
141+
"SIM", # flake8-simplify
142+
"RET", # flake8-return
143+
"PYI", # flake8-pyi
144+
"PIE", # flake8-pie
145+
"COM", # flake8-commas
146+
"PERF", # Perflint
147+
"FURB", # refurb
148+
"RUF", # Ruff-specific rules
139149
]
140150

141151
ignore = [
142-
"E501", # line-too-long (handled by the formatter)
143-
"E402", # module-import-not-at-top-of-file
152+
"E501", # line-too-long (handled by the formatter)
153+
"E402", # module-import-not-at-top-of-file
154+
"RUF022", # __all__ not sorted
155+
"RUF010", # Use explicit conversion flag
156+
"RUF012", # Mutable default value for class attribute (a bit tedious to fix)
157+
"RET504", # Unnecessary assignment return statement
144158
]
145159

146160
[tool.ruff.lint.isort]

src/torchjd/_linalg/_gramian.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
3535
first dimension).
3636
"""
3737

38-
contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim
38+
contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
3939
indices_source = list(range(t.ndim - contracted_dims))
4040
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
4141
transposed = t.movedim(indices_source, indices_dest)
@@ -70,7 +70,9 @@ def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
7070
"""
7171

7272
regularization_matrix = eps * torch.eye(
73-
gramian.shape[0], dtype=gramian.dtype, device=gramian.device
73+
gramian.shape[0],
74+
dtype=gramian.dtype,
75+
device=gramian.device,
7476
)
7577
output = gramian + regularization_matrix
7678
return cast(PSDMatrix, output)

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _check_is_matrix(matrix: Tensor) -> None:
2121
if not is_matrix(matrix):
2222
raise ValueError(
2323
"Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = "
24-
f"{matrix.shape}`."
24+
f"{matrix.shape}`.",
2525
)
2626

2727
@abstractmethod

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
107107

108108
@staticmethod
109109
def _compute_balance_transformation(
110-
M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min"
110+
M: Tensor,
111+
scale_mode: SUPPORTED_SCALE_MODE = "min",
111112
) -> Tensor:
112113
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
113114
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
@@ -130,7 +131,7 @@ def _compute_balance_transformation(
130131
scale = lambda_.mean()
131132
else:
132133
raise ValueError(
133-
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'."
134+
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'.",
134135
)
135136

136137
B = scale.sqrt() * V @ sigma_inv @ V.T

src/torchjd/aggregation/_constant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, weights: Tensor):
3939
if weights.dim() != 1:
4040
raise ValueError(
4141
"Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = "
42-
f"{weights.shape}`."
42+
f"{weights.shape}`.",
4343
)
4444

4545
super().__init__()
@@ -53,5 +53,5 @@ def _check_matrix_shape(self, matrix: Tensor) -> None:
5353
if matrix.shape[0] != len(self.weights):
5454
raise ValueError(
5555
f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified "
56-
f"weights). Found `matrix` with {matrix.shape[0]} rows."
56+
f"weights). Found `matrix` with {matrix.shape[0]} rows.",
5757
)

src/torchjd/aggregation/_dualproj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
self._solver: SUPPORTED_SOLVER = solver
4141

4242
super().__init__(
43-
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver)
43+
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
4444
)
4545

4646
# This prevents considering the computed weights as constant w.r.t. the matrix.

src/torchjd/aggregation/_graddrop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
3030
if leak is not None and leak.dim() != 1:
3131
raise ValueError(
3232
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
33-
f"{leak.shape}`."
33+
f"{leak.shape}`.",
3434
)
3535

3636
super().__init__()
@@ -64,7 +64,7 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
6464
if self.leak is not None and n_rows != len(self.leak):
6565
raise ValueError(
6666
f"Parameter `matrix` should be a matrix of exactly {len(self.leak)} rows (i.e. the "
67-
f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`."
67+
f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`.",
6868
)
6969

7070
def __repr__(self) -> str:

src/torchjd/aggregation/_imtl_g.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
3434
v = torch.linalg.pinv(gramian) @ d
3535
v_sum = v.sum()
3636

37-
if v_sum.abs() < 1e-12:
38-
weights = torch.zeros_like(v)
39-
else:
40-
weights = v / v_sum
37+
weights = torch.zeros_like(v) if v_sum.abs() < 1e-12 else v / v_sum
4138

4239
return weights

src/torchjd/aggregation/_krum.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ def __init__(self, n_byzantine: int, n_selected: int = 1):
4949
if n_byzantine < 0:
5050
raise ValueError(
5151
"Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = "
52-
f"{n_byzantine}`."
52+
f"{n_byzantine}`.",
5353
)
5454

5555
if n_selected < 1:
5656
raise ValueError(
5757
"Parameter `n_selected` should be a positive integer. Found `n_selected = "
58-
f"{n_selected}`."
58+
f"{n_selected}`.",
5959
)
6060

6161
self.n_byzantine = n_byzantine
@@ -85,11 +85,11 @@ def _check_matrix_shape(self, gramian: PSDMatrix) -> None:
8585
if gramian.shape[0] < min_rows:
8686
raise ValueError(
8787
f"Parameter `gramian` should have at least {min_rows} rows (n_byzantine + 3). Found"
88-
f" `gramian` with {gramian.shape[0]} rows."
88+
f" `gramian` with {gramian.shape[0]} rows.",
8989
)
9090

9191
if gramian.shape[0] < self.n_selected:
9292
raise ValueError(
9393
f"Parameter `gramian` should have at least {self.n_selected} rows (n_selected). "
94-
f"Found `gramian` with {gramian.shape[0]} rows."
94+
f"Found `gramian` with {gramian.shape[0]} rows.",
9595
)

src/torchjd/aggregation/_nash_mtl.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
max_norm=max_norm,
8585
update_weights_every=update_weights_every,
8686
optim_niter=optim_niter,
87-
)
87+
),
8888
)
8989
self._n_tasks = n_tasks
9090
self._max_norm = max_norm
@@ -144,7 +144,7 @@ def _stop_criteria(self, gtg: np.ndarray, alpha_t: np.ndarray) -> bool:
144144
return bool(
145145
(self.alpha_param.value is None)
146146
or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3)
147-
or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6)
147+
or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6),
148148
)
149149

150150
def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray:
@@ -189,12 +189,10 @@ def _init_optim_problem(self) -> None:
189189
self.phi_alpha = self._calc_phi_alpha_linearization()
190190

191191
G_alpha = self.G_param @ self.alpha_param
192-
constraint = []
193-
for i in range(self.n_tasks):
194-
constraint.append(
195-
-cp.log(self.alpha_param[i] * self.normalization_factor_param) - cp.log(G_alpha[i])
196-
<= 0
197-
)
192+
constraint = [
193+
-cp.log(a * self.normalization_factor_param) - cp.log(G_a) <= 0
194+
for a, G_a in zip(self.alpha_param, G_alpha, strict=True)
195+
]
198196
obj = cp.Minimize(cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param)
199197
self.prob = cp.Problem(obj, constraint)
200198

0 commit comments

Comments
 (0)