Skip to content

Commit 3961225

Browse files
committed
Add COM (commas)
1 parent da54f6f commit 3961225

38 files changed

Lines changed: 148 additions & 72 deletions

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ select = [
143143
"RET", # flake8-return
144144
"PYI", # flake8-pyi
145145
"PIE", # flake8-pie
146+
"COM", # flake8-commas
146147
"PERF", # Perflint
147148
"FURB", # refurb
148149
"RUF", # Ruff-specific rules

src/torchjd/_linalg/_gramian.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
6969
"""
7070

7171
regularization_matrix = eps * torch.eye(
72-
gramian.shape[0], dtype=gramian.dtype, device=gramian.device
72+
gramian.shape[0],
73+
dtype=gramian.dtype,
74+
device=gramian.device,
7375
)
7476
output = gramian + regularization_matrix
7577
return cast(PSDMatrix, output)

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _check_is_matrix(matrix: Tensor) -> None:
2121
if not is_matrix(matrix):
2222
raise ValueError(
2323
"Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = "
24-
f"{matrix.shape}`."
24+
f"{matrix.shape}`.",
2525
)
2626

2727
@abstractmethod

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
105105

106106
@staticmethod
107107
def _compute_balance_transformation(
108-
M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min"
108+
M: Tensor,
109+
scale_mode: SUPPORTED_SCALE_MODE = "min",
109110
) -> Tensor:
110111
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
111112
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
@@ -127,7 +128,7 @@ def _compute_balance_transformation(
127128
scale = lambda_.mean()
128129
else:
129130
raise ValueError(
130-
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'."
131+
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'.",
131132
)
132133

133134
return scale.sqrt() * V @ sigma_inv @ V.T

src/torchjd/aggregation/_constant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, weights: Tensor):
3939
if weights.dim() != 1:
4040
raise ValueError(
4141
"Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = "
42-
f"{weights.shape}`."
42+
f"{weights.shape}`.",
4343
)
4444

4545
super().__init__()
@@ -53,5 +53,5 @@ def _check_matrix_shape(self, matrix: Tensor) -> None:
5353
if matrix.shape[0] != len(self.weights):
5454
raise ValueError(
5555
f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified "
56-
f"weights). Found `matrix` with {matrix.shape[0]} rows."
56+
f"weights). Found `matrix` with {matrix.shape[0]} rows.",
5757
)

src/torchjd/aggregation/_dualproj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
self._solver: SUPPORTED_SOLVER = solver
4141

4242
super().__init__(
43-
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver)
43+
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
4444
)
4545

4646
# This prevents considering the computed weights as constant w.r.t. the matrix.

src/torchjd/aggregation/_graddrop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
3030
if leak is not None and leak.dim() != 1:
3131
raise ValueError(
3232
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
33-
f"{leak.shape}`."
33+
f"{leak.shape}`.",
3434
)
3535

3636
super().__init__()
@@ -64,7 +64,7 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
6464
if self.leak is not None and n_rows != len(self.leak):
6565
raise ValueError(
6666
f"Parameter `matrix` should be a matrix of exactly {len(self.leak)} rows (i.e. the "
67-
f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`."
67+
f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`.",
6868
)
6969

7070
def __repr__(self) -> str:

src/torchjd/aggregation/_krum.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ def __init__(self, n_byzantine: int, n_selected: int = 1):
4949
if n_byzantine < 0:
5050
raise ValueError(
5151
"Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = "
52-
f"{n_byzantine}`."
52+
f"{n_byzantine}`.",
5353
)
5454

5555
if n_selected < 1:
5656
raise ValueError(
5757
"Parameter `n_selected` should be a positive integer. Found `n_selected = "
58-
f"{n_selected}`."
58+
f"{n_selected}`.",
5959
)
6060

6161
self.n_byzantine = n_byzantine
@@ -83,11 +83,11 @@ def _check_matrix_shape(self, gramian: PSDMatrix) -> None:
8383
if gramian.shape[0] < min_rows:
8484
raise ValueError(
8585
f"Parameter `gramian` should have at least {min_rows} rows (n_byzantine + 3). Found"
86-
f" `gramian` with {gramian.shape[0]} rows."
86+
f" `gramian` with {gramian.shape[0]} rows.",
8787
)
8888

8989
if gramian.shape[0] < self.n_selected:
9090
raise ValueError(
9191
f"Parameter `gramian` should have at least {self.n_selected} rows (n_selected). "
92-
f"Found `gramian` with {gramian.shape[0]} rows."
92+
f"Found `gramian` with {gramian.shape[0]} rows.",
9393
)

src/torchjd/aggregation/_nash_mtl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
max_norm=max_norm,
8585
update_weights_every=update_weights_every,
8686
optim_niter=optim_niter,
87-
)
87+
),
8888
)
8989
self._n_tasks = n_tasks
9090
self._max_norm = max_norm
@@ -144,7 +144,7 @@ def _stop_criteria(self, gtg: np.ndarray, alpha_t: np.ndarray) -> bool:
144144
return bool(
145145
(self.alpha_param.value is None)
146146
or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3)
147-
or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6)
147+
or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6),
148148
)
149149

150150
def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray:

src/torchjd/aggregation/_trimmed_mean.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, trim_number: int):
2020
if trim_number < 0:
2121
raise ValueError(
2222
"Parameter `trim_number` should be a non-negative integer. Found `trim_number` = "
23-
f"{trim_number}`."
23+
f"{trim_number}`.",
2424
)
2525
self.trim_number = trim_number
2626

@@ -40,7 +40,7 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
4040
if n_rows < min_rows:
4141
raise ValueError(
4242
f"Parameter `matrix` should be a matrix of at least {min_rows} rows "
43-
f"(i.e. `2 * trim_number + 1`). Found `matrix` of shape `{matrix.shape}`."
43+
f"(i.e. `2 * trim_number + 1`). Found `matrix` of shape `{matrix.shape}`.",
4444
)
4545

4646
def __repr__(self) -> str:

0 commit comments

Comments
 (0)