Skip to content

Commit 7da352c

Browse files
committed
Remove GeneralizedMatrix
1 parent a793693 commit 7da352c

File tree

4 files changed

+8
-20
lines changed

4 files changed

+8
-20
lines changed

src/torchjd/_linalg/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from ._gramian import compute_gramian, normalize, regularize
22
from ._matrix import (
3-
GeneralizedMatrix,
43
Matrix,
54
PSDGeneralizedMatrix,
65
PSDMatrix,
7-
is_generalized_matrix,
86
is_matrix,
97
is_psd_generalized_matrix,
108
is_psd_matrix,
@@ -14,11 +12,9 @@
1412
"compute_gramian",
1513
"normalize",
1614
"regularize",
17-
"GeneralizedMatrix",
1815
"Matrix",
1916
"PSDMatrix",
2017
"PSDGeneralizedMatrix",
21-
"is_generalized_matrix",
2218
"is_matrix",
2319
"is_psd_matrix",
2420
"is_psd_generalized_matrix",

src/torchjd/_linalg/_gramian.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
from typing import Literal, cast, overload
22

33
import torch
4+
from torch import Tensor
45

5-
from ._matrix import GeneralizedMatrix, PSDGeneralizedMatrix, PSDMatrix
6+
from ._matrix import PSDGeneralizedMatrix, PSDMatrix
67

78

89
@overload
9-
def compute_gramian(matrix: GeneralizedMatrix) -> PSDMatrix:
10+
def compute_gramian(matrix: Tensor) -> PSDMatrix:
1011
pass
1112

1213

1314
@overload
14-
def compute_gramian(matrix: GeneralizedMatrix, contracted_dims: Literal[-1]) -> PSDMatrix:
15+
def compute_gramian(matrix: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix:
1516
pass
1617

1718

1819
@overload
19-
def compute_gramian(matrix: GeneralizedMatrix, contracted_dims: int) -> PSDGeneralizedMatrix:
20+
def compute_gramian(matrix: Tensor, contracted_dims: int) -> PSDGeneralizedMatrix:
2021
pass
2122

2223

23-
def compute_gramian(matrix: GeneralizedMatrix, contracted_dims: int = -1) -> PSDGeneralizedMatrix:
24+
def compute_gramian(matrix: Tensor, contracted_dims: int = -1) -> PSDGeneralizedMatrix:
2425
"""
2526
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of the input.
2627

src/torchjd/_linalg/_matrix.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
from torch import Tensor
44

55

6-
class GeneralizedMatrix(Tensor):
7-
"""Tensor with a least 1 dimension."""
8-
9-
10-
class Matrix(GeneralizedMatrix):
6+
class Matrix(Tensor):
117
"""Tensor with exactly 2 dimensions."""
128

139

@@ -23,10 +19,6 @@ class PSDMatrix(PSDGeneralizedMatrix, Matrix):
2319
"""Positive semi-definite matrix."""
2420

2521

26-
def is_generalized_matrix(t: Tensor) -> TypeGuard[GeneralizedMatrix]:
27-
return t.ndim >= 1
28-
29-
3022
def is_matrix(t: Tensor) -> TypeGuard[Matrix]:
3123
return t.ndim == 2
3224

tests/unit/linalg/test_gramian.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from utils.asserts import assert_is_psd_matrix
33
from utils.tensors import randn_
44

5-
from torchjd._linalg import compute_gramian, is_generalized_matrix, is_matrix, normalize, regularize
5+
from torchjd._linalg import compute_gramian, is_matrix, normalize, regularize
66

77

88
@mark.parametrize(
@@ -21,7 +21,6 @@
2121
)
2222
def test_gramian_is_psd(shape: list[int]):
2323
matrix = randn_(shape)
24-
assert is_generalized_matrix(matrix)
2524
gramian = compute_gramian(matrix)
2625
assert_is_psd_matrix(gramian)
2726

0 commit comments

Comments
 (0)