Skip to content

Commit 9d9cbf0

Browse files
authored
Merge branch 'main' into optimize_jac_to_grad
2 parents 3f9a6d1 + 0a1ecfd commit 9d9cbf0

64 files changed

Lines changed: 304 additions & 252 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/checks.yml

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
tests:
1414
# Default config: py3.14, ubuntu-latest, float32, full options.
1515
# The idea is to make each of those params vary one by one, to limit the number of tests to run.
16-
name: Tests (py${{ matrix.python-version || '3.14' }}, ${{ matrix.os || 'ubuntu-latest' }}, ${{ matrix.dtype || 'float32' }}, ${{ matrix.options || 'full' }})
16+
name: Tests (py${{ matrix.python-version || '3.14' }}, ${{ matrix.os || 'ubuntu-latest' }}, ${{ matrix.dtype || 'float32' }}, ${{ matrix.options || 'full' }}${{ matrix.extra_groups && format(', {0}', matrix.extra_groups) || '' }})
1717
runs-on: ${{ matrix.os || 'ubuntu-latest' }}
1818
strategy:
1919
fail-fast: false
@@ -32,6 +32,9 @@ jobs:
3232
- dtype: float64
3333
# Installation options variations
3434
- options: 'none'
35+
# Lower-bounds of all dependencies and Python version.
36+
- python-version: '3.10.0'
37+
extra_groups: 'lower_bounds'
3538

3639
steps:
3740
- name: Checkout repository
@@ -45,7 +48,7 @@ jobs:
4548
- uses: ./.github/actions/install-deps
4649
with:
4750
options: ${{ matrix.options || 'full' }}
48-
groups: test
51+
groups: test ${{ matrix.extra_groups }}
4952

5053
- name: Run tests
5154
run: uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml
@@ -106,8 +109,8 @@ jobs:
106109
# This reduces false positives due to rate limits
107110
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
108111

109-
typing:
110-
name: Typing correctness
112+
code-quality:
113+
name: Code quality (ty and ruff)
111114
runs-on: ubuntu-latest
112115
steps:
113116
- name: Checkout repository
@@ -120,37 +123,10 @@ jobs:
120123

121124
- uses: ./.github/actions/install-deps
122125
with:
123-
groups: check
126+
groups: check test plot
124127

125-
- name: Run mypy
126-
run: uv run mypy src/torchjd
128+
- name: Run ty
129+
run: uv run ty check --output-format=github
127130

128-
check-todos:
129-
name: Absence of TODOs
130-
runs-on: ubuntu-latest
131-
steps:
132-
- name: Checkout code
133-
uses: actions/checkout@v6
134-
135-
- name: Scan for TODO strings
136-
run: |
137-
echo "Scanning codebase for TODOs..."
138-
139-
git grep -nE "TODO" -- . ':(exclude).github/workflows/*' > todos_found.txt || true
140-
141-
if [ -s todos_found.txt ]; then
142-
echo "❌ ERROR: Found TODOs in the following files:"
143-
echo "-------------------------------------------"
144-
145-
while IFS=: read -r file line content; do
146-
echo "::error file=$file,line=$line::TODO found at $file:$line - must be resolved before merge:%0A$content"
147-
done < todos_found.txt
148-
149-
echo "-------------------------------------------"
150-
echo "Please resolve these TODOs or track them in an issue before merging."
151-
152-
exit 1
153-
else
154-
echo "✅ No TODOs found. Codebase is clean!"
155-
exit 0
156-
fi
131+
- name: Run ruff
132+
run: uv run ruff check --output-format=github

.pre-commit-config.yaml

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,12 @@ repos:
99
- id: check-docstring-first # Check a common error of defining a docstring after code.
1010
- id: check-merge-conflict # Check for files that contain merge conflict strings.
1111

12-
- repo: https://github.com/PyCQA/flake8
13-
rev: 7.3.0
12+
- repo: https://github.com/astral-sh/ruff-pre-commit
13+
rev: v0.9.0
1414
hooks:
15-
- id: flake8 # Check style and syntax. Does not modify code, issues have to be solved manually.
16-
args: [
17-
'--ignore=E501,E203,W503,E402', # Ignore line length problems, space after colon problems, line break occurring before a binary operator problems, module level import not at top of file problems.
18-
]
19-
20-
- repo: https://github.com/pycqa/isort
21-
rev: 7.0.0
22-
hooks:
23-
- id: isort # Sort imports.
24-
args: [
25-
--multi-line=3,
26-
--line-length=100,
27-
--trailing-comma,
28-
--force-grid-wrap=0,
29-
--use-parentheses,
30-
--ensure-newline-before-comments,
31-
]
32-
33-
- repo: https://github.com/psf/black-pre-commit-mirror
34-
rev: 25.12.0
35-
hooks:
36-
- id: black # Format code.
37-
args: [--line-length=100]
15+
- id: ruff
16+
args: [ --fix, --ignore, FIX ]
17+
- id: ruff-format
3818

