Skip to content

Commit 65c1165

Browse files
typing: Replace mypy with ty (#551)
Typing changes (necessary to run ty without error): * Use positional-only arguments for methods from classes whose subclasses rename the parameters. This includes `Weighting.forward`, `JacobianComputer._compute_jacobian`, `Transform.__call__` and `Differentiate.__differentiate__`. This fixes a break of LSP. * Remove now-useless typing ignore statement in MGDA * Add necessary casts in NashMTL * Add ignore statement for subclasses of autograd.Functions * Fix name of parameters of methods in `OrderedSet` * Make ModuleFactory generic * Add some missing casts to PSDMatrix * Add ignore statement when calling .grad of BatchedTensor * Ignore type errors in the lightning example's test Structural changes: * Add ty check dependency * Remove mypy check dependency * Change section about type checking in CONTRIBUTING.md * Remove mypy badge * Add ty tool section in pyproject.toml * Change CI to run ty instead of mypy * Check typing in tests too --------- Co-authored-by: Valérian Rey <valerian.rey@gmail.com>
1 parent 7d30162 commit 65c1165

39 files changed

+104
-83
lines changed

.github/workflows/checks.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ jobs:
120120

121121
- uses: ./.github/actions/install-deps
122122
with:
123-
groups: check
123+
groups: check test plot
124124

125-
- name: Run mypy
126-
run: uv run mypy src/torchjd
125+
- name: Run ty
126+
run: uv run ty check
127127

128128
check-todos:
129129
name: Absence of TODOs

CONTRIBUTING.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,12 @@ uv run pre-commit install
111111
uv run make clean
112112
```
113113

114-
## Running `mypy`
114+
## Type checking
115115

116-
From the root of the repo, run:
116+
We use [ty](https://docs.astral.sh/ty/) for type-checking. If you're on VSCode, we recommend using
117+
the `ty` extension. You can also run it from the root of the repo with:
117118
```bash
118-
uv run mypy src/torchjd
119+
uv run ty check
119120
```
120121
121122
## Development guidelines

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
[![Static Badge](https://img.shields.io/badge/%F0%9F%92%AC_ChatBot-chat.torchjd.org-blue?logo=%F0%9F%92%AC)](https://chat.torchjd.org)
55
[![Tests](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml/badge.svg)](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml)
66
[![codecov](https://codecov.io/gh/TorchJD/torchjd/graph/badge.svg?token=8AUCZE76QH)](https://codecov.io/gh/TorchJD/torchjd)
7-
[![mypy](https://img.shields.io/github/actions/workflow/status/TorchJD/torchjd/checks.yml?label=mypy)](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml)
87
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/TorchJD/torchjd/main.svg)](https://results.pre-commit.ci/latest/github/TorchJD/torchjd/main)
98
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchjd)](https://pypi.org/project/torchjd/)
109
[![Static Badge](https://img.shields.io/badge/Discord%20-%20community%20-%20%235865F2?logo=discord&logoColor=%23FFFFFF&label=Discord)](https://discord.gg/76KkRnb3nk)

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Changelog = "https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md"
6666

6767
[dependency-groups]
6868
check = [
69-
"mypy>=1.16.0",
69+
"ty>=0.0.14",
7070
"pre-commit>=2.9.2", # isort doesn't work before 2.9.2
7171
]
7272

@@ -114,3 +114,7 @@ exclude_lines = [
114114
"pragma: not covered",
115115
"@overload",
116116
]
117+
118+
[tool.ty.src]
119+
include = ["src", "tests"]
120+
exclude = ["src/torchjd/aggregation/_nash_mtl.py"]

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
9999
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
100100

101-
def forward(self, gramian: PSDMatrix) -> Tensor:
101+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
102102
w = self.weighting(gramian)
103103
B = self._compute_balance_transformation(gramian, self._scale_mode)
104104
alpha = B @ w

src/torchjd/aggregation/_cagrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, c: float, norm_eps: float = 0.0001):
7676
self.c = c
7777
self.norm_eps = norm_eps
7878

79-
def forward(self, gramian: PSDMatrix) -> Tensor:
79+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
8080
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))
8181

8282
reduced_matrix = U @ S.sqrt().diag()

src/torchjd/aggregation/_constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, weights: Tensor):
4545
super().__init__()
4646
self.weights = weights
4747

48-
def forward(self, matrix: Tensor) -> Tensor:
48+
def forward(self, matrix: Tensor, /) -> Tensor:
4949
self._check_matrix_shape(matrix)
5050
return self.weights
5151

src/torchjd/aggregation/_dualproj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
self.reg_eps = reg_eps
8686
self.solver: SUPPORTED_SOLVER = solver
8787

88-
def forward(self, gramian: PSDMatrix) -> Tensor:
88+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
8989
u = self.weighting(gramian)
9090
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
9191
w = project_weights(u, G, self.solver)

src/torchjd/aggregation/_imtl_g.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class IMTLGWeighting(Weighting[PSDMatrix]):
2929
:class:`~torchjd.aggregation.IMTLG`.
3030
"""
3131

32-
def forward(self, gramian: PSDMatrix) -> Tensor:
32+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
3333
d = torch.sqrt(torch.diagonal(gramian))
3434
v = torch.linalg.pinv(gramian) @ d
3535
v_sum = v.sum()

src/torchjd/aggregation/_krum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, n_byzantine: int, n_selected: int = 1):
6161
self.n_byzantine = n_byzantine
6262
self.n_selected = n_selected
6363

64-
def forward(self, gramian: PSDMatrix) -> Tensor:
64+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
6565
self._check_matrix_shape(gramian)
6666
gradient_norms_squared = torch.diagonal(gramian)
6767
distances_squared = (

0 commit comments

Comments
 (0)