Skip to content

Commit b6880cd

Browse files
ci: Switch from flake8, isort and black to ruff (#554)
Structural changes: * Add ruff dependency, configuration and pre-commit hook * Add ruff in the typing check and rename it code-quality * Remove check-todos (now done by ruff in CI) Code changes to adapt to ruff: * Use strict=True when zipping * Use | None instead of Optional * Use | in isinstance instead of tuple * Rename unused for-loop counters to _ * Stop using .format in strings * Stop importing Iterable and Callable from typing * Group "as" imports * Stop opening files explicitly with "r" mode --------- Co-authored-by: Valérian Rey <valerian.rey@gmail.com>
1 parent 970b8e4 commit b6880cd

33 files changed

+93
-116
lines changed

.github/workflows/checks.yml

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ jobs:
109109
# This reduces false positives due to rate limits
110110
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
111111

112-
typing:
113-
name: Typing correctness
112+
code-quality:
113+
name: Code quality (ty and ruff)
114114
runs-on: ubuntu-latest
115115
steps:
116116
- name: Checkout repository
@@ -128,32 +128,5 @@ jobs:
128128
- name: Run ty
129129
run: uv run ty check --output-format=github
130130

131-
check-todos:
132-
name: Absence of TODOs
133-
runs-on: ubuntu-latest
134-
steps:
135-
- name: Checkout code
136-
uses: actions/checkout@v6
137-
138-
- name: Scan for TODO strings
139-
run: |
140-
echo "Scanning codebase for TODOs..."
141-
142-
git grep -nE "TODO" -- . ':(exclude).github/workflows/*' > todos_found.txt || true
143-
144-
if [ -s todos_found.txt ]; then
145-
echo "❌ ERROR: Found TODOs in the following files:"
146-
echo "-------------------------------------------"
147-
148-
while IFS=: read -r file line content; do
149-
echo "::error file=$file,line=$line::TODO found at $file:$line - must be resolved before merge:%0A$content"
150-
done < todos_found.txt
151-
152-
echo "-------------------------------------------"
153-
echo "Please resolve these TODOs or track them in an issue before merging."
154-
155-
exit 1
156-
else
157-
echo "✅ No TODOs found. Codebase is clean!"
158-
exit 0
159-
fi
131+
- name: Run ruff
132+
run: uv run ruff check --output-format=github

.pre-commit-config.yaml

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,12 @@ repos:
99
- id: check-docstring-first # Check a common error of defining a docstring after code.
1010
- id: check-merge-conflict # Check for files that contain merge conflict strings.
1111

12-
- repo: https://github.com/PyCQA/flake8
13-
rev: 7.3.0
12+
- repo: https://github.com/astral-sh/ruff-pre-commit
13+
rev: v0.9.0
1414
hooks:
15-
- id: flake8 # Check style and syntax. Does not modify code, issues have to be solved manually.
16-
args: [
17-
'--ignore=E501,E203,W503,E402', # Ignore line length problems, space after colon problems, line break occurring before a binary operator problems, module level import not at top of file problems.
18-
]
19-
20-
- repo: https://github.com/pycqa/isort
21-
rev: 7.0.0
22-
hooks:
23-
- id: isort # Sort imports.
24-
args: [
25-
--multi-line=3,
26-
--line-length=100,
27-
--trailing-comma,
28-
--force-grid-wrap=0,
29-
--use-parentheses,
30-
--ensure-newline-before-comments,
31-
]
32-
33-
- repo: https://github.com/psf/black-pre-commit-mirror
34-
rev: 25.12.0
35-
hooks:
36-
- id: black # Format code.
37-
args: [--line-length=100]
15+
- id: ruff
16+
args: [ --fix, --ignore, FIX ]
17+
- id: ruff-format
3818

3919
ci:
4020
autoupdate_commit_msg: 'chore: Update pre-commit hooks'

pyproject.toml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Changelog = "https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md"
6666

6767
[dependency-groups]
6868
check = [
69+
"ruff>=0.14.14",
6970
"ty>=0.0.14",
7071
"pre-commit>=2.9.2", # isort doesn't work before 2.9.2
7172
]
@@ -122,6 +123,32 @@ exclude_lines = [
122123
"@overload",
123124
]
124125

126+
[tool.ruff]
127+
line-length = 100
128+
target-version = "py310"
129+
130+
[tool.ruff.lint]
131+
select = [
132+
"E", # pycodestyle Error
133+
"F", # Pyflakes
134+
"W", # pycodestyle Warning
135+
"I", # isort
136+
"UP", # pyupgrade
137+
"B", # flake8-bugbear
138+
"FIX", # flake8-fixme
139+
]
140+
141+
ignore = [
142+
"E501", # line-too-long (handled by the formatter)
143+
"E402", # module-import-not-at-top-of-file
144+
]
145+
146+
[tool.ruff.lint.isort]
147+
combine-as-imports = true
148+
149+
[tool.ruff.format]
150+
quote-style = "double"
151+
125152
[tool.ty.src]
126153
include = ["src", "tests"]
127154
exclude = ["src/torchjd/aggregation/_nash_mtl.py"]

src/torchjd/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from collections.abc import Callable
22
from warnings import warn as _warn
33

4-
from .autojac import backward as _backward
5-
from .autojac import mtl_backward as _mtl_backward
4+
from .autojac import backward as _backward, mtl_backward as _mtl_backward
65

76
_deprecated_items: dict[str, tuple[str, Callable]] = {
87
"backward": ("autojac", _backward),

src/torchjd/aggregation/_graddrop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,5 @@ def __str__(self) -> str:
7474
if self.leak is None:
7575
leak_str = ""
7676
else:
77-
leak_str = f"([{', '.join(['{:.2f}'.format(l_).rstrip('0') for l_ in self.leak])}])"
77+
leak_str = f"([{', '.join([f'{l_:.2f}'.rstrip('0') for l_ in self.leak])}])"
7878
return f"GradDrop{leak_str}"

src/torchjd/aggregation/_mgda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
5353
dtype = gramian.dtype
5454

5555
alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0]
56-
for i in range(self.max_iters):
56+
for _ in range(self.max_iters):
5757
t = torch.argmin(gramian @ alpha)
5858
e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype)
5959
e_t[t] = 1.0

