Skip to content

Commit 91cd8f7

Browse files
author
FA Saad
committed
Merge branch '20170904-fsaad-log-constants-store'
2 parents 69c089d + 43669e1 commit 91cd8f7

1 file changed

Lines changed: 19 additions & 11 deletions

File tree

src/primitives/normal.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
# limitations under the License.
1616

1717
from math import lgamma
18+
from math import log
19+
from math import pi
1820

1921
import numpy as np
2022

2123
from cgpm.primitives.distribution import DistributionGpm
2224
from cgpm.utils import general as gu
2325

2426

27+
LOG2 = log(2)
28+
LOGPI = log(pi)
29+
LOG2PI = LOG2 + LOGPI
30+
31+
2532
class Normal(DistributionGpm):
2633
"""Normal distribution with normal prior on mean and gamma prior on
2734
precision. Collapsed.
@@ -157,39 +164,40 @@ def is_numeric():
157164

158165
@staticmethod
159166
def calc_predictive_logp(x, N, sum_x, sum_x_sq, m, r, s, nu):
160-
mn, rn, sn, nun = Normal.posterior_hypers(
167+
_mn, rn, sn, nun = Normal.posterior_hypers(
161168
N, sum_x, sum_x_sq, m, r, s, nu)
162-
mm, rm, sm, num = Normal.posterior_hypers(
169+
_mm, rm, sm, num = Normal.posterior_hypers(
163170
N+1, sum_x+x, sum_x_sq+x*x, m, r, s, nu)
164171
ZN = Normal.calc_log_Z(rn, sn, nun)
165172
ZM = Normal.calc_log_Z(rm, sm, num)
166-
return -.5 * np.log(2*np.pi) + ZM - ZN
173+
return -.5 * LOG2PI + ZM - ZN
167174

168175
@staticmethod
169176
def calc_logpdf_marginal(N, sum_x, sum_x_sq, m, r, s, nu):
170-
mn, rn, sn, nun = Normal.posterior_hypers(
177+
_mn, rn, sn, nun = Normal.posterior_hypers(
171178
N, sum_x, sum_x_sq, m, r, s, nu)
172179
Z0 = Normal.calc_log_Z(r, s, nu)
173180
ZN = Normal.calc_log_Z(rn, sn, nun)
174-
return -(N/2.) * np.log(2*np.pi) + ZN - Z0
181+
return -(N/2.) * LOG2PI + ZN - Z0
175182

176183
@staticmethod
177184
def posterior_hypers(N, sum_x, sum_x_sq, m, r, s, nu):
178185
rn = r + float(N)
179186
nun = nu + float(N)
180187
mn = (r*m + sum_x)/rn
181188
sn = s + sum_x_sq + r*m*m - rn*mn*mn
182-
if sn == 0: sn = s
189+
if sn == 0:
190+
sn = s
183191
return mn, rn, sn, nun
184192

185193
@staticmethod
186194
def calc_log_Z(r, s, nu):
187195
return (
188-
((nu + 1.) / 2.) * np.log(2)
189-
+ .5 * np.log(np.pi)
190-
- .5 * np.log(r)
191-
- (nu/2.) * np.log(s)
192-
+ lgamma(nu/2.0))
196+
((nu + 1.) / 2.) * LOG2
197+
+ .5 * LOGPI
198+
- .5 * log(r)
199+
- (nu/2.) * log(s)
200+
+ lgamma(nu/2.))
193201

194202
@staticmethod
195203
def sample_parameters(m, r, s, nu, rng):

0 commit comments

Comments
 (0)