Skip to content

Commit 481bf67

Browse files
committed
Make GramianWeightedAggregator generic
1 parent 86fe403 commit 481bf67

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from typing import Generic, TypeVar
23

34
from torch import Tensor, nn
45

@@ -68,7 +69,10 @@ def forward(self, matrix: Matrix, /) -> Tensor:
6869
return vector
6970

7071

71-
class GramianWeightedAggregator(WeightedAggregator):
72+
_T = TypeVar("_T", bound=Weighting[PSDMatrix])
73+
74+
75+
class GramianWeightedAggregator(WeightedAggregator, Generic[_T]):
7276
"""
7377
WeightedAggregator that computes the gramian of the input jacobian matrix before applying a
7478
Weighting to it.
@@ -77,6 +81,6 @@ class GramianWeightedAggregator(WeightedAggregator):
7781
gramian.
7882
"""
7983

80-
def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None:
84+
def __init__(self, gramian_weighting: _T) -> None:
8185
super().__init__(gramian_weighting << compute_gramian)
8286
self.gramian_weighting = gramian_weighting

0 commit comments

Comments
 (0)