|
3 | 3 | from aesara.graph.fg import FunctionGraph |
4 | 4 | from aesara.graph.kanren import KanrenRelationSub |
5 | 5 |
|
6 | | -from aemcmc.transforms import location_scale_transform |
| 6 | +from aemcmc.transforms import invgamma_exponential, location_scale_transform |
7 | 7 |
|
8 | 8 |
|
9 | 9 | def test_normal_scale_loc_transform_lift(): |
@@ -45,3 +45,41 @@ def test_normal_scale_loc_transform_sink(): |
45 | 45 | )[0] |
46 | 46 |
|
47 | 47 | assert isinstance(res.owner.op, type(at.random.normal)) |
| 48 | + |
| 49 | + |
| 50 | +def test_invgamma_to_exp(): |
| 51 | + |
| 52 | + srng = at.random.RandomStream(0) |
| 53 | + c_at = at.scalar() |
| 54 | + X_rv = srng.invgamma(1.0, c_at) |
| 55 | + |
| 56 | + fgraph = FunctionGraph(outputs=[X_rv], clone=False) |
| 57 | + res = KanrenRelationSub(invgamma_exponential).transform( |
| 58 | + fgraph, fgraph.outputs[0].owner |
| 59 | + )[0] |
| 60 | + |
| 61 | + Y_rv = 1. / srng.exponential(c_at) |
| 62 | + |
| 63 | + assert res.owner.op == Y_rv.owner.op |
| 64 | + assert isinstance(res.owner.inputs[1].owner.op, type(Y_rv.owner.inputs[1].owner.op)) |
| 65 | + assert res.owner.inputs[1].owner.inputs[-1] == c_at |
| 66 | + |
| 67 | + |
| 68 | +@pytest.mark.xfail( |
| 69 | + reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error" |
| 70 | +) |
| 71 | +def test_invgamma_from_exp(): |
| 72 | + |
| 73 | + srng = at.random.RandomStream(0) |
| 74 | + c_at = at.scalar() |
| 75 | + X_rv = 1.0 / srng.exponential(c_at) |
| 76 | + |
| 77 | + fgraph = FunctionGraph(outputs=[X_rv], clone=False) |
| 78 | + res = KanrenRelationSub(lambda x, y: invgamma_exponential(y, x)).transform( |
| 79 | + fgraph, fgraph.outputs[0].owner |
| 80 | + )[0] |
| 81 | + |
| 82 | + Y_rv = srng.invgamma(1.0, c_at) |
| 83 | + |
| 84 | + assert isinstance(res.owner.op , type(Y_rv.owner.op)) |
| 85 | + assert res.owner.inputs[-1] == c_at |
0 commit comments