Skip to content

Commit 049c2d0

Browse files
refactor: Move autogram._gramian_utils to _linalg.generalized_gramian (#627)
1 parent 76941a1 commit 049c2d0

File tree

7 files changed

+10
-11
lines changed

7 files changed

+10
-11
lines changed

src/torchjd/_linalg/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ._generalized_gramian import flatten, movedim, reshape
12
from ._gramian import compute_gramian, normalize, regularize
23
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor
34

@@ -11,4 +12,7 @@
1112
"is_matrix",
1213
"is_psd_matrix",
1314
"is_psd_tensor",
15+
"flatten",
16+
"reshape",
17+
"movedim",
1418
]

src/torchjd/autogram/_gramian_utils.py renamed to src/torchjd/_linalg/_generalized_gramian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from torch import Tensor
55

6-
from torchjd._linalg import PSDMatrix, PSDTensor
6+
from torchjd._linalg._matrix import PSDMatrix, PSDTensor
77

88

99
def flatten(gramian: PSDTensor) -> PSDMatrix:

src/torchjd/aggregation/_flattening.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from torch import Tensor
22

3-
from torchjd._linalg import PSDTensor
3+
from torchjd._linalg import PSDTensor, flatten
44
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting
5-
from torchjd.autogram._gramian_utils import flatten
65

76

87
class Flattening(GeneralizedWeighting):

src/torchjd/autogram/_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
from torch import Tensor, nn, vmap
55
from torch.autograd.graph import get_gradient_edge
66

7-
from torchjd._linalg import PSDMatrix
7+
from torchjd._linalg import PSDMatrix, movedim, reshape
88

99
from ._edge_registry import EdgeRegistry
1010
from ._gramian_accumulator import GramianAccumulator
1111
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
12-
from ._gramian_utils import movedim, reshape
1312
from ._jacobian_computer import (
1413
AutogradJacobianComputer,
1514
FunctionalJacobianComputer,

tests/unit/autogram/test_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,9 @@
7979
)
8080
from utils.tensors import make_inputs_and_targets, ones_, randn_, zeros_
8181

82-
from torchjd._linalg import PSDMatrix, compute_gramian
82+
from torchjd._linalg import PSDMatrix, compute_gramian, movedim, reshape
8383
from torchjd.aggregation import UPGradWeighting
8484
from torchjd.autogram._engine import Engine
85-
from torchjd.autogram._gramian_utils import movedim, reshape
8685

8786
PARAMETRIZATIONS = [
8887
(ModuleFactory(OverlyNested), 32),

tests/unit/autogram/test_gramian_utils.py renamed to tests/unit/linalg/test_generalized_gramian.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from utils.asserts import assert_is_psd_matrix, assert_is_psd_tensor
44
from utils.tensors import randn_
55

6-
from torchjd._linalg import compute_gramian, is_psd_matrix
7-
from torchjd.autogram._gramian_utils import flatten, movedim, reshape
6+
from torchjd._linalg import compute_gramian, flatten, is_psd_matrix, movedim, reshape
87

98

109
@mark.parametrize(

tests/utils/asserts.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from torch import Tensor
55
from torch.testing import assert_close
66

7-
from torchjd._linalg import is_psd_matrix, is_psd_tensor
8-
from torchjd.autogram._gramian_utils import flatten
7+
from torchjd._linalg import flatten, is_psd_matrix, is_psd_tensor
98
from torchjd.autojac._accumulation import is_tensor_with_jac
109

1110

0 commit comments

Comments
 (0)