Skip to content

Commit f2d0d1b

Browse files
committed
Improve style
1 parent 3d9742c commit f2d0d1b

1 file changed

Lines changed: 20 additions & 12 deletions

File tree

src/torchjd/autogram/_gramian_utils.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@
66
from torchjd._linalg import PSDGeneralizedMatrix, PSDMatrix
77

88

9+
def flatten(gramian: PSDGeneralizedMatrix) -> PSDMatrix:
10+
"""
11+
Flattens a generalized Gramian into a square matrix. The first half of the dimensions are
12+
flattened into the first dimension, and the second half are flattened into the second.
13+
14+
:param gramian: Gramian to flatten. Can be a generalized Gramian.
15+
"""
16+
17+
# Example: `gramian` of shape [2, 3, 4, 4, 3, 2]:
18+
# [2, 3, 4, 4, 3, 2] yields a gramian of shape [24, 24]
19+
20+
k = gramian.ndim // 2
21+
shape = gramian.shape[:k]
22+
m = prod(shape)
23+
square_gramian = reshape(gramian, [m])
24+
return cast(PSDMatrix, square_gramian)
25+
26+
927
def reshape(gramian: PSDGeneralizedMatrix, half_shape: list[int]) -> PSDGeneralizedMatrix:
1028
"""
1129
Reshapes a Gramian to a provided shape. The reshape of the first half of the target dimensions
@@ -23,18 +41,8 @@ def reshape(gramian: PSDGeneralizedMatrix, half_shape: list[int]) -> PSDGenerali
2341
# Example 2: `gramian` of shape [24, 24] and `half_shape` of [4, 3, 2]:
2442
# [24, 24] -(movedim)-> [24, 24] -(reshape)-> [4, 3, 2, 4, 3, 2] -(movedim)-> [4, 3, 2, 2, 3, 4]
2543

26-
reshaped_gramian = _revert_last_dims(
27-
_revert_last_dims(gramian).reshape(half_shape + half_shape)
28-
)
29-
return cast(PSDGeneralizedMatrix, reshaped_gramian)
30-
31-
32-
def flatten(gramian: PSDGeneralizedMatrix) -> PSDMatrix:
33-
k = gramian.ndim // 2
34-
half_shape = gramian.shape[:k]
35-
m = prod(half_shape)
36-
square_gramian = reshape(gramian, [m])
37-
return cast(PSDMatrix, square_gramian)
44+
result = _revert_last_dims(_revert_last_dims(gramian).reshape(half_shape + half_shape))
45+
return cast(PSDGeneralizedMatrix, result)
3846

3947

4048
def _revert_last_dims(t: Tensor) -> Tensor:

0 commit comments

Comments
 (0)