@@ -54,9 +54,8 @@ def __init__(self, dimension=10, b=0.1, V=100):
5454 self ._V = float (V )
5555
5656 # Create phi
57- self ._mean = np .zeros (self ._dimension )
58- self ._cov = np .eye (self ._dimension )
59- self ._phi = scipy .stats .multivariate_normal (self ._mean , self ._cov )
57+ self ._phi = scipy .stats .multivariate_normal (
58+ np .zeros (self ._dimension ), np .eye (self ._dimension ))
6059
6160 def __call__ (self , x ):
6261 y = np .array (x , copy = True )
@@ -98,17 +97,18 @@ def kl_divergence(self, samples):
9897 # - k
9998 # )
10099 #
101- # using s1 = real sigma, as this needs to be inverted and the real one
102- # is more likely to be invertible than the sample one
100+ # For this distribution, s1 is the identify matrix, and m1 is zero,
101+ # so it simplifies to
102+ #
103+ # dkl = 0.5 * (trace(s0) + m0.dot(m0) - log(det(s0)) - k))
104+ #
103105 m0 = np .mean (y , axis = 0 )
104- m1 = self ._mean
105106 s0 = np .cov (y .T )
106- s1 = self ._cov
107- cov_inv = np .linalg .inv (s1 )
108- dkl1 = np .trace (cov_inv .dot (s0 ))
109- dkl2 = np .dot ((m1 - m0 ).T , cov_inv ).dot (m1 - m0 )
110- dkl3 = np .log (np .linalg .det (s1 ) / np .linalg .det (s0 ))
111- return 0.5 * (dkl1 + dkl2 + dkl3 - self ._dimension )
107+ s1 = np .eye (self ._dimension )
108+ cov_inv = s1
109+ return 0.5 * (
110+ np .trace (s0 ) + m0 .dot (m0 )
111+ - np .log (np .linalg .det (s0 )) - self ._dimension )
112112
113113 def n_parameters (self ):
114114 """ See :meth:`pints.LogPDF.n_parameters()`. """
@@ -121,10 +121,8 @@ def sample(self, n):
121121 if n < 0 :
122122 raise ValueError ('Number of samples cannot be negative.' )
123123
124- x = np . random . randn ( n , 2 )
124+ x = self . _phi . rvs ( n )
125125 x [:, 0 ] *= np .sqrt (self ._V )
126126 x [:, 1 ] -= self ._b * (x [:, 0 ] ** 2 - self ._V )
127- if self ._dimension > 2 :
128- x = np .hstack ((x , np .random .randn (n , self ._dimension - 2 )))
129127 return x
130128
0 commit comments