3919
ci:
4020
autoupdate_commit_msg: 'chore: Update pre-commit hooks'

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,13 @@ changelog does not include internal changes that do not affect the user.
4545
mtl_backward(losses, features)
4646
jac_to_grad(shared_module.parameters(), aggregator)
4747
```
48-
48+
4949
- Removed several unnecessary memory duplications. This should significantly improve the memory
5050
efficiency and speed of `autojac`.
51+
- Increased the lower bounds of the torch (from 2.0.0 to 2.3.0) and numpy (from 1.21.0
52+
to 1.21.2) dependencies to reflect what really works with torchjd. We now also run torchjd's tests
53+
with the dependency lower-bounds specified in `pyproject.toml`, so we should now always accurately
54+
reflect the actual lower-bounds.
5155

5256
## [0.8.1] - 2026-01-07
5357

CONTRIBUTING.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,12 @@ uv run pre-commit install
111111
uv run make clean
112112
```
113113

114-
## Running `mypy`
114+
## Type checking
115115

116-
From the root of the repo, run:
116+
We use [ty](https://docs.astral.sh/ty/) for type-checking. If you're on VSCode, we recommend using
117+
the `ty` extension. You can also run it from the root of the repo with:
117118
```bash
118-
uv run mypy src/torchjd
119+
uv run ty check
119120
```
120121
121122
## Development guidelines

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
[![Static Badge](https://img.shields.io/badge/%F0%9F%92%AC_ChatBot-chat.torchjd.org-blue?logo=%F0%9F%92%AC)](https://chat.torchjd.org)
55
[![Tests](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml/badge.svg)](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml)
66
[![codecov](https://codecov.io/gh/TorchJD/torchjd/graph/badge.svg?token=8AUCZE76QH)](https://codecov.io/gh/TorchJD/torchjd)
7-
[![mypy](https://img.shields.io/github/actions/workflow/status/TorchJD/torchjd/checks.yml?label=mypy)](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml)
87
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/TorchJD/torchjd/main.svg)](https://results.pre-commit.ci/latest/github/TorchJD/torchjd/main)
98
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchjd)](https://pypi.org/project/torchjd/)
109
[![Static Badge](https://img.shields.io/badge/Discord%20-%20community%20-%20%235865F2?logo=discord&logoColor=%23FFFFFF&label=Discord)](https://discord.gg/76KkRnb3nk)

pyproject.toml

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ authors = [
1313
]
1414
requires-python = ">=3.10"
1515
dependencies = [
16-
"torch>=2.0.0",
16+
"torch>=2.3.0", # Problems before 2.4.0, especially with autogram.
1717
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
18-
"numpy>=1.21.0", # Does not work before 1.21
18+
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
1919
"qpsolvers>=1.0.1", # Does not work before 1.0.1
2020
]
2121
classifiers = [
@@ -66,7 +66,8 @@ Changelog = "https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md"
6666

6767
[dependency-groups]
6868
check = [
69-
"mypy>=1.16.0",
69+
"ruff>=0.14.14",
70+
"ty>=0.0.14",
7071
"pre-commit>=2.9.2", # isort doesn't work before 2.9.2
7172
]
7273

@@ -83,7 +84,7 @@ test = [
8384
"pytest>=7.3", # Before version 7.3, not all tests are run
8485
"pytest-cov>=6.0.0", # Recent version to avoid problems, could be relaxed
8586
"lightning>=2.0.9", # No OptimizerLRScheduler public type before 2.0.9
86-
"torchvision>=0.22.1" # Recent version to avoid problems, could be relaxed
87+
"torchvision>=0.18.0"
8788
]
8889

8990
plot = [
@@ -92,6 +93,13 @@ plot = [
9293
"kaleido==0.2.1", # Only works with locked version
9394
"matplotlib>=3.10.0", # Recent version to avoid problems, could be relaxed
9495
]
96+
# Dependency group allowing to easily resolve version of the core dependencies to the lower bound.
97+
lower_bounds = [
98+
"torch==2.3.0",
99+
"numpy==1.21.2",
100+
"quadprog==0.1.9",
101+
"qpsolvers==1.0.1",
102+
]
95103

96104
[project.optional-dependencies]
97105
nash_mtl = [
@@ -114,3 +122,33 @@ exclude_lines = [
114122
"pragma: not covered",
115123
"@overload",
116124
]
125+
126+
[tool.ruff]
127+
line-length = 100
128+
target-version = "py310"
129+
130+
[tool.ruff.lint]
131+
select = [
132+
"E", # pycodestyle Error
133+
"F", # Pyflakes
134+
"W", # pycodestyle Warning
135+
"I", # isort
136+
"UP", # pyupgrade
137+
"B", # flake8-bugbear
138+
"FIX", # flake8-fixme
139+
]
140+
141+
ignore = [
142+
"E501", # line-too-long (handled by the formatter)
143+
"E402", # module-import-not-at-top-of-file
144+
]
145+
146+
[tool.ruff.lint.isort]
147+
combine-as-imports = true
148+
149+
[tool.ruff.format]
150+
quote-style = "double"
151+
152+
[tool.ty.src]
153+
include = ["src", "tests"]
154+
exclude = ["src/torchjd/aggregation/_nash_mtl.py"]

src/torchjd/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from collections.abc import Callable
22
from warnings import warn as _warn
33

4-
from .autojac import backward as _backward
5-
from .autojac import mtl_backward as _mtl_backward
4+
from .autojac import backward as _backward, mtl_backward as _mtl_backward
65

76
_deprecated_items: dict[str, tuple[str, Callable]] = {
87
"backward": ("autojac", _backward),

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# SOFTWARE.
2626

2727

28-
from typing import Literal
28+
from typing import Literal, TypeAlias
2929

3030
import torch
3131
from torch import Tensor
@@ -37,6 +37,8 @@
3737
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
3838
from ._weighting_bases import Weighting
3939

40+
SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"]
41+
4042

4143
class AlignedMTL(GramianWeightedAggregator):
4244
r"""
@@ -58,10 +60,10 @@ class AlignedMTL(GramianWeightedAggregator):
5860
def __init__(
5961
self,
6062
pref_vector: Tensor | None = None,
61-
scale_mode: Literal["min", "median", "rmse"] = "min",
63+
scale_mode: SUPPORTED_SCALE_MODE = "min",
6264
):
6365
self._pref_vector = pref_vector
64-
self._scale_mode = scale_mode
66+
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
6567
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))
6668

6769
def __repr__(self) -> str:
@@ -89,14 +91,14 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]):
8991
def __init__(
9092
self,
9193
pref_vector: Tensor | None = None,
92-
scale_mode: Literal["min", "median", "rmse"] = "min",
94+
scale_mode: SUPPORTED_SCALE_MODE = "min",
9395
):
9496
super().__init__()
9597
self._pref_vector = pref_vector
96-
self._scale_mode = scale_mode
98+
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
9799
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
98100

99-
def forward(self, gramian: PSDMatrix) -> Tensor:
101+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
100102
w = self.weighting(gramian)
101103
B = self._compute_balance_transformation(gramian, self._scale_mode)
102104
alpha = B @ w
@@ -105,7 +107,7 @@ def forward(self, gramian: PSDMatrix) -> Tensor:
105107

106108
@staticmethod
107109
def _compute_balance_transformation(
108-
M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min"
110+
M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min"
109111
) -> Tensor:
110112
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
111113
tol = torch.max(lambda_) * len(M) * torch.finfo().eps

src/torchjd/aggregation/_cagrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, c: float, norm_eps: float = 0.0001):
7676
self.c = c
7777
self.norm_eps = norm_eps
7878

79-
def forward(self, gramian: PSDMatrix) -> Tensor:
79+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
8080
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))
8181

8282
reduced_matrix = U @ S.sqrt().diag()

src/torchjd/aggregation/_constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, weights: Tensor):
4545
super().__init__()
4646
self.weights = weights
4747

48-
def forward(self, matrix: Tensor) -> Tensor:
48+
def forward(self, matrix: Tensor, /) -> Tensor:
4949
self._check_matrix_shape(matrix)
5050
return self.weights
5151

0 commit comments

Comments
 (0)