forked from cornellius-gp/linear_operator
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_root_decomposition.py
More file actions
173 lines (149 loc) · 6.54 KB
/
_root_decomposition.py
File metadata and controls
173 lines (149 loc) · 6.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#!/usr/bin/env python3
from __future__ import annotations
import torch
from torch.autograd import Function
from linear_operator import settings
from linear_operator.utils import lanczos
class RootDecomposition(Function):
@staticmethod
def forward(
ctx,
representation_tree,
max_iter,
dtype,
device,
batch_shape,
matrix_shape,
root,
inverse,
initial_vectors,
*matrix_args,
):
r"""
:param list matrix_args: The arguments representing the symmetric matrix A (or batch of PSD matrices A)
:rtype: (torch.Tensor, torch.Tensor)
:return: :attr:`R`, such that :math:`R R^T \approx A`, and :attr:`R_inv`, such that
:math:`R_{inv} R_{inv}^T \approx A^{-1}` (will only be populated if self.inverse = True)
"""
from linear_operator.operators import to_linear_operator
ctx.representation_tree = representation_tree
ctx.device = device
ctx.dtype = dtype
ctx.matrix_shape = matrix_shape
ctx.max_iter = max_iter
ctx.batch_shape = batch_shape
ctx.root = root
ctx.inverse = inverse
ctx.initial_vectors = initial_vectors
# Get closure for matmul
linear_op = ctx.representation_tree(*matrix_args)
matmul_closure = linear_op._matmul
# Do lanczos
q_mat, t_mat = lanczos.lanczos_tridiag(
matmul_closure,
ctx.max_iter,
dtype=ctx.dtype,
device=ctx.device,
matrix_shape=ctx.matrix_shape,
batch_shape=ctx.batch_shape,
init_vecs=ctx.initial_vectors,
)
if ctx.batch_shape is None:
q_mat = q_mat.unsqueeze(-3)
t_mat = t_mat.unsqueeze(-3)
if t_mat.ndimension() == 3: # If we only used one probe vector
q_mat = q_mat.unsqueeze(0)
t_mat = t_mat.unsqueeze(0)
n_probes = t_mat.size(0)
mins = to_linear_operator(t_mat)._diagonal().min(dim=-1, keepdim=True)[0].unsqueeze(-1)
jitter_mat = (settings.tridiagonal_jitter.value() * mins) * torch.eye(
t_mat.size(-1), device=t_mat.device, dtype=t_mat.dtype
).expand_as(t_mat)
eigenvalues, eigenvectors = lanczos.lanczos_tridiag_to_diag(t_mat + jitter_mat)
# Get orthogonal matrix and eigenvalue roots
q_mat = q_mat.matmul(eigenvectors)
root_evals = eigenvalues.sqrt()
# Store q_mat * t_mat_chol
# Decide if we're computing the inverse, or the regular root
root = torch.empty(0, dtype=q_mat.dtype, device=q_mat.device)
inverse = torch.empty(0, dtype=q_mat.dtype, device=q_mat.device)
if ctx.inverse:
inverse = q_mat / root_evals.unsqueeze(-2)
if ctx.root:
root = q_mat * root_evals.unsqueeze(-2)
if settings.memory_efficient.off():
ctx._linear_op = linear_op
if ctx.batch_shape is None:
root = root.squeeze(1) if root.numel() else root
q_mat = q_mat.squeeze(1)
root_evals = root_evals.squeeze(1)
inverse = inverse.squeeze(1) if inverse.numel() else inverse
if n_probes == 1:
root = root.squeeze(0) if root.numel() else root
q_mat = q_mat.squeeze(0)
root_evals = root_evals.squeeze(0)
inverse = inverse.squeeze(0) if inverse.numel() else inverse
to_save = list(matrix_args) + [q_mat, root_evals, inverse]
ctx.save_for_backward(*to_save)
return root, inverse
@staticmethod
def backward(ctx, root_grad_output, inverse_grad_output):
# Taken from http://homepages.inf.ed.ac.uk/imurray2/pub/16choldiff/choldiff.pdf
if any(ctx.needs_input_grad):
def is_empty(tensor):
return tensor.numel() == 0 or (tensor.numel() == 1 and tensor[0] == 0)
# Fix outputs and gradients
if is_empty(root_grad_output):
root_grad_output = None
if is_empty(inverse_grad_output):
inverse_grad_output = None
# Get saved tensors
matrix_args = ctx.saved_tensors[:-3]
q_mat = ctx.saved_tensors[-3]
root_evals = ctx.saved_tensors[-2]
inverse = ctx.saved_tensors[-1]
is_batch = False
if root_grad_output is not None:
if root_grad_output.ndimension() == 2 and q_mat.ndimension() > 2:
root_grad_output = root_grad_output.unsqueeze(0)
is_batch = True
if root_grad_output.ndimension() == 3 and q_mat.ndimension() > 3:
root_grad_output = root_grad_output.unsqueeze(0)
is_batch = True
if inverse_grad_output is not None:
if inverse_grad_output.ndimension() == 2 and q_mat.ndimension() > 2:
inverse_grad_output = inverse_grad_output.unsqueeze(0)
is_batch = True
if inverse_grad_output.ndimension() == 3 and q_mat.ndimension() > 3:
inverse_grad_output = inverse_grad_output.unsqueeze(0)
is_batch = True
# Get closure for matmul
if hasattr(ctx, "_linear_op"):
linear_op = ctx._linear_op
else:
linear_op = ctx.representation_tree(*matrix_args)
# Get root inverse
if not ctx.inverse:
inverse = q_mat / root_evals.unsqueeze(-2)
# Left factor:
left_factor = torch.zeros_like(inverse)
if root_grad_output is not None:
left_factor.add_(root_grad_output)
if inverse_grad_output is not None:
# -root^-T grad_output.T root^-T
left_factor.sub_(torch.matmul(inverse, inverse_grad_output.mT).matmul(inverse))
# Right factor
right_factor = inverse.div(2.0)
# Fix batches
if is_batch:
left_factor = left_factor.permute(1, 0, 2, 3).contiguous()
left_factor = left_factor.view(inverse.size(1), -1, left_factor.size(-1))
right_factor = right_factor.permute(1, 0, 2, 3).contiguous()
right_factor = right_factor.view(inverse.size(1), -1, right_factor.size(-1))
else:
left_factor = left_factor.contiguous()
right_factor = right_factor.contiguous()
res = linear_op._bilinear_derivative(left_factor, right_factor)
return tuple([None] * 9 + list(res))
else:
pass