Skip to content

Commit d5a36bf

Browse files
committed
Remove _compute_normalized_regularized_gramian and _compute_normalized_gramian as they are now just combinations of _compute_gramian, _regularize and _normalize
1 parent 20c2b41 commit d5a36bf

File tree

4 files changed

+6
-11
lines changed

4 files changed

+6
-11
lines changed

src/torchjd/aggregation/_gramian_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@ def _compute_gramian(matrix: Tensor) -> Tensor:
1010
return matrix @ matrix.T
1111

1212

13-
def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg_eps: float):
14-
normalized_gramian = _compute_normalized_gramian(matrix, norm_eps)
15-
return _regularize(normalized_gramian, reg_eps)
16-
17-
1813
def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor:
1914
gramian = _compute_gramian(matrix)
2015
return _normalize(gramian, eps)

src/torchjd/aggregation/cagrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55

6-
from ._gramian_utils import _compute_normalized_gramian
6+
from ._gramian_utils import _compute_gramian, _normalize
77
from .bases import _WeightedAggregator, _Weighting
88

99

@@ -72,7 +72,7 @@ def __init__(self, c: float, norm_eps: float):
7272
self.norm_eps = norm_eps
7373

7474
def forward(self, matrix: Tensor) -> Tensor:
75-
gramian = _compute_normalized_gramian(matrix, self.norm_eps)
75+
gramian = _normalize(_compute_gramian(matrix), self.norm_eps)
7676
U, S, _ = torch.svd(gramian)
7777

7878
reduced_matrix = U @ S.sqrt().diag()

src/torchjd/aggregation/dualproj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import Tensor
44

55
from ._dual_cone_utils import _project_weights
6-
from ._gramian_utils import _compute_regularized_normalized_gramian
6+
from ._gramian_utils import _compute_gramian, _normalize, _regularize
77
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
88
from .bases import _WeightedAggregator, _Weighting
99
from .mean import _MeanWeighting
@@ -100,6 +100,6 @@ def __init__(
100100

101101
def forward(self, matrix: Tensor) -> Tensor:
102102
u = self.weighting(matrix)
103-
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
103+
G = _regularize(_normalize(_compute_gramian(matrix), self.norm_eps), self.reg_eps)
104104
w = _project_weights(u, G, self.solver)
105105
return w

src/torchjd/aggregation/upgrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55

66
from ._dual_cone_utils import _project_weights
7-
from ._gramian_utils import _compute_regularized_normalized_gramian
7+
from ._gramian_utils import _compute_gramian, _normalize, _regularize
88
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
99
from .bases import _WeightedAggregator, _Weighting
1010
from .mean import _MeanWeighting
@@ -96,6 +96,6 @@ def __init__(
9696

9797
def forward(self, matrix: Tensor) -> Tensor:
9898
U = torch.diag(self.weighting(matrix))
99-
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
99+
G = _regularize(_normalize(_compute_gramian(matrix), self.norm_eps), self.reg_eps)
100100
W = _project_weights(U, G, self.solver)
101101
return torch.sum(W, dim=0)

0 commit comments

Comments
 (0)