Skip to content

Commit a2eca24

Browse files
committed
Add PYI
1 parent a80b4bc commit a2eca24

4 files changed

Lines changed: 6 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ select = [
141141
"SIM", # flake8-simplify
142142
"ARG", # flake8-unused-arguments
143143
"RET", # flake8-return
144+
"PYI", # flake8-pyi
144145
"PERF", # Perflint
145146
"FURB", # refurb
146147
"RUF", # Ruff-specific rules

src/torchjd/aggregation/_weighting_bases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
_FnOutputT = TypeVar("_FnOutputT", bound=Tensor)
1414

1515

16-
class Weighting(Generic[_T], nn.Module, ABC):
16+
class Weighting(nn.Module, ABC, Generic[_T]):
1717
r"""
1818
Abstract base class for all weighting methods. It has the role of extracting a vector of weights
1919
of dimension :math:`m` from some statistic of a matrix of dimension :math:`m \times n`,

src/torchjd/autojac/_transform/_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Set
1+
from collections.abc import Set as AbstractSet
22

33
import torch
44
from torch import Tensor
@@ -13,7 +13,7 @@ class Init(Transform):
1313
:param values: Tensors for which Gradients must be returned.
1414
"""
1515

16-
def __init__(self, values: Set[Tensor]):
16+
def __init__(self, values: AbstractSet[Tensor]):
1717
self.values = values
1818

1919
def __call__(self, _: TensorDict, /) -> TensorDict:

src/torchjd/autojac/_transform/_select.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Set
1+
from collections.abc import Set as AbstractSet
22

33
from torch import Tensor
44

@@ -12,7 +12,7 @@ class Select(Transform):
1212
:param keys: The keys that should be included in the returned subset.
1313
"""
1414

15-
def __init__(self, keys: Set[Tensor]):
15+
def __init__(self, keys: AbstractSet[Tensor]):
1616
self.keys = keys
1717

1818
def __call__(self, tensor_dict: TensorDict, /) -> TensorDict:

0 commit comments

Comments
 (0)