Skip to content

Commit a80b4bc

Browse files
committed
Add RET (flake8 returns)
1 parent f0f576a commit a80b4bc

34 files changed

Lines changed: 70 additions & 146 deletions

docs/source/conf.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None:
9696
line_str = _get_line_str(obj)
9797
version_str = _get_version_str()
9898

99-
link = f"https://github.com/TorchJD/torchjd/blob/{version_str}/{file_name}{line_str}"
100-
return link
99+
return f"https://github.com/TorchJD/torchjd/blob/{version_str}/{file_name}{line_str}"
101100

102101

103102
def _get_obj(_info: dict[str, str]):
@@ -108,8 +107,7 @@ def _get_obj(_info: dict[str, str]):
108107
for part in full_name.split("."):
109108
obj = getattr(obj, part)
110109
# strip decorators, which would resolve to the source of the decorator
111-
obj = inspect.unwrap(obj)
112-
return obj
110+
return inspect.unwrap(obj)
113111

114112

115113
def _get_file_name(obj) -> str | None:
@@ -124,8 +122,7 @@ def _get_file_name(obj) -> str | None:
124122
def _get_line_str(obj) -> str:
125123
source, start = inspect.getsourcelines(obj)
126124
end = start + len(source) - 1
127-
line_str = f"#L{start}-L{end}"
128-
return line_str
125+
return f"#L{start}-L{end}"
129126

130127

131128
def _get_version_str() -> str:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ select = [
140140
"TID", # flake8-tidy-imports
141141
"SIM", # flake8-simplify
142142
"ARG", # flake8-unused-arguments
143+
"RET", # flake8-return
143144
"PERF", # Perflint
144145
"FURB", # refurb
145146
"RUF", # Ruff-specific rules

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,11 @@ def combine(matrix: Matrix, weights: Tensor) -> Tensor:
5959
weights.
6060
"""
6161

62-
vector = weights @ matrix
63-
return vector
62+
return weights @ matrix
6463

6564
def forward(self, matrix: Matrix) -> Tensor:
6665
weights = self.weighting(matrix)
67-
vector = self.combine(matrix, weights)
68-
return vector
66+
return self.combine(matrix, weights)
6967

7068

7169
class GramianWeightedAggregator(WeightedAggregator):

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def __init__(
101101
def forward(self, gramian: PSDMatrix, /) -> Tensor:
102102
w = self.weighting(gramian)
103103
B = self._compute_balance_transformation(gramian, self._scale_mode)
104-
alpha = B @ w
105-
106-
return alpha
104+
return B @ w
107105

108106
@staticmethod
109107
def _compute_balance_transformation(
@@ -114,8 +112,7 @@ def _compute_balance_transformation(
114112
rank = sum(lambda_ > tol)
115113

116114
if rank == 0:
117-
identity = torch.eye(len(M), dtype=M.dtype, device=M.device)
118-
return identity
115+
return torch.eye(len(M), dtype=M.dtype, device=M.device)
119116

120117
order = torch.argsort(lambda_, dim=-1, descending=True)
121118
lambda_, V = lambda_[order][:rank], V[:, order][:, :rank]
@@ -133,5 +130,4 @@ def _compute_balance_transformation(
133130
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'."
134131
)
135132

136-
B = scale.sqrt() * V @ sigma_inv @ V.T
137-
return B
133+
return scale.sqrt() * V @ sigma_inv @ V.T

src/torchjd/aggregation/_cagrad.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,4 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
101101
# We are approximately on the pareto front
102102
weight_array = np.zeros(dimension)
103103

104-
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)
105-
106-
return weights
104+
return torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)

src/torchjd/aggregation/_dualproj.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,4 @@ def __init__(
8888
def forward(self, gramian: PSDMatrix, /) -> Tensor:
8989
u = self.weighting(gramian)
9090
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
91-
w = project_weights(u, G, self.solver)
92-
return w
91+
return project_weights(u, G, self.solver)

src/torchjd/aggregation/_flattening.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,4 @@ def forward(self, generalized_gramian: PSDTensor) -> Tensor:
2929
shape = generalized_gramian.shape[:k]
3030
square_gramian = flatten(generalized_gramian)
3131
weights_vector = self.weighting(square_gramian)
32-
weights = weights_vector.reshape(shape)
33-
return weights
32+
return weights_vector.reshape(shape)

src/torchjd/aggregation/_imtl_g.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,4 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
3434
v = torch.linalg.pinv(gramian) @ d
3535
v_sum = v.sum()
3636

37-
weights = torch.zeros_like(v) if v_sum.abs() < 1e-12 else v / v_sum
38-
39-
return weights
37+
return torch.zeros_like(v) if v_sum.abs() < 1e-12 else v / v_sum

src/torchjd/aggregation/_krum.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
7676

7777
_, selected_indices = torch.topk(scores, k=self.n_selected, largest=False)
7878
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0])
79-
weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected
80-
81-
return weights
79+
return one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected
8280

8381
def _check_matrix_shape(self, gramian: PSDMatrix) -> None:
8482
min_rows = self.n_byzantine + 3

src/torchjd/aggregation/_mean.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,4 @@ def forward(self, matrix: Tensor, /) -> Tensor:
2828
device = matrix.device
2929
dtype = matrix.dtype
3030
m = matrix.shape[0]
31-
weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype)
32-
return weights
31+
return torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype)

0 commit comments

Comments
 (0)