Skip to content

Commit 1d69cff

Browse files
committed
Add tests for pmd.CustomDist
Covers both symbolic (dist=) and black-box (logp=/random=) paths: graph comparison against regular distributions, dim propagation, observed data, custom support points, and model variables as params.
1 parent 8b86017 commit 1d69cff

1 file changed

Lines changed: 221 additions & 0 deletions

File tree

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright 2026 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytensor.tensor as pt
16+
import pytest
17+
18+
from pytensor.xtensor import as_xtensor
19+
20+
import pymc.distributions as regular_distributions
21+
22+
from pymc.dims import CustomDist, Normal
23+
from pymc.model.core import Model
24+
from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph
25+
26+
pytestmark = pytest.mark.filterwarnings(
27+
"error",
28+
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
29+
)
30+
31+
32+
class TestCustomDistSymbolic:
33+
"""Tests for the symbolic (dist=) path of pmd.CustomDist."""
34+
35+
def test_basic(self):
36+
"""Symbolic path: dist function wrapping Normal.dist, compared against regular Normal."""
37+
38+
def normal_dist(mu, sigma):
39+
return Normal.dist(mu, sigma)
40+
41+
coords = {"city": range(5)}
42+
with Model(coords=coords) as model:
43+
CustomDist("x", 0, 1, dist=normal_dist, dims="city")
44+
45+
with Model(coords=coords) as reference_model:
46+
regular_distributions.Normal("x", 0, 1, dims="city")
47+
48+
assert_equivalent_random_graph(model, reference_model)
49+
assert_equivalent_logp_graph(model, reference_model)
50+
51+
def test_param_dims_propagate(self):
52+
"""Params with dims propagate to the output."""
53+
54+
def normal_dist(mu, sigma):
55+
return Normal.dist(mu, sigma)
56+
57+
coords = {"city": range(5)}
58+
mu = as_xtensor(np.array([0, 1, 2, 3, 4]), dims=("city",))
59+
sigma = as_xtensor(np.array([1, 2, 3, 4, 5]), dims=("city",))
60+
61+
with Model(coords=coords) as model:
62+
x = CustomDist("x", mu, sigma, dist=normal_dist)
63+
64+
assert set(x.dims) == {"city"}
65+
assert x.type.shape == (5,)
66+
67+
68+
class TestCustomDistBlackbox:
69+
"""Tests for the black-box (logp=/random=) path of pmd.CustomDist."""
70+
71+
def test_logp_basic(self):
72+
"""Black-box path with logp function and dims on output."""
73+
74+
def normal_logp(value, mu, sigma):
75+
v = value.values
76+
return pt.sum(
77+
-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi)))
78+
)
79+
80+
coords = {"city": range(5)}
81+
rng = np.random.default_rng(42)
82+
observed = as_xtensor(rng.normal(0, 1, size=5), dims=("city",))
83+
84+
with Model(coords=coords) as model:
85+
CustomDist(
86+
"x",
87+
0,
88+
1,
89+
logp=normal_logp,
90+
observed=observed,
91+
dims="city",
92+
)
93+
94+
# Test that logp evaluates without error and returns finite values
95+
ip = model.initial_point()
96+
logp_val = model.compile_logp()(ip)
97+
assert np.isfinite(logp_val)
98+
99+
def test_random_logp(self):
100+
"""Black-box path with both random and logp."""
101+
102+
def normal_logp(value, mu, sigma):
103+
v = value.values
104+
return pt.sum(
105+
-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi)))
106+
)
107+
108+
def normal_random(mu, sigma, rng=None, size=None):
109+
return rng.normal(loc=mu, scale=sigma, size=size)
110+
111+
coords = {"city": range(5)}
112+
with Model(coords=coords) as model:
113+
CustomDist(
114+
"x",
115+
0,
116+
1,
117+
logp=normal_logp,
118+
random=normal_random,
119+
dims="city",
120+
)
121+
122+
# Verify shape via draw
123+
from pymc import draw as pm_draw
124+
125+
draws = pm_draw(model["x"], draws=3)
126+
assert draws.shape == (3, 5)
127+
128+
# Verify logp
129+
ip = model.initial_point()
130+
logp_val = model.compile_logp()(ip)
131+
assert np.isfinite(logp_val)
132+
133+
def test_logcdf(self):
134+
"""Black-box path with logcdf function."""
135+
136+
def normal_logp(value, mu, sigma):
137+
v = value.values
138+
return pt.sum(
139+
-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi)))
140+
)
141+
142+
def normal_logcdf(value, mu, sigma):
143+
v = value.values
144+
return pt.sum(
145+
pt.log(pt.erf((v - mu) / (sigma * pt.sqrt(2.0))) + 1.0) - pt.log(pt.constant(2.0))
146+
)
147+
148+
coords = {"city": range(5)}
149+
rng = np.random.default_rng(42)
150+
observed = as_xtensor(rng.normal(0, 1, size=5), dims=("city",))
151+
152+
with Model(coords=coords) as model:
153+
CustomDist(
154+
"x",
155+
0,
156+
1,
157+
logp=normal_logp,
158+
logcdf=normal_logcdf,
159+
observed=observed,
160+
dims="city",
161+
)
162+
163+
ip = model.initial_point()
164+
logp_val = model.compile_logp()(ip)
165+
assert np.isfinite(logp_val)
166+
167+
def test_mu_as_model_var(self):
168+
"""Black-box path with mu as a model variable (no dims on mu)."""
169+
170+
def normal_logp(value, mu, sigma):
171+
v = value.values
172+
return pt.sum(
173+
-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi)))
174+
)
175+
176+
coords = {"city": range(5)}
177+
rng = np.random.default_rng(42)
178+
observed = as_xtensor(rng.normal(0, 1, size=5), dims=("city",))
179+
180+
with Model(coords=coords) as model:
181+
mu = Normal("mu", 0, 1)
182+
CustomDist(
183+
"x",
184+
mu,
185+
1,
186+
logp=normal_logp,
187+
observed=observed,
188+
dims="city",
189+
)
190+
191+
ip = model.initial_point()
192+
logp_val = model.compile_logp()(ip)
193+
assert np.isfinite(logp_val)
194+
195+
def test_support_point(self):
196+
"""Black-box path with custom support_point."""
197+
198+
def normal_logp(value, mu, sigma):
199+
v = value.values
200+
return pt.sum(
201+
-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi)))
202+
)
203+
204+
def custom_support_point(rv, size, mu, sigma):
205+
return pt.full_like(rv, mu)
206+
207+
coords = {"city": range(5)}
208+
with Model(coords=coords) as model:
209+
CustomDist(
210+
"x",
211+
0,
212+
1,
213+
logp=normal_logp,
214+
support_point=custom_support_point,
215+
dims="city",
216+
)
217+
218+
from pymc.distributions.distribution import support_point
219+
220+
sp = support_point(model["x"])
221+
np.testing.assert_allclose(sp.eval(), np.zeros(5))

0 commit comments

Comments
 (0)