-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspd_random.py
More file actions
200 lines (160 loc) · 7.51 KB
/
spd_random.py
File metadata and controls
200 lines (160 loc) · 7.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import torch
from geoopt import SymmetricPositiveDefinite
def sqrt_inv(p):
"""
Computes the inverse square root of a symmetric positive definite matrix or a batch of them.
p_inv_sqrt = p^(-1/2)
Args:
p (torch.Tensor): A single SPD matrix of shape (n, n) or a batch of SPD matrices
of shape (batch_size, n, n).
Returns:
torch.Tensor: The inverse square root of p, with the same shape as the input.
"""
# Check if the input is a single matrix and unsqueeze it to create a batch of 1
is_batch = p.dim() == 3
if not is_batch:
p = p.unsqueeze(0)
eigvals, eigvecs = torch.linalg.eigh(p)
# Compute the inverse square root of eigenvalues.
inv_sqrt_eigvals = 1.0 / torch.sqrt(eigvals + 1e-8)
# Reconstruct the matrix: V * D^(-1/2) * V^T
inv_sqrt_eigvals_matrix = torch.diag_embed(inv_sqrt_eigvals)
result = eigvecs @ inv_sqrt_eigvals_matrix @ eigvecs.transpose(-1, -2)
# If the input was a single matrix, squeeze the batch dimension from the result
if not is_batch:
result = result.squeeze(0)
return result
def Vect_inv(t, p, n):
"""
Constructs a symmetric matrix from a tangent vector 't' and transforms it.
This function is now vectorized to handle batches.
Args:
t (torch.Tensor): A tangent vector of shape (k,) or a batch of tangent vectors
of shape (batch_size, k), where k = n * (n + 1) / 2.
p (torch.Tensor): The base SPD matrix of shape (n, n) or a batch of base matrices
of shape (batch_size, n, n).
n (int): The dimension of the matrices.
Returns:
torch.Tensor: The resulting symmetric matrix/matrices on the manifold.
"""
# Check if the input is a single vector and unsqueeze it to create a batch of 1
is_batch = t.dim() == 2
if not is_batch:
t = t.unsqueeze(0)
p = p.unsqueeze(0)
batch_size = t.shape[0]
# Get the indices for the upper triangle of an n x n matrix
rows, cols = torch.triu_indices(row=n, col=n, device=t.device)
is_diag_mask = (rows == cols)
# Create a scaling vector of shape (k,)
scaling_factor = torch.full_like(t[0], 1.0 / torch.sqrt(torch.tensor(2.0, dtype=t.dtype, device=t.device)))
scaling_factor[is_diag_mask] = 1.0
# Apply the scaling to the entire batch of t vectors via broadcasting
t_scaled = t * scaling_factor
m = torch.zeros((batch_size, n, n), dtype=t.dtype, device=t.device)
m[:, rows, cols] = t_scaled
m[:, cols, rows] = t_scaled
p_inv_2 = sqrt_inv(p)
m_transformed = p_inv_2 @ m @ p_inv_2
# If the input was a single vector, squeeze the batch dimension from the result
if not is_batch:
m_transformed = m_transformed.squeeze(0)
return m_transformed
def random(num_matrices, p, mu, sigma):
"""
Generates one or more random Symmetric Positive Definite (SPD) matrices.
Args:
num_matrices (int): The number of matrices to generate.
p (torch.Tensor): The base point matrix of shape (n, n). Can also be a batch of
base points of shape (num_matrices, n, n).
mu (float | torch.Tensor): Mean of the normal distribution in the tangent space.
- Scalar: same mean for all tangent coordinates
- Shape (k,): per-coordinate mean, where k = n*(n+1)//2
- Shape (num_matrices, k): per-sample mean
sigma (float | torch.Tensor): Dispersion for the normal distribution.
- Scalar or shape (k,) or (num_matrices, k): interpreted as STANDARD DEVIATION(s)
for independent coordinates (backward compatible).
- Shape (k, k) or (num_matrices, k, k): interpreted as full COVARIANCE matrix/matrices.
Returns:
torch.Tensor: A tensor of generated SPD matrices.
Shape is (n, n) if num_matrices is 1.
Shape is (num_matrices, n, n) if num_matrices > 1.
"""
n = p.shape[-1]
M = SymmetricPositiveDefinite()
# Tangent space dimension (upper-triangular including diagonal)
t_dims = (n * (n + 1)) // 2
device = p.device
dtype = p.dtype
# Build mean tensor with proper shape and dtype/device
if torch.is_tensor(mu):
mu_t = mu.to(dtype=dtype, device=device)
else:
mu_t = torch.tensor(mu, dtype=dtype, device=device)
if num_matrices > 1:
if mu_t.dim() == 0:
mean = mu_t.expand(t_dims).repeat(num_matrices, 1)
elif mu_t.shape == (t_dims,):
mean = mu_t.expand(num_matrices, -1)
elif mu_t.shape == (num_matrices, t_dims):
mean = mu_t
else:
raise ValueError(f"mu has incompatible shape {tuple(mu_t.shape)}; expected (), (k,), or (num_matrices, k) with k={t_dims}.")
else:
if mu_t.dim() == 0:
mean = mu_t.expand(t_dims)
elif mu_t.shape == (t_dims,):
mean = mu_t
else:
raise ValueError(f"mu has incompatible shape {tuple(mu_t.shape)}; expected () or (k,) with k={t_dims} for num_matrices=1.")
# Prepare dispersion: std(s) or covariance matrix/matrices
if torch.is_tensor(sigma):
sig_t = sigma.to(dtype=dtype, device=device)
else:
sig_t = torch.tensor(sigma, dtype=dtype, device=device)
is_cov = sig_t.dim() >= 2 and sig_t.shape[-2:] == (t_dims, t_dims)
# Sample tangent vectors
if is_cov:
# Full covariance case; allow broadcasting across batch dimension
mvn = torch.distributions.MultivariateNormal(loc=mean, covariance_matrix=sig_t)
t = mvn.rsample()
else:
# Independent coordinates with provided standard deviation(s)
if num_matrices > 1:
if sig_t.dim() == 0:
std = sig_t.expand(num_matrices, t_dims)
elif sig_t.shape == (t_dims,):
std = sig_t.expand(num_matrices, -1)
elif sig_t.shape == (num_matrices, t_dims):
std = sig_t
else:
raise ValueError(f"sigma has incompatible shape {tuple(sig_t.shape)}; expected (), (k,), or (num_matrices, k) with k={t_dims}.")
else:
if sig_t.dim() == 0:
std = sig_t.expand(t_dims)
elif sig_t.shape == (t_dims,):
std = sig_t
else:
raise ValueError(f"sigma has incompatible shape {tuple(sig_t.shape)}; expected () or (k,) with k={t_dims} for num_matrices=1.")
t = torch.normal(mean=mean, std=std)
# If generating multiple matrices from a single base point, expand 'p' to match the batch size.
p_base = p
if num_matrices > 1 and p.dim() == 2:
p_base = p.expand(num_matrices, -1, -1)
# Create the tangent vectors on the manifold
tangent_vectors = Vect_inv(t, p_base, n)
x = M.expmap(p_base, tangent_vectors)
return x
if __name__ == '__main__':
n_dim = 4
batch_size = 10
base_p = torch.eye(n_dim, dtype=torch.float64)
single_matrix = random(num_matrices=1, p=base_p, mu=0.0, sigma=1.0)
print("Shape of the generated matrix:", single_matrix.shape)
print("Is it symmetric?", torch.allclose(single_matrix, single_matrix.T))
eigvals, _ = torch.linalg.eigh(single_matrix)
print("Are its eigenvalues positive?", (eigvals > 0).all().item())
print("\n")
print("--- Generating a batch of matrices ---")
multiple_matrices = random(num_matrices=batch_size, p=base_p, mu=0.0, sigma=0.5)
print("Shape of the generated batch:", multiple_matrices.shape)