-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSVGD.py
More file actions
38 lines (30 loc) · 1.36 KB
/
SVGD.py
File metadata and controls
38 lines (30 loc) · 1.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
class SVGD:
def __init__(self, n_particles, n_dim):
'''
:param n_particles: int
:param n_dim: int
'''
self.n_particles = n_particles
self.particles = torch.randn(self.n_particles, n_dim) * 5
def step(self, target_distribution_logprob, kernel, learning_rate=0.1):
'''
:param target_distribution_logprob: pytorch/pyro callable, return log_prob for each input x
:param kernel: pytorch callable, need to be able to broadcast for one input
:param learning_rate: float
:return: None
'''
def dlog():
particles = self.particles.clone().detach()
particles.requires_grad_()
torch.sum(target_distribution_logprob(particles)).backward() # Because sum is linear
return particles.grad
def dkernel(x):
particles = self.particles.clone().detach()
particles.requires_grad_()
torch.sum(kernel(particles, x)).backward()
return particles.grad
phi = lambda x: torch.mean(kernel(self.particles, x).reshape(-1, 1) * dlog() + dkernel(x), dim=0)
phis = torch.stack([phi(self.particles[i, :]) for i in range(self.n_particles)], dim=0)
self.particles = self.particles + learning_rate * phis
return self.particles