Skip to content

Commit 20fc697

Browse files
committed
Add RUFF
1 parent ddcb440 commit 20fc697

19 files changed

Lines changed: 38 additions & 39 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ select = [
141141
"SIM", # flake8-simplify
142142
"PERF", # Perflint
143143
"FURB", # refurb
144+
"RUF", # Ruff-specific rules
144145
]
145146

146147
ignore = [

src/torchjd/_linalg/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor
33

44
__all__ = [
5-
"compute_gramian",
6-
"normalize",
7-
"regularize",
85
"Matrix",
96
"PSDMatrix",
107
"PSDTensor",
8+
"compute_gramian",
119
"is_matrix",
1210
"is_psd_matrix",
1311
"is_psd_tensor",
12+
"normalize",
13+
"regularize",
1414
]

src/torchjd/aggregation/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181
from ._weighting_bases import GeneralizedWeighting, Weighting
8282

8383
__all__ = [
84+
"IMTLG",
85+
"MGDA",
8486
"Aggregator",
8587
"AlignedMTL",
8688
"AlignedMTLWeighting",
@@ -92,14 +94,12 @@
9294
"Flattening",
9395
"GeneralizedWeighting",
9496
"GradDrop",
95-
"IMTLG",
9697
"IMTLGWeighting",
9798
"Krum",
9899
"KrumWeighting",
100+
"MGDAWeighting",
99101
"Mean",
100102
"MeanWeighting",
101-
"MGDA",
102-
"MGDAWeighting",
103103
"PCGrad",
104104
"PCGradWeighting",
105105
"Random",

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def __init__(
6868

6969
def __repr__(self) -> str:
7070
return (
71-
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
72-
f"scale_mode={repr(self._scale_mode)})"
71+
f"{self.__class__.__name__}(pref_vector={self._pref_vector!r}, "
72+
f"scale_mode={self._scale_mode!r})"
7373
)
7474

7575
def __str__(self) -> str:

src/torchjd/aggregation/_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def forward(self, matrix: Matrix) -> Tensor:
7070
return length * unit_target_vector
7171

7272
def __repr__(self) -> str:
73-
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
73+
return f"{self.__class__.__name__}(pref_vector={self._pref_vector!r})"
7474

7575
def __str__(self) -> str:
7676
return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}"

src/torchjd/aggregation/_constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, weights: Tensor):
2020
self._weights = weights
2121

2222
def __repr__(self) -> str:
23-
return f"{self.__class__.__name__}(weights={repr(self._weights)})"
23+
return f"{self.__class__.__name__}(weights={self._weights!r})"
2424

2525
def __str__(self) -> str:
2626
weights_str = vector_to_str(self._weights)

src/torchjd/aggregation/_dualproj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def __init__(
4848

4949
def __repr__(self) -> str:
5050
return (
51-
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
52-
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
51+
f"{self.__class__.__name__}(pref_vector={self._pref_vector!r}, norm_eps="
52+
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={self._solver!r})"
5353
)
5454

5555
def __str__(self) -> str:

src/torchjd/aggregation/_graddrop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
6868
)
6969

7070
def __repr__(self) -> str:
71-
return f"{self.__class__.__name__}(f={repr(self.f)}, leak={repr(self.leak)})"
71+
return f"{self.__class__.__name__}(f={self.f!r}, leak={self.leak!r})"
7272

7373
def __str__(self) -> str:
7474
if self.leak is None:

src/torchjd/aggregation/_upgrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def __init__(
4949

5050
def __repr__(self) -> str:
5151
return (
52-
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
53-
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
52+
f"{self.__class__.__name__}(pref_vector={self._pref_vector!r}, norm_eps="
53+
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={self._solver!r})"
5454
)
5555

5656
def __str__(self) -> str:

src/torchjd/autogram/_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def compute_gramian(self, output: Tensor) -> Tensor:
278278
target_shape = []
279279

280280
if has_non_batch_dim:
281-
target_shape = [-1] + target_shape
281+
target_shape = [-1, *target_shape]
282282

283283
reshaped_output = ordered_output.reshape(target_shape)
284284
# There are four different cases for the shape of reshaped_output:

0 commit comments

Comments
 (0)