|
15 | 15 | # limitations under the License. |
16 | 16 |
|
17 | 17 | from math import lgamma |
| 18 | +from math import log |
| 19 | +from math import pi |
18 | 20 |
|
19 | 21 | import numpy as np |
20 | 22 |
|
21 | 23 | from cgpm.primitives.distribution import DistributionGpm |
22 | 24 | from cgpm.utils import general as gu |
23 | 25 |
|
24 | 26 |
|
| 27 | +LOG2 = log(2) |
| 28 | +LOGPI = log(pi) |
| 29 | +LOG2PI = LOG2 + LOGPI |
| 30 | + |
| 31 | + |
25 | 32 | class Normal(DistributionGpm): |
26 | 33 | """Normal distribution with normal prior on mean and gamma prior on |
27 | 34 | precision. Collapsed. |
@@ -157,39 +164,40 @@ def is_numeric(): |
157 | 164 |
|
158 | 165 | @staticmethod |
159 | 166 | def calc_predictive_logp(x, N, sum_x, sum_x_sq, m, r, s, nu): |
160 | | - mn, rn, sn, nun = Normal.posterior_hypers( |
| 167 | + _mn, rn, sn, nun = Normal.posterior_hypers( |
161 | 168 | N, sum_x, sum_x_sq, m, r, s, nu) |
162 | | - mm, rm, sm, num = Normal.posterior_hypers( |
| 169 | + _mm, rm, sm, num = Normal.posterior_hypers( |
163 | 170 | N+1, sum_x+x, sum_x_sq+x*x, m, r, s, nu) |
164 | 171 | ZN = Normal.calc_log_Z(rn, sn, nun) |
165 | 172 | ZM = Normal.calc_log_Z(rm, sm, num) |
166 | | - return -.5 * np.log(2*np.pi) + ZM - ZN |
| 173 | + return -.5 * LOG2PI + ZM - ZN |
167 | 174 |
|
168 | 175 | @staticmethod |
169 | 176 | def calc_logpdf_marginal(N, sum_x, sum_x_sq, m, r, s, nu): |
170 | | - mn, rn, sn, nun = Normal.posterior_hypers( |
| 177 | + _mn, rn, sn, nun = Normal.posterior_hypers( |
171 | 178 | N, sum_x, sum_x_sq, m, r, s, nu) |
172 | 179 | Z0 = Normal.calc_log_Z(r, s, nu) |
173 | 180 | ZN = Normal.calc_log_Z(rn, sn, nun) |
174 | | - return -(N/2.) * np.log(2*np.pi) + ZN - Z0 |
| 181 | + return -(N/2.) * LOG2PI + ZN - Z0 |
175 | 182 |
|
176 | 183 | @staticmethod |
177 | 184 | def posterior_hypers(N, sum_x, sum_x_sq, m, r, s, nu): |
178 | 185 | rn = r + float(N) |
179 | 186 | nun = nu + float(N) |
180 | 187 | mn = (r*m + sum_x)/rn |
181 | 188 | sn = s + sum_x_sq + r*m*m - rn*mn*mn |
182 | | - if sn == 0: sn = s |
| 189 | + if sn == 0: |
| 190 | + sn = s |
183 | 191 | return mn, rn, sn, nun |
184 | 192 |
|
185 | 193 | @staticmethod |
186 | 194 | def calc_log_Z(r, s, nu): |
187 | 195 | return ( |
188 | | - ((nu + 1.) / 2.) * np.log(2) |
189 | | - + .5 * np.log(np.pi) |
190 | | - - .5 * np.log(r) |
191 | | - - (nu/2.) * np.log(s) |
192 | | - + lgamma(nu/2.0)) |
| 196 | + ((nu + 1.) / 2.) * LOG2 |
| 197 | + + .5 * LOGPI |
| 198 | + - .5 * log(r) |
| 199 | + - (nu/2.) * log(s) |
| 200 | + + lgamma(nu/2.)) |
193 | 201 |
|
194 | 202 | @staticmethod |
195 | 203 | def sample_parameters(m, r, s, nu, rng): |
|
0 commit comments