From 2fb4ea5b4732b91f8f06c26cea4bf9649b54717d Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 3 Mar 2021 23:52:21 -0600 Subject: [PATCH 1/3] initialize the notebook --- tutorials/sum_product_network.ipynb | 365 ++++++++++++++++++++++++++++ 1 file changed, 365 insertions(+) create mode 100644 tutorials/sum_product_network.ipynb diff --git a/tutorials/sum_product_network.ipynb b/tutorials/sum_product_network.ipynb new file mode 100644 index 000000000..900e39d0a --- /dev/null +++ b/tutorials/sum_product_network.ipynb @@ -0,0 +1,365 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sum Product Network\n", + "\n", + "Some text" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([1, 3], dtype=int32)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.clip(jnp.array([-1, 3]), a_min=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([0., 1., 1.], dtype=float32)" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax.numpy as jnp\n", + "import numpyro\n", + "\n", + "jnp.repeat(jnp.array([0., 1.]), [1, 2])" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "\n", + "import torch\n", + "\n", + "import funsor\n", + "import funsor.torch.distributions as dist\n", + "import funsor.ops as ops\n", + "\n", + "funsor.set_backend(\"torch\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### network" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tensor(tensor([[[0.0341, 0.0371],\n", + " [0.0571, 0.0717]],\n", + "\n", + " [[0.1363, 0.1485],\n", + " [0.2285, 0.2867]]]), OrderedDict([('v0', Bint[2, ]), ('v1', Bint[2, ]), ('v2', Bint[2, ])]), 'real')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# sum_op = +, prod_op = *\n", + "# alternatively, we can use rewrite_ops as in\n", + "# https://github.com/pyro-ppl/funsor/pull/456\n", + "# and switch to sum_op = logsumexp, prod_op = +\n", + "spn = 0.4 * (dist.Categorical(torch.tensor([0.2, 0.8]), value=\"v0\").exp() *\n", + " (0.3 * (dist.Categorical(torch.tensor([0.3, 0.7]), value=\"v1\").exp() *\n", + " dist.Categorical(torch.tensor([0.4, 0.6]), value=\"v2\").exp())\n", + " + 0.7 * (dist.Categorical(torch.tensor([0.5, 0.5]), value=\"v1\").exp() *\n", + " dist.Categorical(torch.tensor([0.6, 0.4]), value=\"v2\").exp()))) \\\n", + " + 0.6 * (dist.Categorical(torch.tensor([0.2, 0.8]), value=\"v0\").exp() *\n", + " dist.Categorical(torch.tensor([0.3, 0.7]), value=\"v1\").exp() *\n", + " dist.Categorical(torch.tensor([0.4, 0.6]), value=\"v2\").exp())\n", + "spn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### marginalize" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(tensor([[0.1704, 0.1856],\n", + " [0.2856, 0.3584]]), OrderedDict([('v1', Bint[2, ]), ('v2', Bint[2, ])]))\n" + ] + } + ], + "source": [ + "spn_marg = spn.reduce(ops.add, \"v0\")\n", + "print(spn_marg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### likelihood" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "test_data = {\"v0\": 1, \"v1\": 0, \"v2\": 1}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.9073) tensor(0.1485)\n" + ] + } + ], + "source": [ + "ll_exp = spn(**test_data)\n", + "print(ll_exp.log(), ll_exp)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.6842) tensor(0.1856)\n" + ] + } + ], + "source": [ + "llm_exp = spn_marg(**test_data)\n", + "print(llm_exp.log(), llm_exp)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.6842) tensor(0.1856)\n" + ] + } + ], + "source": [ + "test_data2 = {\"v1\": 0, \"v2\": 1}\n", + "llom_exp = spn(**test_data2).reduce(ops.add)\n", + "print(llom_exp.log(), llom_exp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### sample" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Delta((('v0', (Tensor(tensor([1, 1, 1, 0, 1]), OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor(-0.8297846913337708, OrderedDict(), 'real').reduce(nullop, set())" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample_inputs = OrderedDict(particle=funsor.Bint[5])\n", + "spn(v1=0, v2=0).sample(frozenset({\"v0\"}), sample_inputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "what is `-0.8297846913337708`? a normalization factor?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### train parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-2.0612e-09)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "-torch.nn.functional.softplus(-torch.tensor(20.))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### parameter optimization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### most probable explanation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### multivariate leaf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### cutset networks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### expectations and moments" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Integrate(q, x, q_vars)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### pareto" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-0.5232)\n" + ] + } + ], + "source": [ + "spn = 0.3 * dist.Pareto(1., 2., value=\"v0\").exp() + 0.7 * dist.Pareto(1., 3., value=\"v0\").exp()\n", + "print(spn(v0=1.5).log())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From dd2e009c2c7d1b20d4f644d6f177e66c184fb2d1 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 3 Mar 2021 23:52:21 -0600 Subject: [PATCH 2/3] initialize the notebook --- tutorials/sum_product_network.ipynb | 365 ++++++++++++++++++++++++++++ 1 file changed, 365 insertions(+) create mode 100644 tutorials/sum_product_network.ipynb diff --git a/tutorials/sum_product_network.ipynb b/tutorials/sum_product_network.ipynb new file mode 100644 index 000000000..900e39d0a --- /dev/null +++ b/tutorials/sum_product_network.ipynb @@ -0,0 +1,365 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sum Product Network\n", + "\n", + "Some text" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([1, 3], dtype=int32)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.clip(jnp.array([-1, 3]), a_min=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([0., 1., 1.], dtype=float32)" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax.numpy as jnp\n", + "import numpyro\n", + "\n", + "jnp.repeat(jnp.array([0., 1.]), [1, 2])" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "\n", + "import torch\n", + "\n", + "import funsor\n", + "import funsor.torch.distributions as dist\n", + "import funsor.ops as ops\n", + "\n", + "funsor.set_backend(\"torch\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### network" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tensor(tensor([[[0.0341, 0.0371],\n", + " [0.0571, 0.0717]],\n", + "\n", + " [[0.1363, 0.1485],\n", + " [0.2285, 0.2867]]]), OrderedDict([('v0', Bint[2, ]), ('v1', Bint[2, ]), ('v2', Bint[2, ])]), 'real')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# sum_op = +, prod_op = *\n", + "# alternatively, we can use rewrite_ops as in\n", + "# https://github.com/pyro-ppl/funsor/pull/456\n", + "# and switch to sum_op = logsumexp, prod_op = +\n", + "spn = 0.4 * (dist.Categorical(torch.tensor([0.2, 0.8]), value=\"v0\").exp() *\n", + " (0.3 * (dist.Categorical(torch.tensor([0.3, 0.7]), value=\"v1\").exp() *\n", + " dist.Categorical(torch.tensor([0.4, 0.6]), value=\"v2\").exp())\n", + " + 0.7 * (dist.Categorical(torch.tensor([0.5, 0.5]), value=\"v1\").exp() *\n", + " dist.Categorical(torch.tensor([0.6, 0.4]), value=\"v2\").exp()))) \\\n", + " + 0.6 * (dist.Categorical(torch.tensor([0.2, 0.8]), value=\"v0\").exp() *\n", + " dist.Categorical(torch.tensor([0.3, 0.7]), value=\"v1\").exp() *\n", + " dist.Categorical(torch.tensor([0.4, 0.6]), value=\"v2\").exp())\n", + "spn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### marginalize" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(tensor([[0.1704, 0.1856],\n", + " [0.2856, 0.3584]]), OrderedDict([('v1', Bint[2, ]), ('v2', Bint[2, ])]))\n" + ] + } + ], + "source": [ + "spn_marg = spn.reduce(ops.add, \"v0\")\n", + "print(spn_marg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### likelihood" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "test_data = {\"v0\": 1, \"v1\": 0, \"v2\": 1}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.9073) tensor(0.1485)\n" + ] + } + ], + "source": [ + "ll_exp = spn(**test_data)\n", + "print(ll_exp.log(), ll_exp)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.6842) tensor(0.1856)\n" + ] + } + ], + "source": [ + "llm_exp = spn_marg(**test_data)\n", + "print(llm_exp.log(), llm_exp)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.6842) tensor(0.1856)\n" + ] + } + ], + "source": [ + "test_data2 = {\"v1\": 0, \"v2\": 1}\n", + "llom_exp = spn(**test_data2).reduce(ops.add)\n", + "print(llom_exp.log(), llom_exp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### sample" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Delta((('v0', (Tensor(tensor([1, 1, 1, 0, 1]), OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor(-0.8297846913337708, OrderedDict(), 'real').reduce(nullop, set())" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample_inputs = OrderedDict(particle=funsor.Bint[5])\n", + "spn(v1=0, v2=0).sample(frozenset({\"v0\"}), sample_inputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "what is `-0.8297846913337708`? a normalization factor?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### train parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-2.0612e-09)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "-torch.nn.functional.softplus(-torch.tensor(20.))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### parameter optimization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### most probable explanation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### multivariate leaf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### cutset networks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### expectations and moments" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Integrate(q, x, q_vars)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### pareto" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-0.5232)\n" + ] + } + ], + "source": [ + "spn = 0.3 * dist.Pareto(1., 2., value=\"v0\").exp() + 0.7 * dist.Pareto(1., 3., value=\"v0\").exp()\n", + "print(spn(v0=1.5).log())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 8070ec4c7ac3b0305eea2b9a0362c513b391ed4e Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 12 Mar 2021 00:28:49 -0600 Subject: [PATCH 3/3] switch to jax backend to make it easier to optimize parameters --- tutorials/sum_product_network.ipynb | 118 ++++++++-------------------- 1 file changed, 31 insertions(+), 87 deletions(-) diff --git a/tutorials/sum_product_network.ipynb b/tutorials/sum_product_network.ipynb index 900e39d0a..8aaf573c9 100644 --- a/tutorials/sum_product_network.ipynb +++ b/tutorials/sum_product_network.ipynb @@ -4,52 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Sum Product Network\n", - "\n", - "Some text" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "DeviceArray([1, 3], dtype=int32)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jnp.clip(jnp.array([-1, 3]), a_min=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "DeviceArray([0., 1., 1.], dtype=float32)" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax.numpy as jnp\n", - "import numpyro\n", - "\n", - "jnp.repeat(jnp.array([0., 1.]), [1, 2])" + "# Sum Product Network" ] }, { @@ -60,13 +15,14 @@ "source": [ "from collections import OrderedDict\n", "\n", - "import torch\n", + "import jax\n", + "import numpy as np\n", "\n", "import funsor\n", - "import funsor.torch.distributions as dist\n", + "import funsor.jax.distributions as dist\n", "import funsor.ops as ops\n", "\n", - "funsor.set_backend(\"torch\")" + "funsor.set_backend(\"jax\")" ] }, { @@ -84,11 +40,11 @@ { "data": { "text/plain": [ - "Tensor(tensor([[[0.0341, 0.0371],\n", - " [0.0571, 0.0717]],\n", + "Tensor([[[0.03408 0.03712 ]\n", + " [0.05712 0.07167999]]\n", "\n", - " [[0.1363, 0.1485],\n", - " [0.2285, 0.2867]]]), OrderedDict([('v0', Bint[2, ]), ('v1', Bint[2, ]), ('v2', Bint[2, ])]), 'real')" + " [[0.13632001 0.14848001]\n", + " [0.22848003 0.28672004]]], OrderedDict([('v0', Bint[2, ]), ('v1', Bint[2, ]), ('v2', Bint[2, ])]), 'real')" ] }, "execution_count": 2, @@ -101,14 +57,15 @@ "# alternatively, we can use rewrite_ops as in\n", "# https://github.com/pyro-ppl/funsor/pull/456\n", "# and switch to sum_op = logsumexp, prod_op = +\n", - "spn = 0.4 * (dist.Categorical(torch.tensor([0.2, 0.8]), value=\"v0\").exp() *\n", - " (0.3 * (dist.Categorical(torch.tensor([0.3, 0.7]), value=\"v1\").exp() *\n", - " dist.Categorical(torch.tensor([0.4, 0.6]), value=\"v2\").exp())\n", - " + 0.7 * (dist.Categorical(torch.tensor([0.5, 0.5]), value=\"v1\").exp() *\n", - " dist.Categorical(torch.tensor([0.6, 0.4]), value=\"v2\").exp()))) \\\n", - " + 0.6 * (dist.Categorical(torch.tensor([0.2, 0.8]), value=\"v0\").exp() *\n", - " dist.Categorical(torch.tensor([0.3, 0.7]), value=\"v1\").exp() *\n", - " dist.Categorical(torch.tensor([0.4, 0.6]), value=\"v2\").exp())\n", + "# FIXME: what is the best way to set constraints to the weights\n", + "spn = 0.4 * (dist.Categorical(np.array([0.2, 0.8]), value=\"v0\").exp() *\n", + " (0.3 * (dist.Categorical(np.array([0.3, 0.7]), value=\"v1\").exp() *\n", + " dist.Categorical(np.array([0.4, 0.6]), value=\"v2\").exp())\n", + " + 0.7 * (dist.Categorical(np.array([0.5, 0.5]), value=\"v1\").exp() *\n", + " dist.Categorical(np.array([0.6, 0.4]), value=\"v2\").exp()))) \\\n", + " + 0.6 * (dist.Categorical(np.array([0.2, 0.8]), value=\"v0\").exp() *\n", + " dist.Categorical(np.array([0.3, 0.7]), value=\"v1\").exp() *\n", + " dist.Categorical(np.array([0.4, 0.6]), value=\"v2\").exp())\n", "spn" ] }, @@ -128,8 +85,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Tensor(tensor([[0.1704, 0.1856],\n", - " [0.2856, 0.3584]]), OrderedDict([('v1', Bint[2, ]), ('v2', Bint[2, ])]))\n" + "Tensor([[0.17040001 0.18560001]\n", + " [0.28560004 0.35840005]], OrderedDict([('v1', Bint[2, ]), ('v2', Bint[2, ])]))\n" ] } ], @@ -163,7 +120,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor(-1.9073) tensor(0.1485)\n" + "-1.9073049 0.14848001\n" ] } ], @@ -181,7 +138,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor(-1.6842) tensor(0.1856)\n" + "-1.6841614 0.18560001\n" ] } ], @@ -199,7 +156,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor(-1.6842) tensor(0.1856)\n" + "-1.6841614 0.18560001\n" ] } ], @@ -224,7 +181,7 @@ { "data": { "text/plain": [ - "Delta((('v0', (Tensor(tensor([1, 1, 1, 0, 1]), OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor(-0.8297846913337708, OrderedDict(), 'real').reduce(nullop, set())" + "Delta((('v0', (Tensor([1 0 1 1 0], OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor([-0.8297847 -0.8297847 -0.8297847 -0.8297847 -0.8297847], OrderedDict([('particle', Bint[5, ])]), 'real').reduce(nullop, set())" ] }, "execution_count": 8, @@ -234,14 +191,14 @@ ], "source": [ "sample_inputs = OrderedDict(particle=funsor.Bint[5])\n", - "spn(v1=0, v2=0).sample(frozenset({\"v0\"}), sample_inputs)" + "spn(v1=0, v2=0).sample(frozenset({\"v0\"}), sample_inputs, jax.random.PRNGKey(0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "what is `-0.8297846913337708`? a normalization factor?" + "what is `-0.8297847`? a normalization factor? why the latter term is a constant in torch but it is an array in jax" ] }, { @@ -253,23 +210,10 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(-2.0612e-09)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "-torch.nn.functional.softplus(-torch.tensor(20.))" - ] + "outputs": [], + "source": [] }, { "cell_type": "markdown", @@ -308,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -331,7 +275,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor(-0.5232)\n" + "-0.523248\n" ] } ],