src/torchjd/aggregation/_utils/str.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ def vector_to_str(vector: Tensor) -> str:
77
`1.23, 1., ...`.
88
"""
99

10-
weights_str = ", ".join(["{:.2f}".format(value).rstrip("0") for value in vector])
10+
weights_str = ", ".join([f"{value:.2f}".rstrip("0") for value in vector])
1111
return weights_str

src/torchjd/autogram/_gramian_accumulator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional
2-
31
from torchjd._linalg import PSDMatrix
42

53

@@ -13,7 +11,7 @@ class GramianAccumulator:
1311
"""
1412

1513
def __init__(self) -> None:
16-
self._gramian: Optional[PSDMatrix] = None
14+
self._gramian: PSDMatrix | None = None
1715

1816
def reset(self) -> None:
1917
self._gramian = None
@@ -25,7 +23,7 @@ def accumulate_gramian(self, gramian: PSDMatrix) -> None:
2523
self._gramian = gramian
2624

2725
@property
28-
def gramian(self) -> Optional[PSDMatrix]:
26+
def gramian(self) -> PSDMatrix | None:
2927
"""
3028
Get the Gramian matrix accumulated so far.
3129

src/torchjd/autogram/_gramian_computer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Optional, cast
2+
from typing import cast
33

44
from torch import Tensor
55
from torch.utils._pytree import PyTree
@@ -16,12 +16,14 @@ def __call__(
1616
grad_outputs: tuple[Tensor, ...],
1717
args: tuple[PyTree, ...],
1818
kwargs: dict[str, PyTree],
19-
) -> Optional[PSDMatrix]:
19+
) -> PSDMatrix | None:
2020
"""Compute what we can for a module and optionally return the gramian if it's ready."""
2121

22+
@abstractmethod
2223
def track_forward_call(self) -> None:
2324
"""Track that the module's forward was called. Necessary in some implementations."""
2425

26+
@abstractmethod
2527
def reset(self) -> None:
2628
"""Reset state if any. Necessary in some implementations."""
2729

@@ -40,7 +42,7 @@ class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
4042
def __init__(self, jacobian_computer: JacobianComputer):
4143
super().__init__(jacobian_computer)
4244
self.remaining_counter = 0
43-
self.summed_jacobian: Optional[Matrix] = None
45+
self.summed_jacobian: Matrix | None = None
4446

4547
def reset(self) -> None:
4648
self.remaining_counter = 0
@@ -55,7 +57,7 @@ def __call__(
5557
grad_outputs: tuple[Tensor, ...],
5658
args: tuple[PyTree, ...],
5759
kwargs: dict[str, PyTree],
58-
) -> Optional[PSDMatrix]:
60+
) -> PSDMatrix | None:
5961
"""Compute what we can for a module and optionally return the gramian if it's ready."""
6062

6163
jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __call__(
141141
*rg_outputs,
142142
)
143143

144-
for idx, output in zip(rg_output_indices, autograd_fn_rg_outputs):
144+
for idx, output in zip(rg_output_indices, autograd_fn_rg_outputs, strict=True):
145145
flat_outputs[idx] = output
146146

147147
return tree_unflatten(flat_outputs, output_spec)

0 commit comments

Comments
 (0)