66from torchjd ._linalg import PSDGeneralizedMatrix , PSDMatrix
77
88
9+ def flatten (gramian : PSDGeneralizedMatrix ) -> PSDMatrix :
10+ """
11+ Flattens a generalized Gramian into a square matrix. The first half of the dimensions are
12+ flattened into the first dimension, and the second half are flattened into the second.
13+
14+ :param gramian: Gramian to flatten. Can be a generalized Gramian.
15+ """
16+
17+ # Example: `gramian` of shape [2, 3, 4, 4, 3, 2]:
18+ # [2, 3, 4, 4, 3, 2] yields a gramian of shape [24, 24]
19+
20+ k = gramian .ndim // 2
21+ shape = gramian .shape [:k ]
22+ m = prod (shape )
23+ square_gramian = reshape (gramian , [m ])
24+ return cast (PSDMatrix , square_gramian )
25+
26+
927def reshape (gramian : PSDGeneralizedMatrix , half_shape : list [int ]) -> PSDGeneralizedMatrix :
1028 """
1129 Reshapes a Gramian to a provided shape. The reshape of the first half of the target dimensions
@@ -23,18 +41,8 @@ def reshape(gramian: PSDGeneralizedMatrix, half_shape: list[int]) -> PSDGenerali
2341 # Example 2: `gramian` of shape [24, 24] and `half_shape` of [4, 3, 2]:
2442 # [24, 24] -(movedim)-> [24, 24] -(reshape)-> [4, 3, 2, 4, 3, 2] -(movedim)-> [4, 3, 2, 2, 3, 4]
2543
26- reshaped_gramian = _revert_last_dims (
27- _revert_last_dims (gramian ).reshape (half_shape + half_shape )
28- )
29- return cast (PSDGeneralizedMatrix , reshaped_gramian )
30-
31-
32- def flatten (gramian : PSDGeneralizedMatrix ) -> PSDMatrix :
33- k = gramian .ndim // 2
34- half_shape = gramian .shape [:k ]
35- m = prod (half_shape )
36- square_gramian = reshape (gramian , [m ])
37- return cast (PSDMatrix , square_gramian )
44+ result = _revert_last_dims (_revert_last_dims (gramian ).reshape (half_shape + half_shape ))
45+ return cast (PSDGeneralizedMatrix , result )
3846
3947
4048def _revert_last_dims (t : Tensor ) -> Tensor :
0 commit comments