File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed
Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change 88
99from torchjd ._linalg import PSDTensor , is_psd_tensor
1010
11- _T = TypeVar ("_T" , contravariant = True )
12- _FnInputT = TypeVar ("_FnInputT" )
13- _FnOutputT = TypeVar ("_FnOutputT" )
11+ _T = TypeVar ("_T" , contravariant = True , bound = Tensor )
12+ _FnInputT = TypeVar ("_FnInputT" , bound = Tensor )
13+ _FnOutputT = TypeVar ("_FnOutputT" , bound = Tensor )
1414
1515
1616class Weighting (Generic [_T ], nn .Module , ABC ):
@@ -27,9 +27,11 @@ def __init__(self):
2727 def forward (self , stat : _T ) -> Tensor :
2828 """Computes the vector of weights from the input stat."""
2929
30- def __call__ (self , stat : _T ) -> Tensor :
30+ def __call__ (self , stat : Tensor ) -> Tensor :
3131 """Computes the vector of weights from the input stat and applies all registered hooks."""
3232
33+ # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of
34+ # stat to be Tensor.
3335 return super ().__call__ (stat )
3436
3537 def _compose (self , fn : Callable [[_FnInputT ], _T ]) -> Weighting [_FnInputT ]:
You can’t perform that action at this time.
0 commit comments