Skip to content

Commit f09f6b6

Browse files
Add Poisson and NegativeBinomial to pymc.dims (#8305)
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
1 parent 853149c commit f09f6b6

2 files changed

Lines changed: 46 additions & 0 deletions

File tree

pymc/dims/distributions/scalar.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
halfflat,
3737
truncated_normal,
3838
)
39+
from pymc.distributions.discrete import NegativeBinomial as RegularNegativeBinomial
3940
from pymc.util import UNSET
4041

4142

@@ -295,3 +296,22 @@ def xrv_op(self, alpha, beta, core_dims=None, extra_dims=None, rng=None, **kwarg
295296
core_rv = WeibullBetaRV.rv_op(alpha=alpha.values, beta=beta.values).owner.op
296297
xop = ptxr.as_xrv(core_rv)
297298
return xop(alpha, beta, core_dims=core_dims, extra_dims=extra_dims, rng=rng, **kwargs)
299+
300+
301+
@copy_docstring(regular_dists.Poisson)
302+
class Poisson(DimDistribution):
303+
xrv_op = ptxr.poisson
304+
305+
@classmethod
306+
def dist(cls, mu, **kwargs):
307+
return super().dist([mu], **kwargs)
308+
309+
310+
@copy_docstring(regular_dists.NegativeBinomial)
311+
class NegativeBinomial(DimDistribution):
312+
xrv_op = ptxr.nbinom
313+
314+
@classmethod
315+
def dist(cls, mu=None, alpha=None, *, p=None, n=None, **kwargs):
316+
n, p = RegularNegativeBinomial.get_n_p(mu=mu, alpha=alpha, p=p, n=n)
317+
return super().dist([n, p], **kwargs)

tests/dims/distributions/test_scalar.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
InverseGamma,
3232
Laplace,
3333
LogNormal,
34+
NegativeBinomial,
3435
Normal,
36+
Poisson,
3537
StudentT,
3638
TruncatedNormal,
3739
Uniform,
@@ -311,3 +313,27 @@ def test_weibull():
311313

312314
assert_equivalent_random_graph(model, reference_model)
313315
assert_equivalent_logp_graph(model, reference_model)
316+
317+
318+
def test_poisson():
319+
coords = {"a": range(3)}
320+
with Model(coords=coords) as model:
321+
Poisson("x", mu=2.0, dims="a")
322+
323+
with Model(coords=coords) as reference_model:
324+
regular_distributions.Poisson("x", mu=2.0, dims="a")
325+
326+
assert_equivalent_random_graph(model, reference_model)
327+
assert_equivalent_logp_graph(model, reference_model)
328+
329+
330+
def test_negative_binomial():
331+
coords = {"a": range(3)}
332+
with Model(coords=coords) as model:
333+
NegativeBinomial("x", mu=5.0, alpha=2.0, dims="a")
334+
335+
with Model(coords=coords) as reference_model:
336+
regular_distributions.NegativeBinomial("x", mu=5.0, alpha=2.0, dims="a")
337+
338+
assert_equivalent_random_graph(model, reference_model)
339+
assert_equivalent_logp_graph(model, reference_model)

0 commit comments

Comments
 (0)