Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 35477d4

Browse files
committed
Add the transformation between the inverse gamma and the exponential
1 parent 5b5ae2f commit 35477d4

2 files changed

Lines changed: 90 additions & 1 deletion

File tree

aemcmc/transforms.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,54 @@ def location_scale_transform(in_expr, out_expr):
7070
eq(out_expr, noncentered_et),
7171
location_scale_family(distribution_lv),
7272
)
73+
74+
75+
def invgamma_exponential(invgamma_expr, invexponential_expr):
76+
r"""Produce a goal that represents the relation between the inverse gamma distribution
77+
and the inverse of an exponential distribution.
78+
79+
.. math::
80+
81+
\begin{equation*}
82+
\frac{
83+
X \sim \operatorname{Gamma^{-1}}\left(1, c\right)
84+
}{
85+
Y = 1 / X, \quad
86+
Y \sim \operatorname{Exp}\left(c\right)
87+
}
88+
\end{equation*}
89+
90+
TODO: This is a particular case of a more general relation between the inverse gamma
91+
and the gamma distribution (of which the exponential distribution is a special case).
92+
We should implement this more general relation, and the special case separately in the
93+
future.
94+
95+
Parameters
96+
----------
97+
invgamma_expr
98+
An expression that represents a random variable with an inverse gamma
99+
distribution with a shape parameter equal to 1.
100+
invexponential_expr
101+
An expression that represents the inverse of a random variable with an
102+
exponential distribution.
103+
104+
"""
105+
c_lv = var()
106+
rng_lv, size_lv, dtype_lv = var(), var(), var()
107+
108+
invgamma_et = etuple(
109+
etuplize(at.random.invgamma), rng_lv, size_lv, dtype_lv, at.as_tensor(1.0), c_lv
110+
)
111+
112+
exponential_et = etuple(
113+
etuplize(at.random.exponential),
114+
c_lv,
115+
rng=rng_lv,
116+
size=size_lv,
117+
dtype=dtype_lv,
118+
)
119+
invexponential_et = etuple(at.true_div, at.as_tensor(1.0), exponential_et)
120+
121+
return lall(
122+
eq(invgamma_expr, invgamma_et), eq(invexponential_expr, invexponential_et)
123+
)

tests/test_transforms.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from aesara.graph.fg import FunctionGraph
44
from aesara.graph.kanren import KanrenRelationSub
55

6-
from aemcmc.transforms import location_scale_transform
6+
from aemcmc.transforms import invgamma_exponential, location_scale_transform
77

88

99
def test_normal_scale_loc_transform_lift():
@@ -45,3 +45,41 @@ def test_normal_scale_loc_transform_sink():
4545
)[0]
4646

4747
assert isinstance(res.owner.op, type(at.random.normal))
48+
49+
50+
def test_invgamma_to_exp():
51+
52+
srng = at.random.RandomStream(0)
53+
c_at = at.scalar()
54+
X_rv = srng.invgamma(1.0, c_at)
55+
56+
fgraph = FunctionGraph(outputs=[X_rv], clone=False)
57+
res = KanrenRelationSub(invgamma_exponential).transform(
58+
fgraph, fgraph.outputs[0].owner
59+
)[0]
60+
61+
Y_rv = 1. / srng.exponential(c_at)
62+
63+
assert res.owner.op == Y_rv.owner.op
64+
assert isinstance(res.owner.inputs[1].owner.op, type(Y_rv.owner.inputs[1].owner.op))
65+
assert res.owner.inputs[1].owner.inputs[-1] == c_at
66+
67+
68+
@pytest.mark.xfail(
69+
reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error"
70+
)
71+
def test_invgamma_from_exp():
72+
73+
srng = at.random.RandomStream(0)
74+
c_at = at.scalar()
75+
X_rv = 1.0 / srng.exponential(c_at)
76+
77+
fgraph = FunctionGraph(outputs=[X_rv], clone=False)
78+
res = KanrenRelationSub(lambda x, y: invgamma_exponential(y, x)).transform(
79+
fgraph, fgraph.outputs[0].owner
80+
)[0]
81+
82+
Y_rv = srng.invgamma(1.0, c_at)
83+
84+
assert isinstance(res.owner.op , type(Y_rv.owner.op))
85+
assert res.owner.inputs[-1] == c_at

0 commit comments

Comments
 (0)