Skip to content

Commit a80f3f6

Browse files
committed
Add overload for compute_gramian when t is matrix and contracted_dims is 1
1 parent 55bc6f8 commit a80f3f6

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/torchjd/_linalg/_gramian.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55

6-
from ._matrix import PSDMatrix, PSDTensor
6+
from ._matrix import Matrix, PSDMatrix, PSDTensor
77

88

99
@overload
@@ -16,6 +16,11 @@ def compute_gramian(t: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix:
1616
pass
1717

1818

19+
@overload
20+
def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix:
21+
pass
22+
23+
1924
def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
2025
"""
2126
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of the input.

0 commit comments

Comments
 (0)