Skip to content

Commit a781602

Browse files
committed
Make functions in _gramian_utils public (for the same package)
1 parent 0a6244f commit a781602

File tree

5 files changed

+11
-11
lines changed

5 files changed

+11
-11
lines changed

src/torchjd/aggregation/_gramian_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from torch import Tensor
33

44

5-
def _compute_gramian(matrix: Tensor) -> Tensor:
5+
def compute_gramian(matrix: Tensor) -> Tensor:
66
"""
77
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
88
"""
99

1010
return matrix @ matrix.T
1111

1212

13-
def _normalize(gramian: Tensor, eps: float) -> Tensor:
13+
def normalize(gramian: Tensor, eps: float) -> Tensor:
1414
"""
1515
Normalizes the gramian with respect to the Frobenius norm.
1616
@@ -25,7 +25,7 @@ def _normalize(gramian: Tensor, eps: float) -> Tensor:
2525
return gramian / squared_frobenius_norm
2626

2727

28-
def _regularize(gramian: Tensor, eps: float) -> Tensor:
28+
def regularize(gramian: Tensor, eps: float) -> Tensor:
2929
"""
3030
Adds a regularization term to the gramian to enforce positive definiteness.
3131

src/torchjd/aggregation/cagrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55

6-
from ._gramian_utils import _compute_gramian, _normalize
6+
from ._gramian_utils import compute_gramian, normalize
77
from .bases import _WeightedAggregator, _Weighting
88

99

@@ -72,7 +72,7 @@ def __init__(self, c: float, norm_eps: float):
7272
self.norm_eps = norm_eps
7373

7474
def forward(self, matrix: Tensor) -> Tensor:
75-
gramian = _normalize(_compute_gramian(matrix), self.norm_eps)
75+
gramian = normalize(compute_gramian(matrix), self.norm_eps)
7676
U, S, _ = torch.svd(gramian)
7777

7878
reduced_matrix = U @ S.sqrt().diag()

src/torchjd/aggregation/dualproj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import Tensor
44

55
from ._dual_cone_utils import _project_weights
6-
from ._gramian_utils import _compute_gramian, _normalize, _regularize
6+
from ._gramian_utils import compute_gramian, normalize, regularize
77
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
88
from .bases import _WeightedAggregator, _Weighting
99
from .mean import _MeanWeighting
@@ -100,6 +100,6 @@ def __init__(
100100

101101
def forward(self, matrix: Tensor) -> Tensor:
102102
u = self.weighting(matrix)
103-
G = _regularize(_normalize(_compute_gramian(matrix), self.norm_eps), self.reg_eps)
103+
G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps)
104104
w = _project_weights(u, G, self.solver)
105105
return w

src/torchjd/aggregation/mgda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import Tensor
33

4-
from ._gramian_utils import _compute_gramian
4+
from ._gramian_utils import compute_gramian
55
from .bases import _WeightedAggregator, _Weighting
66

77

@@ -57,7 +57,7 @@ def __init__(self, epsilon: float, max_iters: int):
5757
self.max_iters = max_iters
5858

5959
def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
60-
gramian = _compute_gramian(matrix)
60+
gramian = compute_gramian(matrix)
6161
device = matrix.device
6262
dtype = matrix.dtype
6363

src/torchjd/aggregation/upgrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55

66
from ._dual_cone_utils import _project_weights
7-
from ._gramian_utils import _compute_gramian, _normalize, _regularize
7+
from ._gramian_utils import compute_gramian, normalize, regularize
88
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
99
from .bases import _WeightedAggregator, _Weighting
1010
from .mean import _MeanWeighting
@@ -96,6 +96,6 @@ def __init__(
9696

9797
def forward(self, matrix: Tensor) -> Tensor:
9898
U = torch.diag(self.weighting(matrix))
99-
G = _regularize(_normalize(_compute_gramian(matrix), self.norm_eps), self.reg_eps)
99+
G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps)
100100
W = _project_weights(U, G, self.solver)
101101
return torch.sum(W, dim=0)

0 commit comments

Comments
 (0)