File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -141,6 +141,7 @@ select = [
141141 " SIM" , # flake8-simplify
142142 " ARG" , # flake8-unused-arguments
143143 " RET" , # flake8-return
144+ " PYI" , # flake8-pyi
144145 " PERF" , # Perflint
145146 " FURB" , # refurb
146147 " RUF" , # Ruff-specific rules
Original file line number Diff line number Diff line change 1313_FnOutputT = TypeVar ("_FnOutputT" , bound = Tensor )
1414
1515
16- class Weighting (Generic [ _T ], nn .Module , ABC ):
16+ class Weighting (nn .Module , ABC , Generic [ _T ] ):
1717 r"""
1818 Abstract base class for all weighting methods. It has the role of extracting a vector of weights
1919 of dimension :math:`m` from some statistic of a matrix of dimension :math:`m \times n`,
Original file line number Diff line number Diff line change 1- from collections .abc import Set
1+ from collections .abc import Set as AbstractSet
22
33import torch
44from torch import Tensor
@@ -13,7 +13,7 @@ class Init(Transform):
1313 :param values: Tensors for which Gradients must be returned.
1414 """
1515
16- def __init__ (self , values : Set [Tensor ]):
16+ def __init__ (self , values : AbstractSet [Tensor ]):
1717 self .values = values
1818
1919 def __call__ (self , _ : TensorDict , / ) -> TensorDict :
Original file line number Diff line number Diff line change 1- from collections .abc import Set
1+ from collections .abc import Set as AbstractSet
22
33from torch import Tensor
44
@@ -12,7 +12,7 @@ class Select(Transform):
1212 :param keys: The keys that should be included in the returned subset.
1313 """
1414
15- def __init__ (self , keys : Set [Tensor ]):
15+ def __init__ (self , keys : AbstractSet [Tensor ]):
1616 self .keys = keys
1717
1818 def __call__ (self , tensor_dict : TensorDict , / ) -> TensorDict :
You can’t perform that action at this time.
0 commit comments