diff --git a/tutorials/sum_product_network.ipynb b/tutorials/sum_product_network.ipynb new file mode 100644 index 000000000..8aaf573c9 --- /dev/null +++ b/tutorials/sum_product_network.ipynb @@ -0,0 +1,309 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sum Product Network" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "\n", + "import jax\n", + "import numpy as np\n", + "\n", + "import funsor\n", + "import funsor.jax.distributions as dist\n", + "import funsor.ops as ops\n", + "\n", + "funsor.set_backend(\"jax\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### network" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tensor([[[0.03408 0.03712 ]\n", + " [0.05712 0.07167999]]\n", + "\n", + " [[0.13632001 0.14848001]\n", + " [0.22848003 0.28672004]]], OrderedDict([('v0', Bint[2, ]), ('v1', Bint[2, ]), ('v2', Bint[2, ])]), 'real')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# sum_op = +, prod_op = *\n", + "# alternatively, we can use rewrite_ops as in\n", + "# https://github.com/pyro-ppl/funsor/pull/456\n", + "# and switch to sum_op = logsumexp, prod_op = +\n", + "# FIXME: what is the best way to set constraints to the weights\n", + "spn = 0.4 * (dist.Categorical(np.array([0.2, 0.8]), value=\"v0\").exp() *\n", + " (0.3 * (dist.Categorical(np.array([0.3, 0.7]), value=\"v1\").exp() *\n", + " dist.Categorical(np.array([0.4, 0.6]), value=\"v2\").exp())\n", + " + 0.7 * (dist.Categorical(np.array([0.5, 0.5]), value=\"v1\").exp() *\n", + " dist.Categorical(np.array([0.6, 0.4]), value=\"v2\").exp()))) \\\n", + " + 0.6 * (dist.Categorical(np.array([0.2, 0.8]), value=\"v0\").exp() *\n", + " dist.Categorical(np.array([0.3, 0.7]), value=\"v1\").exp() *\n", + " dist.Categorical(np.array([0.4, 0.6]), value=\"v2\").exp())\n", + "spn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### marginalize" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor([[0.17040001 0.18560001]\n", + " [0.28560004 0.35840005]], OrderedDict([('v1', Bint[2, ]), ('v2', Bint[2, ])]))\n" + ] + } + ], + "source": [ + "spn_marg = spn.reduce(ops.add, \"v0\")\n", + "print(spn_marg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### likelihood" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "test_data = {\"v0\": 1, \"v1\": 0, \"v2\": 1}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-1.9073049 0.14848001\n" + ] + } + ], + "source": [ + "ll_exp = spn(**test_data)\n", + "print(ll_exp.log(), ll_exp)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-1.6841614 0.18560001\n" + ] + } + ], + "source": [ + "llm_exp = spn_marg(**test_data)\n", + "print(llm_exp.log(), llm_exp)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-1.6841614 0.18560001\n" + ] + } + ], + "source": [ + "test_data2 = {\"v1\": 0, \"v2\": 1}\n", + "llom_exp = spn(**test_data2).reduce(ops.add)\n", + "print(llom_exp.log(), llom_exp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### sample" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Delta((('v0', (Tensor([1 0 1 1 0], OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor([-0.8297847 -0.8297847 -0.8297847 -0.8297847 -0.8297847], OrderedDict([('particle', Bint[5, ])]), 'real').reduce(nullop, set())" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample_inputs = OrderedDict(particle=funsor.Bint[5])\n", + "spn(v1=0, v2=0).sample(frozenset({\"v0\"}), sample_inputs, jax.random.PRNGKey(0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "what is `-0.8297847`? a normalization factor? why the latter term is a constant in torch but it is an array in jax" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### train parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### parameter optimization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### most probable explanation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### multivariate leaf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### cutset networks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### expectations and moments" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Integrate(q, x, q_vars)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### pareto" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-0.523248\n" + ] + } + ], + "source": [ + "spn = 0.3 * dist.Pareto(1., 2., value=\"v0\").exp() + 0.7 * dist.Pareto(1., 3., value=\"v0\").exp()\n", + "print(spn(v0=1.5).log())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}