diff --git a/notebooks/ADVI Guide API.ipynb b/notebooks/ADVI Guide API.ipynb new file mode 100644 index 000000000..1c9fefa45 --- /dev/null +++ b/notebooks/ADVI Guide API.ipynb @@ -0,0 +1,681 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "9f946eb4", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:16.738620Z", + "start_time": "2026-02-02T02:00:15.186007Z" + } + }, + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "from tqdm import trange\n", + "\n", + "import pymc as pm\n", + "\n", + "from pymc_extras.inference.advi.autoguide import AutoDiagonalNormal\n", + "from pymc_extras.inference.advi.training import compile_svi_training_fn" + ], + "outputs": [], + "execution_count": 1 + }, + { + "cell_type": "code", + "id": "e746bc33", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:18.572262Z", + "start_time": "2026-02-02T02:00:16.743117Z" + } + }, + "source": [ + "with pm.Model() as m:\n", + " X = pm.Normal(\"X\", 0, 1, size=(100, 3))\n", + " alpha = pm.Normal(\"alpha\", 0, 10)\n", + " beta = pm.Normal(\"beta\", 0, 5, size=(3,))\n", + "\n", + " mu = alpha + X @ beta\n", + " sigma = pm.HalfNormal(\"sigma\", 1)\n", + " y = pm.Normal(\"y\", mu=mu, sigma=sigma)\n", + "\n", + " prior = pm.sample_prior_predictive(random_seed=38)" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [X, alpha, beta, sigma, y]\n" + ] + } + ], + "execution_count": 2 + }, + { + "cell_type": "code", + "id": "a8ca0161", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:18.594345Z", + "start_time": "2026-02-02T02:00:18.584554Z" + } + }, + "source": [ + "draw = 123\n", + "true_params = {}\n", + "true_params[\"alpha\"] = prior.prior.alpha.sel(chain=0, draw=draw).values\n", + "true_params[\"beta\"] = prior.prior.beta.sel(chain=0, draw=draw).values\n", + "true_params[\"sigma\"] = prior.prior.sigma.sel(chain=0, draw=draw).values\n", + "\n", + "X_data = prior.prior.X.sel(chain=0, draw=draw).values\n", + "y_data = prior.prior.y.sel(chain=0, draw=draw).values" + ], + "outputs": [], + "execution_count": 3 + }, + { + "cell_type": "code", + "id": "dea42437-a68f-4602-9239-d5f60faab70d", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:18.608592Z", + "start_time": "2026-02-02T02:00:18.595942Z" + } + }, + "source": [ + "true_params" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "{'alpha': array(11.82213038),\n", + " 'beta': array([-0.92518728, 0.27270752, -0.20081106]),\n", + " 'sigma': array(0.40007044)}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 4 + }, + { + "cell_type": "code", + "id": "b89f4031", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:18.616898Z", + "start_time": "2026-02-02T02:00:18.609941Z" + } + }, + "source": [ + "m_obs = pm.observe(pm.do(m, {X: X_data}), {\"y\": y_data})" + ], + "outputs": [], + "execution_count": 5 + }, + { + "cell_type": "code", + "id": "015bdc56-309d-40e4-8e89-4f75691e2301", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:30.433124Z", + "start_time": "2026-02-02T02:00:18.617182Z" + } + }, + "source": [ + "with m_obs:\n", + " idata = pm.sample(mp_ctx=\"spawn\")" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 11 seconds.\n" + ] + } + ], + "execution_count": 6 + }, + { + "cell_type": "code", + "id": "4a4b3b07-0ef9-4387-93bc-622ae98aacb9", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:30.479596Z", + "start_time": "2026-02-02T02:00:30.439494Z" + } + }, + "source": [ + "guide = AutoDiagonalNormal(m_obs)" + ], + "outputs": [], + "execution_count": 7 + }, + { + "cell_type": "code", + "id": "69469eda-ae7d-4259-bb17-f69640d378dd", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:30.488867Z", + "start_time": "2026-02-02T02:00:30.480129Z" + } + }, + "source": [ + "guide.params_init_values" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "{sigma_loc: array(0.26284188),\n", + " sigma_scale: array(0.1),\n", + " beta_loc: array([0.88503108, 0.78157369, 0.67367489]),\n", + " beta_scale: array([0.1, 0.1, 0.1]),\n", + " alpha_loc: array(-0.5681205),\n", + " alpha_scale: array(0.1)}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 8 + }, + { + "cell_type": "code", + "id": "e735cd44-49e0-452e-a000-b56b27e60bca", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:30.503292Z", + "start_time": "2026-02-02T02:00:30.489253Z" + } + }, + "source": [ + "true_loc_dict = {k.name: v for k, v in guide.params_init_values.items()}\n", + "for key, value in true_params.items():\n", + " true_loc_dict[f\"{key}_loc\"] = value\n", + "true_loc_dict" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "{'sigma_loc': array(0.40007044),\n", + " 'sigma_scale': array(0.1),\n", + " 'beta_loc': array([-0.92518728, 0.27270752, -0.20081106]),\n", + " 'beta_scale': array([0.1, 0.1, 0.1]),\n", + " 'alpha_loc': array(11.82213038),\n", + " 'alpha_scale': array(0.1)}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 9 + }, + { + "cell_type": "code", + "id": "8ef58408-639a-460e-bcbf-c2d17b58ad09", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:35.331486Z", + "start_time": "2026-02-02T02:00:30.505233Z" + } + }, + "source": "f_loss_dloss = compile_svi_training_fn(m_obs, guide, stick_the_landing=True)", + "outputs": [], + "execution_count": 10 + }, + { + "cell_type": "code", + "id": "b9f1edb0-b70f-4130-9a89-16287738330f", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:35.355417Z", + "start_time": "2026-02-02T02:00:35.342539Z" + } + }, + "source": [ + "f_loss_dloss(np.array(500), **true_loc_dict)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "[array(264.54163996),\n", + " array(-141.54734769),\n", + " array(151.88467924),\n", + " array([-7.95708067, -4.35176788, -1.4180619 ]),\n", + " array([45.02632721, 44.46560167, 31.55529906]),\n", + " array(15.40696434),\n", + " array(43.45550916)]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 11 + }, + { + "cell_type": "code", + "id": "6086f2cc", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:36.926890Z", + "start_time": "2026-02-02T02:00:35.361133Z" + } + }, + "source": [ + "init_param_values = {k.name: v for k, v in guide.params_init_values.items()}\n", + "opt_param_values = list(init_param_values.values())\n", + "learning_rate = 1e-5\n", + "n_iter = 2_000\n", + "loss_history = np.empty(n_iter)\n", + "progress_bar = trange(n_iter)\n", + "draws = np.array(500, dtype=int)\n", + "for i in progress_bar:\n", + " loss, *grads = f_loss_dloss(draws, *opt_param_values)\n", + " loss_history[i] = loss\n", + " opt_param_values = [\n", + " np.asarray(value - learning_rate * grad) for value, grad in zip(opt_param_values, grads)\n", + " ]\n", + " if i % 50 == 0:\n", + " progress_bar.set_description(f\"Loss: {loss:.2f}\")\n", + " if i % 5_000 == 0 and i > 0:\n", + " learning_rate = max(learning_rate / 5, 1e-5)\n", + "\n", + "optimized_params = dict(zip(init_param_values, opt_param_values))" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "execution_count": 12 + }, + { + "cell_type": "code", + "id": "7f78e19d-1b1d-4b48-91b7-7711f35704e7", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:39.037610Z", + "start_time": "2026-02-02T02:00:38.969203Z" + } + }, + "source": [ + "window_size = 100\n", + "kernel = np.full(window_size, 1 / window_size)\n", + "plt.plot(np.convolve(loss_history, kernel, mode=\"valid\"))" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAGdCAYAAADjWSL8AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAN/tJREFUeJzt3Qt8FOW9//FfbhsC2VwUAgEEYxUhUEERVIqNVYm2fb2K9JxDab1htVhb2n+lbVqwnnj+nH9pBYR/2np6FKTCX/Rgb1xaLJciVUGsFDWoCFUQTUhIgJArm2wy/9fz7O6wmwRMdmd3djaftz6v2Zl5spnZGdgvzzzPTJKIGAIAAOAwyXZvAAAAQDgIMQAAwJEIMQAAwJEIMQAAwJEIMQAAwJEIMQAAwJEIMQAAwJEIMQAAwJFSJYENHTpUGhoa7N4MAADQC263WyorK/tuiFEBpqKiwu7NAAAAYRg2bNgnBpmEDTGBFhj1IdAaAwCAc1phVCNET767EzbEBKgPgRADAEDioWMvAABwJEIMAABwJEIMAABwJEIMAABwJEIMAABwJEIMAABwJEIMAABwJEIMAABwJEIMAABwJEIMAABwJEIMAABwJEIMAABwpIR/AKTV8gpGynUzZ8jp6hp58TfP2L05AAD0WbTE9FJu/hD57B1fkau+UBydIwIAAHqEENNLhmH4XiT19icBAICVCDG95gsxSUmkGAAA7ESI6aVAQ4wQYgAAsBUhJuwUAwAA7ESICROXkwAAsBchJtyOvQAAwFaEmDBDDC0xAADYixDTW+YQa0YnAQBgJ0JML50dnESIAQDAToSY3uJyEgAAcYEQ00t07AUAID4QYnqL0UkAAMQFQkyY6BMDAIC9CDG9xGMHAACID4SY3qJjLwAAcYEQ00tGYJA1I6wBALAVIaaXuGMvAADxgRATfqcYyw8GAADoOUJM2E8dIMQAAGAnQkzYKcb6gwEAAHqOENNL3LEXAID4QIgJE5eTAACwFyGm13yXkwgxAADYixAT7uUkOvYCAOCsEHP99dfLhg0bpKKiQn+hT58+PWT9jBkz5IUXXpCamhq9fvz48V3ew+VySVlZma7T2Ngo69evl2HDhoXUycnJkdWrV0tdXZ0u6nV2drbYzuzXS89eAAAcFWIGDBggb775psydO/ec61955RX58Y9/fM73WL58uQ47s2bNkqlTp0pmZqZs2rRJkpPPbs7atWtlwoQJcuutt+qiXq9Zs0bipyXG7i0BAABGuEWZPn16t+tGjhyp148fPz5keVZWluHxeIyZM2eay/Lz8w2v12sUFxfr+dGjR+ufnTx5slnnmmuu0ctGjRrVo21zu926vpqGu3/dlfxRlxpLy3cbpX/daOn7UvgMOAc4BzgHOAc4B6RX398x7xMzceJEfTlpy5Yt5rJjx47J/v37ZcqUKXr+uuuu05eQXnvtNbPOnj179LJAnc7Ue7rd7pASHfSJAQAgHsQ8xAwZMkQ8Ho8OJMGqq6v1ukCd48ePd/lZtSxQp7P58+dLfX29WVSfnag+dQAAANgqbkYnqSHLwTeS6+6mcp3rBFu0aJFkZWWZpXNH4WhsLwAA6EMhpqqqStLT0/Xoo2B5eXm6NSZQZ/DgwV1+dtCgQWadzlpbW6WhoSGkRAVNMQAA9M0Qs3fvXh04pk2bZi5Tl4jGjRsnu3bt0vO7d+/WIWfSpElmncmTJ+tlgTp2CbQE0RIDAIC9Unv7A2oI9aWXXmrOFxQU6HvBnDx5Uj766CPJzc2VESNGyNChQ/X6yy+/3GxdUa0oqr/KypUrZenSpXLixAn9c0uWLJHy8nLZtm2brnvgwAHZvHmzPPnkk3L//ffrZU888YRs3LhRDh48KLYixAAAEDd6NaSrqKjI6M6qVav0+rvvvrvb9aWlpeZ7pKenG2VlZUZtba3R1NRkbNiwwRg+fHjI78nNzTXWrFljnD59Whf1Ojs7u8fbGa0h1nkFI/UQ6//90gsMhWM4JOcA5wDnAOcA54BY+xn05vtb9U5NyPE2aoi1avVRnXyt7B+TVzBSfrThOWk+XS8PT73FsvcFAADSq+/vuBmd5BTcsRcAgPhAiAm3Yy/PHQAAwFaEGAAA4EiEmN4yHwDJze4AALATISbsDEOIAQDAToSYsFOM9QcDAAD0HCGmlwz/iHRaYgAAsBchJuxnJ9EUAwCAnQgxvcSzkwAAiA+EmN6iYy8AAHGBEBPuHXsBAICtCDHhoksMAAC2IsT0VuCxA9wnBgAAWxFiwhxizR17AQCwFyEm7HvdcT0JAAA7EWJ6izv2AgAQFwgxvcR9YgAAiA+EmN7iKdYAAMQFQky4HXsBAICtCDHhfnDJfHQAANiJb+LeoiEGAIC4QIjpJR47AABAfCDE9FbQs5O4ay8AAPYhxETSEsOjBwAAsA0hppdCMwx37QUAwC6EmF4LTjGWHgsAANALhJgILifx/CQAAOxDiInkehIAALANISYS9IkBAMA2hJhILicRYgAAsA0hpreCryYRYgAAsA0hJqKWGKsPBwAA6ClCTK9xOQkAgHhAiIno2Uk0xQAAYBdCTC8ZHUEtMcmEGAAA7EKI6aWOjnbzdVIyHx8AAHbhWziClphkQgwAAM4JMddff71s2LBBKioqdP+Q6dOnd6lTWlqq1zc3N8uOHTuksLAwZL3L5ZKysjKpqamRxsZGWb9+vQwbNiykTk5OjqxevVrq6up0Ua+zs7PFbkZHh/malhgAABwUYgYMGCBvvvmmzJ07t9v1JSUlMm/ePL1+0qRJUlVVJVu3bpXMzEyzzvLly2XGjBkya9YsmTp1ql63adOmkJaNtWvXyoQJE+TWW2/VRb1es2aNxIMOf5ChJQYAAHsZ4RZl+vTpIcsqKyuNkpISc97lchmnTp0y5syZo+ezsrIMj8djzJw506yTn59veL1eo7i4WM+PHj1av/fkyZPNOtdcc41eNmrUqB5tm9vt1vXVNJJ97K48+o+XjKXlu42svEGWvzeFz4BzgHOAc4BzoC+fA+5efH9b2iemoKBA8vPzZcuWLeay1tZW2blzp0yZMkXPT5w4UV9OCq5z7Ngx2b9/v1nnuuuu05eQXnvtNbPOnj179LJAnc7Ue7rd7pAS7c69ydztDgAA21gaYoYMGaKn1dXVIcvVfGCdmno8Hh1Izlfn+PHjXd5fLQvU6Wz+/PlSX19vFtUnJ9qde5NS6BcNAIBdkqN/QzjfgxI7L+usc53u6p/vfRYtWiRZWVlm6dxROBotMXTsBQAgQUKM6sSrdG4tycvLM1tnVJ309HQ9+uh8dQYPHtzl/QcNGtSllSf4slVDQ0NIiXZLDB17AQBIkBBz+PBh3b9l2rRp5rK0tDQpKiqSXbt26fm9e/fqwBFcR4WecePGmXV2796tQ44a3RQwefJkvSxQJx6GWSenpNi9KQAA9Fmp4QyxvvTSS0M6844fP15OnjwpH330kR4+vWDBAjl06JAu6rW6X4waMq2o/iorV66UpUuXyokTJ/TPLVmyRMrLy2Xbtm26zoEDB2Tz5s3y5JNPyv3336+XPfHEE7Jx40Y5ePCg2K2j3X85iY69AADYqldDn4qKiozurFq1yqxTWlqqh1q3tLQYL774ojF27NiQ90hPTzfKysqM2tpao6mpydiwYYMxfPjwkDq5ubnGmjVrjNOnT+uiXmdnZ0dliFZvyyMv/kkPsR5y2adsH4pG4TPgHOAc4BzgHJAE+gx68/2d5H+RcNQQa9Xqozr5Wt0/5t+3b5DsvEGy5F/ulGMH/2npewMA0Je5e/H9zRjhMARGSCUzxBoAANsQYiLoE5OcTMdeAADsQoiJYHRSUrK6GgcAAOxAiInoPjG0xAAAYBdCTCRDrGmJAQDANoSYCDr2JnGzOwAAbEOICQM3uwMAwH6EmIiGWNMnBgAAuxBiwmC0+5+dlMzHBwCAXfgWDkNHBx17AQCwGyEmgiHWSQyxBgDANoSYCFpikhliDQCAbQgxYaAlBgAA+xFiImmJ4QGQAADYhhATweikJEYnAQBgG0JMRM9O4uMDAMAufAuHgSHWAADYjxATBjr2AgBgP0JMGOjYCwCA/QgxkbTEJPHxAQBgF76Fw2AwxBoAANsRYsLQERhiTUsMAAC2IcSEwTD8Q6y52R0AALYhxISho52nWAMAYDdCTAQtMTzFGgAA+xBiImiJ4Y69AADYhxATBqODZycBAGA3Qkwkz06iYy8AALYhxETSsZch1gAA2IYQE0nHXlpiAACwDSEmDHTsBQDAfoSYMNCxFwAA+xFiwtDhH53EEGsAAOxDiAlDh9fXsTclNdXq4wEAAHqIEBOGdq/X9+GlEWIAALALISaCjr0pKSlWHw8AANBDhJhIWmK4nAQAQGKFmMzMTFm2bJkcOXJEmpub5ZVXXpGrr746pE5paalUVFTo9Tt27JDCwsKQ9S6XS8rKyqSmpkYaGxtl/fr1MmzYMIkH7W2+EEOfGAAAEizErFixQqZNmyZ33nmnfPrTn5YtW7bItm3bZOjQoXp9SUmJzJs3T+bOnSuTJk2Sqqoq2bp1qw4/AcuXL5cZM2bIrFmzZOrUqXrdpk2b4mJEUIe/JYYQAwCAvQwrS79+/Yy2tjbjC1/4Qsjyffv2GQsXLtSvKysrjZKSEnOdy+UyTp06ZcyZM0fPZ2VlGR6Px5g5c6ZZJz8/3/B6vUZxcXGPtsPtdhuKmlq9j9ffPtNYWr7buOPn/2H5e1P4DDgHOAc4BzgH+vI54O7F97flzRqpqam6nDlzJmR5S0uLblEpKCiQ/Px83ToT0NraKjt37pQpU6bo+YkTJ+rLScF1jh07Jvv37zfrxMUde+kTAwCAbSwPMar/yq5du+Thhx/WYUVd/rn99tvlmmuu0fNDhgzR9aqrq0N+Ts0H1qmpx+ORurq6c9bpTIUet9sdUqLdsTeFIdYAANgmKh1MVF+YpKQkqays1GHku9/9rqxdu1ba/S0YwQ9RDFD1Oy/r7Hx15s+fL/X19WZRnYajPjqJIdYAACRWiPnggw/khhtukAEDBshFF12kW2HS0tLk8OHDuhOv0rlFJS8vz2ydUXXS09MlJyfnnHU6W7RokWRlZZklmiOZzJYYLicBAGCbqA71UcOnVSBRYeSWW27Rw6RVkFH9W9TopQAVcIqKivRlKGXv3r26n0xwHRV6xo0bZ9bpTNVvaGgIKdHSwRBrAABsF5X75hcXF+tLP++9955ceumlsnjxYv161apV5vDpBQsWyKFDh3RRr1XgUZecFHU5aOXKlbJ06VI5ceKEnDx5UpYsWSLl5eV6qLbdaIkBACBBQ0x2dra+vDN8+HAdQH73u9/JQw89JF7/ZZhHH31UMjIy5PHHH5fc3FzZs2ePDj6qU3DAgw8+qOuvW7dO192+fbvMnj3bfIK0nRidBACA/ZL8Y60TjhqdpFp0VP8Yqy8tjZ56rXzjv5bJx++8J8u+MtvS9wYAoC9z9+L72/7b3zpQu9f/AEiGWAMAYBtCTBgYYg0AgP0IMWFgdBIAAPYjxETSEpOaYvXxAAAAPUSIiWB0Eje7AwDAPoSYMHCfGAAA7EeICQMhBgAA+xFiwkCfGAAA7EeICQOjkwAAsB8hJgzt/o69yTzFGgAA2xBiwtARGGKdnCxJyXyEAADYgW/gCPrEKAyzBgDAHoSYMBBiAACwHyEmwhBDvxgAAOxBiAlDh/8p1koKjx4AAMAWhJiI7xWTauXxAAAAPUSIibA1hpYYAADsQYgJE48eAADAXoSYMBFiAACwFyEmTPSJAQDAXoSYCO/ay83uAACwByEmTO3+jr3JDLEGAMAWhJgwtbe16WlqWpqVxwMAAPQQISZM3kCIcRFiAACwAyEmTN7WVj1NSXNZeTwAAEAPEWLC1N5KSwwAAHYixITJa4YYWmIAALADISZM3jbf5ST6xAAAYA9CTKQtMfSJAQDAFoSYCDv20hIDAIA9CDER94lhiDUAAHYgxETYJyaFjr0AANiCEBMmhlgDAGAvQkyY6NgLAIC9CDFhomMvAAD2IsRE/OwkbnYHAIAdCDFhoiUGAAB7EWIi7hPDEGsAABIixKSkpMjChQvlgw8+kObmZnn//ffl4YcflqSkpJB6paWlUlFRoevs2LFDCgsLQ9a7XC4pKyuTmpoaaWxslPXr18uwYcMkXng9DLEGACChQsyPfvQj+eY3vylz586VMWPGSElJifzwhz+U73znO2YdtWzevHm6zqRJk6Sqqkq2bt0qmZmZZp3ly5fLjBkzZNasWTJ16lS9btOmTZKcHB+NRzw7CQAA+xlWlo0bNxorVqwIWfbb3/7WWL16tTlfWVlplJSUmPMul8s4deqUMWfOHD2flZVleDweY+bMmWad/Px8w+v1GsXFxT3aDrfbbShqavU+qnLlF4qNpeW7jfufLIvK+1P4DDgHOAc4BzgH+uI54O7F97flzRovv/yy3HTTTXLZZZfp+SuuuEK3pPz5z3/W8wUFBZKfny9btmwxf6a1tVV27twpU6ZM0fMTJ07Ul5OC6xw7dkz2799v1ulM1Xe73SElJh176RMDAIAtUq1+w5///OeSnZ0tBw4ckPb2dt1H5qGHHpLnnntOrx8yZIieVldXh/ycmh85cqRZx+PxSF1dXZc6gZ/vbP78+fLII49I7J+dxBBrAADsYHlLzFe+8hW544475Gtf+5pcddVVcvfdd8sPfvADueuuu0LqGYZqCTpLdfztvKyz89VZtGiRZGVlmSXanYDb/c9O4gGQAAAkSEvM4sWL5Wc/+5n8z//8j55Xl4BUC4tqKVm9erXuxKuoFpXAayUvL89snVHL09PTJScnJ6Q1RtXZtWtXt79XXZJSJVZoiQEAIMFaYvr37y8dHR0hy9RlpcCoosOHD+v+LdOmTTPXp6WlSVFRkRlQ9u7dqwNJcB0VesaNG3fOEGPfze64nAQAQEK0xGzcuFH3gTl69Ki8/fbbcuWVV+rh1E899VTI8OkFCxbIoUOHdFGv1f1i1q5dq9fX19fLypUrZenSpXLixAk5efKkLFmyRMrLy2Xbtm0SD862xHCzOwAA7GLp0KjMzExj2bJlxpEjR4zm5mbjn//8p7Fw4UIjLS0tpF5paakeat3S0mK8+OKLxtixY0PWp6enG2VlZUZtba3R1NRkbNiwwRg+fHhUhmiFU/IKRuoh1gtf/ovtw9EofAacA5wDnAOcA5Ign0Fvvr+T/C8SjhpirVp0VCffhoYGy98/d+gQ+clf/iBtZzzy40k3WP7+AAD0Re5efH/Hx+1vHai1uUVP0/qlS3JKit2bAwBAn0OICdOZpmbztSujn1XHAwAA9BAhJkztbW3S3ubVr9MH9A/3bQAAQJgIMRHwNPtaY9L7E2IAAIg1QkwECDEAANiHEBMBj79fjKt/hlXHAwAA9BAhJgIe/wilfvSJAQAg5ggxFlxOctEnBgCAmCPERKA10LGXlhgAAGKOEGPB5aT0DPrEAAAQa4QYCzr2ptOxFwCAmCPEWNESM2CAVccDAAD0ECHGko69XE4CACDWCDEWhBiGWAMAEHuEmAhwszsAAOxDiIlAa6BPDPeJAQAg5ggxEThjjk7iAZAAAMQaISYC3OwOAAD7EGKsGGLN6CQAAGKOEGPB6CQuJwEAEHuEGAtaYrhPDAAAsUeIiYCnqUlPU1JTJTU93apjAgAAeoAQE4HWljPma/rFAAAQW4SYCBgdHXTuBQDAJoSYCLW2BB4Cyb1iAACIJUKMRY8eSM8gxAAAEEuEGKuGWdMSAwBATBFiLHp+EsOsAQCILUJMhM74W2L60RIDAEBMEWIsa4mhTwwAALFEiLGqYy8hBgCAmCLEWPb8pAwrjgcAAOghQoxVLTH0iQEAIKYIMRa1xPTLHGDF8QAAAD1EiIlQc32DnmZkuSN9KwAA0AuEmAi1EGIAALAFISZCzafr9bR/VpYVxwMAANgVYg4fPiyGYXQpv/zlL806paWlUlFRIc3NzbJjxw4pLCwMeQ+XyyVlZWVSU1MjjY2Nsn79ehk2bJjEo5Z6QgwAAAkRYiZNmiRDhgwxy80336yXP//883paUlIi8+bNk7lz5+q6VVVVsnXrVsnMzDTfY/ny5TJjxgyZNWuWTJ06Va/btGmTJCfHX8MRfWIAALCPEc2ybNky49ChQ+Z8ZWWlUVJSYs67XC7j1KlTxpw5c/R8VlaW4fF4jJkzZ5p18vPzDa/XaxQXF/f497rdbkNR02juXz93prG0fLcuqS5XVH8Xhc+Ac4BzgHOAcyDRzwF3L76/o9q0kZaWJnfccYc89dRTer6goEDy8/Nly5YtZp3W1lbZuXOnTJkyRc9PnDhRX04KrnPs2DHZv3+/Wac76mfcbndIiQVPY5N0dHTo14xQAgAgdqIaYm677TbJycmR3/zmN3peXV5SqqurQ+qp+cA6NfV4PFJXV3fOOt2ZP3++1NfXm0X1uYkF1d8nMEKpP8OsAQBIjBBz7733yubNm3VLSucv/mBJSUldlnX2SXUWLVokWVlZZollR+Czw6wZoQQAgONDzIgRI3Sn3hUrVpjLVCdepXOLSl5entk6o+qkp6frFpxz1emOuizV0NAQUmKlOTBCKZsQAwCA40PMPffcI8ePH5c//elPIcOvVavMtGnTQvrNFBUVya5du/T83r17dSAJrqNCz7hx48w68aa5jhADAECspUbjTdWlHxVinn76aWlvbw9Zp4ZPL1iwQA4dOqSLeq3uF7N27Vq9XvVnWblypSxdulROnDghJ0+elCVLlkh5ebls27ZN4lGTv//OgJxsuzcFAIA+IyohRl1GGjlypDkqKdijjz4qGRkZ8vjjj0tubq7s2bNHiouL9U3tAh588EHxer2ybt06XXf79u0ye/ZscxRQvGk6dVpPB+SGXgIDAADRk+Qfa51w1BBr1aqjOvlGu3/MzXNmy+e/c7+8+tv18vx//CyqvwsAgETm7sX3d/zdAteBaIkBACD2CDEWoE8MAACxR4ixQNMpf8de+sQAABAzhBgLNNX5O/YyOgkAgJghxFjYEqNudqeGlwMAgOgjxFig+bTvZnfJKSnSL0YPngQAoK8jxFig3euVlgbffW4G5HLDOwAAYoEQY/EIpczcXKveEgAAnAchxiKNJ07pqXvgBVa9JQAAOA9CjEXqqnxP2M4ZMtiqtwQAAOdBiLFIXfVxPc0ZnGfVWwIAgPMgxFikrsofYoYQYgAAiAVCjEVO+1tismmJAQAgJggxFveJyR48yKq3BAAA50GIsUhddY2eZucNkqRkPlYAAKKNb1uLNNSe0De9S0lNFfeFDLMGACDaCDEWMTo6pL6mVr+mcy8AANFHiLHQ6cAlJTr3AgAQdYQYC3HDOwAAYocQYyFueAcAQOwQYizEDe8AAIgdQoyFuOEdAACxQ4iJSp8YHj0AAEC0EWIsdLLymDk6KSUtzcq3BgAAnRBiLNR44pScaWyS5ORkGXjRMCvfGgAAdEKIsVjNh0f1dODIi6x+awAAEIQQY7HaDz/S00EjR1j91gAAIAghxmI1Rz/W04Ejh1v91gAAIAghJkqXk2iJAQAguggx0bqcNII+MQAARBMhxmI1H/ouJ2UPHiSujAyr3x4AAPgRYizWUl8vTafq9OuBI+gXAwBAtBBioqDmqO+SEsOsAQCIHkJMFNQcCQyzpl8MAADRQoiJglp/SwwhBgCA6CHEREENN7wDACDqCDFRHGZNx14AABwWYoYOHSpr1qyR2tpaaWpqkn379slVV10VUqe0tFQqKiqkublZduzYIYWFhSHrXS6XlJWVSU1NjTQ2Nsr69etl2LBhjmqJybwgVwbkZNu9OQAAJCTLQ0xOTo688sor0tbWJp///Od1OPn+978vdXW+YcdKSUmJzJs3T+bOnSuTJk2Sqqoq2bp1q2RmZpp1li9fLjNmzJBZs2bJ1KlT9bpNmzbpJ0THu9aWFjPIDC8cbffmAACQsAwry6JFi4y//e1v561TWVlplJSUmPMul8s4deqUMWfOHD2flZVleDweY+bMmWad/Px8w+v1GsXFxT3aDrfbbShqavU+9qTc/vP/MJaW7zZunjPblt9P4TPgHOAc4BzgHBAHfga9+f62vFnjS1/6krz++uuybt06qa6uln/84x9y3333mesLCgokPz9ftmzZYi5rbW2VnTt3ypQpU/T8xIkT9eWk4DrHjh2T/fv3m3U6U/XdbndIsdNH+9/V04vGjbF1OwAASFSWh5hLLrlEHnjgATl06JDccsst8utf/1r3bbnzzjv1+iFDhuipCjjB1HxgnZp6PJ6QS1Cd63Q2f/58qa+vN4vqb2Onj/a/o6cXjSXEAADgiBCj+qyo1peHHnpI3njjDXniiSfkySef1MEmmGGolqCzkpKSuizr7Hx1Fi1aJFlZWWaxuxNwxYGD0tHeLtl5gyRr0EBbtwUAgERkeYhRl33eecfXChHw7rvvyogRI/Rr1YlX6dyikpeXZ7bOqDrp6em6k/C56nSmLkk1NDSEFDu1tpyRqvcP69dcUgIAwAEhRo1Muvzyy0OWjRo1Sj788EP9+vDhwzroTJs2zVyflpYmRUVFsmvXLj2/d+9eHUqC66jQM27cOLOOE3z89gE9JcQAABAdlvYqvvrqq43W1lZj/vz5xqc+9Snjq1/9qtHY2Gh87WtfM+uokUlqNNJtt91mjB071njmmWeMiooKIzMz06zz+OOPG0ePHjVuvPFGY8KECca2bduMffv2GcnJyZb3bo5WuW7mDD1Cac6vl9ne25vCZ8A5wDnAOcA5IA74DHr5/W39Bnzxi1803nrrLaOlpcV45513jPvuu69LndLSUj3UWtV58cUXdZgJXp+enm6UlZUZtbW1RlNTk7FhwwZj+PDh0foQolKGF47WIWbhy38xkpKSbD8xKHwGnAOcA5wDnAMS559Bb76/k/wvEo4aYq1GKalOvnb1j0lOTZGFL/1F+mUOkGVfmS0fv/OeLdsBAEAifn/H/+1vHazD2y7//Pte/XrUdZPt3hwAABIKISbKDu56TU8v/8y10f5VAAD0KYSYKDvw8qt6WjDhCn1ZCQAAWIMQE2UnPq6Q44c/lJS0VLns2knR/nUAAPQZhJgYePfl3Xo6Zup1sfh1AAD0CYSYGDjwki/EjL6eEAMAgFUIMTHwwd43xNPcop+jNPTyy2LxKwEASHiEmBjwtrbKP1/zDbUezSUlAAAsQYiJkXdf8j3zaQyXlAAAsAQhJkYO+Dv3jhw/TjKy3LH6tQAAJCxCTIycqqySqvcPS0pqqhR+9jOx+rUAACQsQkwMvfmX7Xp65RemxfLXAgCQkAgxMbRv81bzOUoDcnNi+asBAEg4hJgYqjlyVD56+119SWl88Y2x/NUAACQcQkyM/eNPW/T02n+ZHutfDQBAQiHExNjrG/6sb3w3bMwonqUEAEAECDEx1ny6Xv7+x0369edmfy3Wvx4AgIRBiLHBztXPSkd7u1z+mWt5DAEAAGEixNjgZMUxecM/3Pqm++6yYxMAAHA8QoxN/rpytZ5eUXyj5I+61K7NAADAsQgxNjl28H1544VtkpycLNNL/pddmwEAgGMRYmy0admvpM3jkcuuuVrG33KTnZsCAIDjEGJsfp7SX1f4Livd+u1vSFIyhwMAgJ7iW9NmO9c8J011pyWvYKRMuPVmuzcHAADHIMTYzNPULDuffla/vuWBe/UjCQAAwCcjxMSBl599XuprT8igi0fIrd+ZY/fmAADgCISYOGmN+f1/Ltavb/z6nTLm+il2bxIAAHGPEBMnyrfvlJeeWadff2XhQ5I7dIjdmwQAQFwjxMSRTY/9SioP/lPcF14gc369XAbkZNu9SQAAxC1CTBzxtrbKim/Nk1PHqvRopXt/tVRcGf3s3iwAAOISISbOnK6ukSfu/54edj3yirFy12M/leTUFLs3CwCAuEOIiUPHD38oK779fWltOSNjpl4ndz/2U0nrl273ZgEAEFcIMXHq6Ftvy9PfX6AfSzDuc5+Vb616XLIHD7J7swAAiBuEmDh24KXd8uv7vitNp+pkxLhC+d6zT8nFE66we7MAAIgLhJg4d+SNt2T5V78uxw69L1mDBsq3n/4vmV7yPTr8AgD6PEKMA5ysOCa/uGOO/H39nyU5OVk+e+dX5Ae//39y+ZRr7N40AABskyQihiQgt9st9fX1kpWVJQ0NDZIoRk+9Vv71338kufm+m+G9/eLLsnHpL6TmyFG7Nw0AgJh+fxNiHCi9f3+5Ze43ZOqsf5WUtFRpb/PKrnW/lxdXPSN11cft3jwAAGISYiy/nFRaWiqGYYSUY8eOdalTUVEhzc3NsmPHDiksLAxZ73K5pKysTGpqaqSxsVHWr18vw4YNs3pTHcvT3CwbHv2/svjLt+uWGBVkrr99pizY/DuZ9Z8/kcGXXGz3JgIA4Mw+Mfv375chQ4aY5dOf/rS5rqSkRObNmydz586VSZMmSVVVlWzdulUyMzPNOsuXL5cZM2bIrFmzZOrUqXrdpk2bdH8QnKUuIT31nR/Kr+/7jhza87oOM5Omf1FK1j8r31zxC7ny89MkJTWVjwwAkLAMK0tpaamxb9++c66vrKw0SkpKzHmXy2WcOnXKmDNnjp7PysoyPB6PMXPmTLNOfn6+4fV6jeLi4h5vh9vtNhQ1tXof47VcNK7QuGvp/zEWv/mKsbR8ty6lOzYZt3z7G3pdUlKS7dtI4TPgHOAc4BzgHBCLvr+j0rRx2WWX6ctFH3zwgTz77LNSUFCgl6tpfn6+bNmyxazb2toqO3fulClTpuj5iRMn6stJwXXU5SjVuhOo0x31M+o6WnDpaz7a/46s/v5D8tNb/0X+8qsn5fTxGskaeKEUf/Pr8r1nV8qCF34nX/rhd6XgqvGSRKsWAMDhLL/WsGfPHrnrrrvk4MGDMnjwYPnJT34iu3btkrFjx+pLS0p1dXXIz6j5kSNH6teqjsfjkbq6ui51Aj/fnfnz58sjjzxi9e44knqA5JZfPyXbVjwtn77pBrnqC9Pk0slXywVD86Xorq/q0ny6Xo68US4f7N0n77++Tz5+9z3p8LbbvekAANgXYl544QXztWo92b17t7z//vty9913y6uvvqqXq86+wZKSkros6+yT6ixatEgee+wxc161xKjWoL5MhZI3/7Jdl9T0dBkz9VoZd2ORFN7wGemfnSWFRZ/RJdBZ2Bdq3tCBRr0+09Bo9y4AAHBOUe/1qUYglZeX60tMf/zjH/Uy1aKiOvQG5OXlma0zanl6errk5OSEtMaoOqpF51zUZSlV0D2vxyPl23fqop6KPWz05VJw1RXyqYkTpOCqCTIgJ1vfPC9wA72Ojg6pPHBIPnr7XTl28J9SV1WtL0+dPl4rjSdPidHRwUcNAEjsEKP6qowZM0ZeeuklOXz4sO7fMm3aNHnjjTf0+rS0NCkqKpIf/ehHen7v3r06jKg6zz//vBl6xo0bp0c2wZoWGtV/RpW/rX5Ot3INvvQSHWguvvIKuahwtAy6eIQML7xcly4/39GhW2laGhqkpb5Rmuvr5XT1cX2PmtPVNb5y/LgOPU0n6z6xlQ0AgLgIMYsXL5aNGzfK0aNHdeuJ6hOjbljz9NNPm8OnFyxYIIcOHdJFvVatNWvXrtXr1Q1uVq5cKUuXLpUTJ07IyZMnZcmSJbo1Z9u2bVZvLvyX96oOva/LK8/9Tn8m6jlNKtAMH3O5DL5kpGTlDZLsQYPEPfACSU5J0ZejVPkk6incpyqr5ERFpTSeOCkNtSekvtY3VUUFnfqaE9La0sKxAADYG2KGDx+uRyQNHDhQ36xO9YO59tprdahRHn30UcnIyJDHH39ccnNzdUfg4uJifVO7gAcffFC8Xq+sW7dO192+fbvMnj1btwAgNuprauWtLX/VJVhwgOnnzpQMt1sG5GZL7pAhkj14kK+owJM3SDIvvEDS0tMlr2CkLudzpqlJGmpPSn1trW9a45s21NbqkFPvDz3qid607AAAFB47gKhRfW9yBufJhcOH6Wc9ZV6Yq4NN1oUXiHvQQD38W4Ue9RiFnmr3eqXxxCmpP3FCGmp8wUYFHHVJ60xDk77EdaaxyXe5q7FRzqjS0CRe+ksBQMI9doDbuSKqfW/UE7hVOR8VYtyDLtShxj0waBq0TJUBuTn6DsSBFp/eUJe1QsNNk7TU+wOPDjpque+16uejl6nXDWfXqwAFAIgfhBjYTg3v9nzYLLUffvSJLTuZF1wgWQMvEPdA1ZJztkUnI8st/TIHSEZmpr7MpV73y1SXu3yPs1CXtVRxX3hB2NvZ2nLGH25Ui08g6JwNQ2o/1PK2M2ek9cwZaWs5I831Db7WoppaXQcAYB1CDBzVslOvOgIfrxGR93r0M+rOxOkD+vvDjS/Y+MJNIOS4zeUZKvjoAOQLP2YYGjBAv5cro58uqtNzONTTxlUHZk9Li27ZUf2AjPYOaW9v1y1FKvTo8HPGY4ag7ufVNLh+IDT56jH8HUBfQYhBQlNf6DowqBv3nf+q1nmDkK9lxx92sty+wBPU6qNCUrqeDpC0fumS1k8Fngx9/x11KUyFIvWAzow0t/55GZwn0aL6/wSHms4hJzDvbW2T9rY28Xpapa21VYclHaZUCbw+49H3GFLv2eZp1SGstdkXxNSUS2wA7ESIAXoQhNQlI1XCpVpw+rndvumA/joAqRYi9WT25NRU3+Wufuni6tdP0lSLj5r659XP6FDkD0ed16tl6f0zzv6hdrl0kU8eAR8xb1vb2WDT3OJraWoOzDebyz0t6rUvQHW0t+vWp442r3i9bTpMqVYlFZpUoFLv2e5VAcvrK15f6fBPz87zmAygryPEADGg+tOoEk3q0RKuLkFIhSB/y1DQvGolSklLk1RVXC4diALTQP+htIyzr1U9HZoyMsTVP0PP69/pf4/+WTFITN3Qgaet+4ATeC3+my0GhuYbHYa/lUkFp1bdyqRbpFSY8nh8x0r10wq0OLW06MCl3kuvP+MRo71d11PPIAtc8lPvowIagNghxAAJQl/28Xj0F2u0qVFiKsyoViI1uiwQbtIzMnSrkDmv1/U7+7pfum55UvcbUu+Rkpoiqf7gpEJWarpLhyJVR69PSzWDVncCISpeqEDlbVOhKChUnSNkqQAWstysdzaAdfk5/b7tOnSFvl/gtVp3rpars7+HFi4kCkIMgF5TX4aRXmILp2+SL/ik+kJOWkrofKeSnOab6p9NCn6fFF9o6t9P0vyX3lJcvjCkl6sAFhy8Mvr5A5VvvWqRUiFMLVctUGo+QIUuVZzunCGnrbtQpYKbaslSrVpt0q6mOii1+VvK/H2vWgPzQevUMv8lxeB6qkVLXcZVNzhVU9V61tHR7p92mGHvbDDzt6S1erjM2Mc4/08bgD5BfZn5vijj60Gv6tljqgVJX3ZTxR+IfMGqm6CVlma2MgXWdw5hvpAWmE/R8yEtVEHLz9br7n1Szv7OkHpB25WS0mWfnB7GAuFHhSHdBys4gAWClNkq1dZ1vTltF8PwhahAqAq8nz4XPa2+MNeu6qjfFQheQa/VNuifC12uWuPM+UAd/V6B1/5tV/OB36da69rVtqkWN19o7OuXMJ17lgJAHFB9bfRorjMecSIVwoIv350NQ0EBqJsWruAA5utMrsKby3cJULVumZcCVYDyldTAVC3zv3dgvTlVYS0lRXd6TzJLkjmv15mXI88GsmCqrvhb7hJdR3u7v5XLH2r8oSoQ4ALBS93OoSMkNPmX6XodvrCm+o2p/w2jSxBTYcoIhKuggFj9/hHZ/fwfbNv/xD/CAIBzUl9YgUs54tDnsKpwo0OUPzwlqf/8QSikNUq99oeyzn2vuraS+eqoABUIU4FwpaaB1rdAa5ZvfbIvbJn11M+ldFmepOfPvk5JSfVN1Xv5t1lfPvUvD4S/QDAMlqwvbfoub9rhwMuvEmIAAAiXaiFwcmtYb+kQpVuuXP7w5gtwgRYuMzgFhS4dtALhKaVTsNL1knTnMfWfmvoCWOf3UFN/4PK/34mjH9v6WdASAwCAg6jLOK0t7SJRvm2DEyTbvQEAAADhIMQAAABHIsQAAABHIsQAAABHIsQAAABHIsQAAABHIsQAAABHIsQAAABHIsQAAABHIsQAAABHIsQAAABHIsQAAABHIsQAAABHSvinWLvdbrs3AQAAROF7OzXRP4SKigq7NwUAAITxPd7Q0HDeOkkiYkiCGjp06Cd+AOF+sCocDRs2LCrvH0/Y18TFsU1Mfem49rX97Wv7WllZ2XdbYpSefACRUCdRop9IAexr4uLYJqa+dFz72v72hX1t6OH+0bEXAAA4EiEGAAA4EiEmDB6PRx555BE9TXTsa+Li2CamvnRc+9r+9qV97amE7tgLAAASFy0xAADAkQgxAADAkQgxAADAkQgxAADAkQgxvfTAAw/IBx98IC0tLfL666/L1KlTxWl+/OMfy2uvvSb19fVSXV0tf/jDH2TUqFEhdVatWiWGYYSU3bt3h9RxuVxSVlYmNTU10tjYKOvXr9d3kownpaWlXfbj2LFjXeqou2A2NzfLjh07pLCw0HH7GXD48OEu+6vKL3/5S8cf1+uvv142bNigj5Xa7unTp3epY8WxzMnJkdWrV0tdXZ0u6nV2drbEy76mpqbKz372M3nrrbf0Pqg6Tz/9tOTn54e8h9r/zsf62Wefjbt97cmxteq8jYf9/aR97e7Pryo/+MEPHHlsY0GNTqL04DOYOXOm4fF4jHvvvdcYPXq0sWzZMqOhocG46KKLHPX5bd682bj77ruNwsJC44orrjA2btxoHDlyxOjfv79ZZ9WqVcaf//xnY/DgwWbJzc0NeZ/HH3/c+Oijj4ybbrrJmDBhgrF9+3Zj3759RnJysu37GCilpaVGeXl5yH4MHDjQXF9SUmKcPn3amDFjhjF27Fjj2WefNSoqKozMzExH7WegqH0L3le1zUpRUZHjj+utt95qLFy4UB8rZfr06SHrrTqW6vN56623jGuvvVYX9XrDhg1xs69ZWVnGli1bjH/7t38zRo0aZVxzzTXG7t27jb///e8h77Fjxw7jv//7v0OOtfrZ4DrxsK89ObZWnbfxsL+ftK/B+6jK7Nmzjfb2dqOgoMCRx1aiX2zfAMeUV199Vf9BCV72zjvvGD/96U9t37ZIv/iU66+/PuQvjT/84Q/n/Bn1B0YFOhXsAsvy8/MNr9drFBcX275PwSFG/UV2rvWVlZX6yy8w73K5jFOnThlz5sxx1H6eq6igfejQoYQ7rt395W/FsVT/OFEmT55s1lEhQVGBIV72tXO5+uqrdb3gf1CpLzp1/M/1M/G4r+faXyvOW6ceW7Xf27ZtC1nm1GMrUShcTuqhtLQ0mThxomzZsiVkuZqfMmWKOFmgifHkyZMhy2+44QZ9uem9996TJ554QgYNGmSuU5+Far4N/jzUZZr9+/fH3edx2WWX6aZbdRlQNbkWFBTo5WqqmuCD96G1tVV27txp7oOT9rO7c/aOO+6Qp556KiGPazCrjuV1112nm97V5daAPXv26GXxvP/qz3BHR4fezmC33367vryi9nHx4sWSmZlprnPavkZ63jptf5W8vDz54he/KCtXruyyLpGObSQS+gGQVho4cKC+Fq3+EAVT80OGDBEne+yxx+Sll16St99+21y2efNmef755+XDDz/UXxALFy6Uv/71r/ovC/XloPZZ3TWy81+a8fZ5qD+4d911lxw8eFAGDx4sP/nJT2TXrl0yduxYczu7O6YjR47Ur52yn9257bbb9HXx3/zmNwl3XDuz6liq6fHjx7u8v1oWr/ufnp6u+8isXbs25KF5zzzzjO4jVVVVJePGjZNFixbJ+PHjpbi42HH7asV566T9Dbj77rv1Mf39738fsjyRjm2kCDG9pDpQBUtKSuqyzElUh88rrriiSwfldevWma9VuFGdmNVfIOpfBaoj8LnE2+fxwgsvmK/Vv1hUZ8D3339f/+Xw6quvhn1M420/u3Pvvffqv/yDOzInynE9FyuOZXf143X/1T+snnvuOUlOTpZvfetbIetWrFgRcqwPHToke/fulSuvvFL27dvnqH216rx1yv4GfP3rX9eBpfNjBhLp2EaKy0k9VFtbK16vt0uKVc19nf/15xSqJ/+XvvQl+dznPqcvt5yPSvzqLw11aSYwr/4FqP6l76TPQ41aKS8v1/uh9kE53zF16n6OGDFCbr755pC/7BL5uFp1LFUd1WLXmbp0EW/7rwKM+nJXLRPTpk0LaYXpzj/+8Q/dahF8rJ2yr1act07bX/UPy9GjR3/in+FEO7a9RYjpoba2Np101V8WwdS8ujzhNL/4xS/ky1/+stx4441y5MiRT6x/wQUXyEUXXWT+q159FuoPTfDnob5AVNNmPH8e6rr5mDFj9H6o5lg1Dd4H1Y+kqKjI3Aen7uc999yjm47/9Kc/9YnjatWxVC116otw0qRJZp3JkyfrZfG0/4EAo760VFjt3J+tO+oSqjr/A8faKftq1XnrtP1VLamqxUkNpe9LxzYctvcudtoQ63vuuUf3/n7sscf0EOsRI0bYvm29Kb/61a/0qI3PfvazIUP0+vXrp9cPGDDAWLx4sR6WN3LkSD0895VXXtHDFzsPVz169Khx44036iGNqgd9PAzFDS5qP9R+XnzxxbqnvhpiqIbhBo6ZGs2iPovbbrtND8t95plnuh2WG+/7GVySkpL0kPlFixaFLHf6cVXbP378eF2U733ve/p1YESOVcdSDU1944039GgOVd58882YD009376mpKQYf/zjH/V+qFskBP8ZTktL0z9/ySWXGA8//LAxceJEfaw///nP65GUe/fujbt9/aT9tfK8jYf9/aTzWBW32200NjYa999/f5efd9qxlegX2zfAUeWBBx4wDh8+bJw5c8Z4/fXXQ4YlO6Wci7p3jFqvwswLL7xgVFdX69CmvhDVEMfhw4eHvE96erpRVlZm1NbWGk1NTfoPSOc6dpfAvULUfnz88cfGb3/7W2PMmDFdhmGr4bktLS3Giy++qL8AnbafwWXatGn6eF522WUhy51+XNWXV3fUPlh5LNX9R9asWaPDrirqdXZ2dtzsq/riOpfA/YDUPqn9V/up/q5Sw+yXL1/e5d4q8bCvn7S/Vp638bC/PTmPv/GNb+h96HzvFyceW4lySfK/AAAAcBT6xAAAAEcixAAAAEcixAAAAEcixAAAAEcixAAAAEcixAAAAEcixAAAAEcixAAAAEcixAAAAEcixAAAAEcixAAAAEcixAAAAHGi/w9Am1M6hWTI+QAAAABJRU5ErkJggg==" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 13 + }, + { + "cell_type": "code", + "id": "38e1a57f-4644-412d-b046-999db6a59205", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:41.587269Z", + "start_time": "2026-02-02T02:00:39.988830Z" + } + }, + "source": [ + "n_iter = 2_000\n", + "loss_history = np.empty(n_iter)\n", + "progress_bar = trange(n_iter)\n", + "draws = np.array(500, dtype=int)\n", + "learning_rate = 1e-3\n", + "for i in progress_bar:\n", + " loss, *grads = f_loss_dloss(draws, *opt_param_values)\n", + " loss_history[i] = loss\n", + " if any(np.isnan(d_loss).any() for d_loss in grads):\n", + " print(\"Got nan, getting out\")\n", + " break\n", + " opt_param_values = [\n", + " np.asarray(value - learning_rate * grad) for value, grad in zip(opt_param_values, grads)\n", + " ]\n", + "\n", + " if i % 50 == 0:\n", + " progress_bar.set_description(f\"Loss: {loss:.3f}\")\n", + "\n", + "optimized_params = dict(zip(init_param_values, opt_param_values))\n", + "optimized_params" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "{'sigma_loc': array(-1.08338307),\n", + " 'sigma_scale': array(-3.40979122),\n", + " 'beta_loc': array([-0.90914275, 0.27361567, -0.15888119]),\n", + " 'beta_scale': array([-3.82559357, -3.84190526, -3.72892214]),\n", + " 'alpha_loc': array(11.78181088),\n", + " 'alpha_scale': array(-3.84125139)}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 14 + }, + { + "cell_type": "code", + "id": "650c5e39", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:43.019891Z", + "start_time": "2026-02-02T02:00:42.959050Z" + } + }, + "source": [ + "window_size = 100\n", + "kernel = np.full(window_size, 1 / window_size)\n", + "plt.plot(np.convolve(loss_history, kernel, mode=\"valid\"))" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAQDZJREFUeJzt3QlYVlX+wPEfKIuyqLiBYEqKuaBSSJmDYSRUNqVUY5alpZPVjJOlM472r9GZSitzGUorp8aaxmpscTRLwy1rktwyFSs1d9kCFVnEl+3+n3OEN16WZHnh3b6fx/Pc+773vC/nnnvl/jj3nHPdRMQQAAAAO+Ju6wIAAABURYACAADsDgEKAACwOwQoAADA7hCgAAAAu0OAAgAA7A4BCgAAsDsEKAAAwO60FAfVpUsXycvLs3UxAABAPfj5+UlaWppzBigqOElNTbV1MQAAQAMEBwdfMkhxyAClouVE7SCtKAAAOE7riWpgqMu12yEDlApqBwlQAABwPnSSBQAAdocABQAA2B0CFAAAYHcIUAAAgHMFKDNmzBDDMGThwoXm95YtW6bfq5ySk5MtPufp6SmJiYmSlZUl+fn5smrVKj0iBwAAoFEByqBBg2TSpEmyZ8+eatvWrl0rgYGB5jRixAiL7YsWLZKEhAQZM2aMREdHi6+vr6xZs0bc3WnQAQAADQxQfHx8ZPny5fLggw/K2bNnq203mUySmZlpTpXz+Pv7y8SJE2XatGmyceNG+fbbb+Xee++V/v37y/DhwzkmAACgYQHK4sWL5ZNPPtEBRk2GDRumA5MDBw7I0qVLpWPHjuZtkZGR+hZPUlKS+b309HRJSUmRIUOGcEgAAED9J2q766675KqrrpKoqKgat6vbO++//74cP35cQkND5emnn5ZNmzbpwKSoqEjf8lEtLDk5ORafUwGN2lYTFdB4eXlZzEQHAACcV70ClJCQEPn73/8u8fHxOsioyYoVK8zr+/fvl507d+pg5ZZbbpGVK1fW+t1ubm66Q21NZs6cKbNnz65PUQEAgKvc4lGtIJ07d5Zdu3ZJcXGxTup2zqOPPqrXa+rkmpGRoQOUsLAw82vVGtK2bVuLfJ06ddKtKDWZO3eu7rtSkRjxAwCAc6tXgKL6nISHh0tERIQ57dixQ3eYVetlZWXVPhMQECBdu3bV/UwUFdyoWz1xcXHmPOrWjvrerVu31vhzVf6K5+7w/B0AAJxfvW7xqDlL1G2bygoKCuT06dP6fTW6R92K+fDDD3VA0r17d5kzZ45kZ2ebb+/k5ubKG2+8IfPnz9efO3PmjLz44ouyb98+2bBhg9iSm7u73PanRyX7xCn56t0PbFoWAABcmVWfZlxaWqqHC48bN07fwlFByubNm3XHWhXcVHj88celpKRE91dp1aqVbpm5//77a2yBaU79hw+T6+69S5fjXOZPkrLpC5uWBwAAV+UmIjX3TLVjahSPaolR/VHULR9runPWn+XaO0dJ8QWTLJn4ezmx17LFCAAANP31m6lbq/jomRfluy++Eg9vL/ntyy9K58u7N/AwAACAhiJAqaKstFTe/uNTcmLfd+LTrq089I9Ead81pMEVDAAA6o8ApQZFhYXyj0cel/RDh6VNp47y8OuJ0jawcwOqFwAANAQBSi3On8uVVx/8g/x09LgEdAmSh19/Sfw6tG9QJQMAgPohQPkF+afP6iDl9Kk06ditqzz8j0TxadumnlUMAADqiwDlEs5lZukgJSfzJwnseblMeu3v4u3nW++KBgAAdUeAUgdnTqXJaw8+Knmnz0hI3yvkwSULxLNVq3pUMwAAqA8ClDpSfVFem/So7pvSPaK/THjpBWlZ6QnLAADAeghQ6iH94GFZ+tBjciG/QMKuGST3L5wjLTw8rHg4AACAQoBSTyf3fy+v/36aFBVekD5Dh8i9z/9V3Fu04GwCAMCKCFAa4Og3e2TZlOlSUlQkA+Kul7uffUo/aBAAAFgHV9UGOpi8Q96a+n9SWlwiV91yo4x5+kmCFAAArIQApRG+2/I/eftPT0ppSYkMuu1muetvTxCkAABgBQQojbRv4xb59/S/6CAlauQtMnr2THFzUw+JBgAADUWAYgV712+W5TNm6yDl6oRfy29mzSBIAQCgEQhQrGTPZxvlnZl/1U9DvuaO2+SOv0wnSAEAoIEIUKzo23Ub5J0n/qaDlGvvHCW3/98frfn1AAC4DAIUK9v9aZK8++TTUlZWJkPuup0gBQCABiBAaQLfrPlM/vPUszpI+dWYOyRh5tSm+DEAADgtApQmsnP1p7Ji1hwdpETf8xsZOf2xpvpRAAA4HQKUJrTjv5/I+7Of0+vX3XeX3PanR5vyxwEA4DQIUJrY9pUfy4rZc/V6zLi75ZbHHmnqHwkAgMMjQGkG2z5cLR8+M0+vx04cJzf+7rfN8WMBAHBYBCjNZOt/PpL/Pr9Ir8c/MlFu+O345vrRAAA4HAKUZvTlv/8jH89/Wa+PmPKwHuEDAACqI0BpZp+/uVySXv2nXlcTuV316xubuwgAANg9AhQb+GzxP+TL5Sv0+pinn5R+1w+1RTEAALBbBCg2sur5RbJj1afSomVLGffiMxJ2zSBbFQUAALtDgGIjhmHoidz2bdwiLT095YHE5+Wy/n1tVRwAAOwKAYoNqYcKvv2np+Rg8nbxat1aHnxloQSG9bBlkQAAsAsEKDZWWlwsy6bMkGN79knrNv7y0GuLpH3XEFsXCwAAmyJAsQNFhYXy+u+mSdrBH8W/Ywd5+B+J0qZzR1sXCwAAmyFAsROFuXmydNIUyTp+UgKCg+ShpYni066trYsFAIDjBSgzZszQnT0XLlxo8f6sWbMkNTVVzp8/L5s3b5a+fS07f3p6ekpiYqJkZWVJfn6+rFq1SoKDg8XV5Z0+I689+KjkZGRK58u7y4OvLhRvXx9bFwsAAMcJUAYNGiSTJk2SPXv2WLw/ffp0mTp1qkyePFmioqIkIyND1q9fL76+vuY8ixYtkoSEBBkzZoxER0frbWvWrBF3dxp0zqZnyGuTpuhgpWvf3jLh5Xni4e3VuKMMAIADMuqbfHx8jAMHDhg33HCDsXnzZmPhwoXmbWlpacb06dPNrz09PY2zZ88akyZN0q/9/f0Nk8lkjB492pwnKCjIKCkpMeLj4+v08/38/AxFLRtSfkdIwb17Gc9sXW/M35ds/HbJfKNFy5Y2LxOJOuAc4BzgHOAckEbUQX2u3w1qsli8eLF88sknsnHjRov3Q0NDJSgoSJKSkszvFRUVyZYtW2TIkCH6dWRkpL7FUzlPenq6pKSkmPNUpfL7+flZJGeX+sNBeeP306So8IL0GTpE7p7zF3GjhQkA4CLqHaDcddddctVVV8nMmTOrbQsMDNTLzMxMi/fV64ptamkymSQnJ6fWPFWpn5Wbm2tOqn+LKzi6e6+8+fhMKSkulitvjpM7nvqTrYsEAID9BSghISHy97//Xe69914dZNRGdZytzM3Nrdp7Vf1Snrlz54q/v785uVKH2gNffS3LZ8zWk7pde+co+fXUybYuEgAA9hWgqNsznTt3ll27dklxcbFOw4YNk0cffVSvV7ScVG0J6dSpk3mb6jTr5eUlbdu2rTVPVeo2UV5enkVyJXuTNsn7s5/T69c/MFaG3T/W1kUCAMB+AhTV5yQ8PFwiIiLMaceOHbJ8+XK9fuTIEd2fJC4uzvwZDw8PiYmJka1bt+rXKrhRAUflPCqgUd9bkQfVbf/vGlk9L1Gv3zptslw5Ip5qAgA4tUb1yK06ikeN4FGjdkaNGmX069fPWL58uZGammr4+vqa8yxZssQ4ceKEERsba0RERBgbNmwwdu/ebbi7u1u9F7CzpdumT9Eje57/5guj59WRNi8PiTrgHOAc4BzgHJA61kF9rt8trR3tvPDCC9KqVStZsmSJtGvXTrZt2ybx8fF6QrYKjz/+uJSUlMiKFSt0XtUyc//990tZWZm1i+N0Pp6XKG06dZSIG2+Q+xc9J4vvf1jSDx62dbEAALAqt/JIxaGoYcZqNI/qMOtq/VGUlp6eMum1RdJj0JVyLjNLEsf+VnIyf7J1sQAAsNr1m6lbHVBJUZEsm/JnST90WD9UcMJL88SzVStbFwsAAKshQHHghwu+8fs/6inxg/v0knvmztJDtQEAcAYEKA7+3B7VklJsMkn/G2JkxJSHbV0kAACsggDFwR3fkyIrZs3R67ETx0nUyBG2LhIAAI1GgOIEvvkkSda/tkyv3zlrhoReNdDWRQIAoFEIUJzEZ4v/IXuSNklLDw95YNFzEhDSxdZFAgCgwQhQnIR6jtG7//c3OZHynfi0aysTX35RvH19bF0sAAAahADFiRRfMMmyR/+s50QJ7BEq9817RtxbtLB1sQAAqDcCFCeTm5Ut//zDn8R0vlB6Rw+W2/70qK2LBABAvRGgOKHU7w/Ku0/8Va8PHTtarh2dYOsiAQBQLwQoTmrfxi3yyaJX9HrCzKkSNjjK1kUCAKDOCFCc2KY3/iU7Vn0qLVq2lPte+Ju0Cwq0dZEAAKgTAhQn98HfnjeP7Bm34Fn9oEEAAOwdAYoLPFjwX1P/TwrO5shl4X317R4AAOwdAYqLPLPn33/+i5SVlcngO0fKNbffausiAQDwiwhQXMTB5B2y7qWlej3hiWkS0re3rYsEAECtCFBcrNNsyuYvxMPLS8YvnCM+bdvYukgAANSIAMXlpsN/WrKOn5SALkEy9vm/ips7pwAAwP5wdXIxF/Ly5c3HZ+qZZq8Yco3c8Ntxti4SAADVEKC4oIxDh+WjZ1/U6zf+7rdy+aArbV0kAAAsEKC4qJ2rP5Udqz7RDxO897m/6nlSAACwFwQoLky1omQcPiptOneUe+bMEjc3N1sXCQAAjQDFhRUVXpC3//ikXqonH18/4V5bFwkAAI0AxcVl/HhEVs5doNdvmjxJQq8cYOsiAQBAgAKR7Ss/lm8++Uw/VFANPfb286VaAAA2RQsKtA/+9oJknziln3h8+xPTqBUAgE0RoEAznT8vy2fOltKSEon89U1y5c1x1AwAwGYIUGB2Yu9+2bD0Tb1+x5N/kraBnakdAIBNEKDAwoZ/vCnH96RIK38/ufvZpxh6DACwCQIUWCgrKZV3nvirvuXT8+pIiRl3NzUEAGh2BCioRnWWXfX8Ir1+85SHJahXT2oJANCsCFBQo20ffSwpm7ZISw8PGfvcbGnh4UFNAQDsM0B5+OGHZc+ePXLu3Dmdtm7dKjfddJN5+7Jly8QwDIuUnJxs8R2enp6SmJgoWVlZkp+fL6tWrZLg4GDr7RGsZsXs5yQ3+7QEhfWQuIcfoGYBAPYZoJw6dUpmzJghgwYN0mnTpk06wOjbt685z9q1ayUwMNCcRowYYfEdixYtkoSEBBkzZoxER0eLr6+vrFmzRtzdacyxNwVnc8xPPY6dcJ90uSLM1kUCALgQozHp9OnTxoQJE/T6smXLjJUrV9aa19/f3zCZTMbo0aPN7wUFBRklJSVGfHx8nX+mn5+foahlY8tPunQdjJv/rDF/X7Lx+H/eNNxbtqDOOG84BzgHOAc4B4yG1EF9rt8NbrZQLR533XWX+Pj4WNzGGTZsmGRmZsqBAwdk6dKl0rFjR/O2yMhIfYsnKSnJ/F56erqkpKTIkCFDav1Z6jN+fn4WCc1n5Zz5cv5croT0vUKGjR9L1QMAmly9A5Tw8HDJy8sTk8kkr776qr5d8/3335tv74wdO1ZiY2Nl2rRpEhUVpW8DqQBDUbd81OdycnIsvlMFNGpbbWbOnCm5ubnmlJqaWv89RYPlnT4j/y0f1RP/yATpFNqN2gQANLl6Nc94eHgYPXr0MCIjI405c+YYP/30k9GnT58a8wYGBupbOgkJCfr13XffbVy4cKFavqSkJOOVV16p9Wd6enrq5qCK1KVLF27x2KB58bevLNC3eib/6zXDzd2dJl6aeDkHOAc4BzgHDLu5xVNcXCyHDx+WXbt2yRNPPKFH9UyZMqXGvBkZGXL8+HEJCwszv/by8pK2bdta5OvUqZNuRalNUVGRbrWpnND8Pvjr83KhoEBCrxwgvxpzB4cAANBkGj10xs3NTQcdNQkICJCuXbvqfiaKCmpUsBEX9/OD6NStHXXbSA1Zhn3LyciUNQsW6/URUx6RgOAgWxcJAODE6tw08+yzzxrR0dFGt27djPDwcOOZZ57RI3CGDx9u+Pj4GPPmzTMGDx6st8fExBhfffWVcfLkScPX19f8HUuWLDFOnDhhxMbGGhEREcaGDRuM3bt3G+71uGXAKB7bNau6ubkZv1u2RN/qmfTqQpp3ad7lHOAc4BzgHDCa6Ppd9y9+/fXXjaNHj+p+JJmZmcb69et1cKK2eXt7G+vWrdPvq34nx44d08OOQ0JCLL7Dy8vLSExMNLKzs42CggJj9erV1fIQoNj3f/gOl4UYz+/aooOU/jfE2Lw8JOqAc4BzgHNAnC5AcStfcShqmLEazePv709/FBu5afIkiXvoATmTli7zRt0jRYUXbFUUAIATXr+ZvhUNsvH1t+RMaroEdAmSG347nloEAFgVAQoapPiCSVa9cHFulGH33yMdLguhJgEAVkOAggZL2fSFfP/lVmnp6SkJM6dRkwAAqyFAQaOsnLtQSoqKpHf0YAmPjaE2AQBWQYCCRjl98pRsfnO5Xh/55yni4V3znDgAANQHAQoabeM/3tKjeVSH2Zhxd1OjAIBGI0CBVTrMfrJwiV6PnXif+LUPoFYBAI1CgAKr+HbdBjm+J0W8WreWGyc/SK0CABqFAAVWs3peol5ek3CrBPa8nJoFADQYAQqs5tiefbInaZO4t2ght077AzULAGgwAhRYleqLUlJcrIcdXzHkGmoXANAgBCiwqtOnUuWrdz/Q6zc/+hC1CwBoEAIUWN3G1/8lpvPnpWu/PtJvWDQ1DACoNwIUWF3B2Rz5cvn7ev3G3z8obm7qodkAANQdAQqaxJa33pEL+QUS3LuXhMdeRy0DAOqFAAVN4vy5XPni3//R67SiAADqiwAFTeaLt9+Twtw8CQrrIQPiY6lpAECdEaCgyajgZMvb7+n1+Ecmips7pxsAoG64YqBJffnv/+jbPYE9QiXipuHUNgCgTghQ0KRUR9nP33xHr8c/PEHPMgsAwKUQoKDJ/e+d9/XQ406h3eTKEfHUOADgkghQ0OTUpG2b31yu12Mn3se8KACASyJAQbNIXrFSCvPydV+UvjG/otYBAL+IAAXN1hclecVHev36CfdR6wCAX0SAgmbzxb9XSLHJJKFXDtAJAIDaEKCg2eRln5adq9fqdVpRAAC/hAAFzerzN5dLWVmZfspx58u7U/sAgBoRoKBZZZ84JSmbvtDr0WNHU/sAgBoRoKDZfbl8hV4OuvVmaeXvzxEAAFRDgIJmd2Tnbkn94aB4tvKWwXfcyhEAAFRDgAKbtqL86u47mf4eAFANAQpsYven6yXv9BlpFxQo4bHXcRQAABYIUGATJUVF8vWHq/T6kNG3cxQAAA0PUB5++GHZs2ePnDt3TqetW7fKTTfdZJFn1qxZkpqaKufV81c2b5a+fftabPf09JTExETJysqS/Px8WbVqlQQHB9enGHAS2z5YrYcchw0eJB0uC7F1cQAAjhqgnDp1SmbMmCGDBg3SadOmTTrAqAhCpk+fLlOnTpXJkydLVFSUZGRkyPr168XX19f8HYsWLZKEhAQZM2aMREdH621r1qwRd3cac1zN2fQM+eF/yXp98B0jbV0cAICdMRqTTp8+bUyYMEGvp6WlGdOnTzdv8/T0NM6ePWtMmjRJv/b39zdMJpMxevRoc56goCCjpKTEiI+Pr/PP9PPzMxS1bGz5Sbatg37Doo35+5KNv2751Gjh4cHx4JzkHOAc4Bxw4nPArx7X7wY3W6gWj7vuukt8fHwkOTlZQkNDJSgoSJKSksx5ioqKZMuWLTJkyBD9OjIyUt/iqZwnPT1dUlJSzHlqoj7j5+dnkeAcvv8yWXIyfxLfgHbSn86yAIBy9Q5QwsPDJS8vT0wmk7z66qv6ds33338vgYGBentmZqZFfvW6Yptaqs/l5OTUmqcmM2fOlNzcXHNSfVzgHMpKS2X7Rx/r9cG/GWXr4gAAHDVAOXDggERERMjgwYPllVdekbfeekv69Olj3m4YqmXmZ25ubtXeq+pSeebOnSv+/v7mRKda57Lto491oBJ2zSBp35XOsgCABgQoxcXFcvjwYdm1a5c88cQTelTPlClTdIdYpWpLSKdOncytKiqPl5eXtG3bttY8NVG3ilSrTeUE55GTkSkHk3fo9UG3Wo4KAwC4pkYPnVGtHyroOHr0qO5PEhcXZ97m4eEhMTExejiyooIaFWxUzqMCGnXbqCIPXNPOj9fqZSQBCgCgXJ173z777LNGdHS00a1bNyM8PNx45pln9Aic4cOH6+1qBI8atTNq1CijX79+xvLly43U1FTD19fX/B1LliwxTpw4YcTGxhoRERHGhg0bjN27dxvu7u5N0guY5Bh14OHtZTz79QY9oif0ygE2Lw+JOuAc4BzgHBCr10F9rt8t6xOmde7cWd5++209WkdN1LZ37149UduGDRv09hdeeEFatWolS5YskXbt2sm2bdskPj5eT8hW4fHHH5eSkhJZsWKFzrtx40a5//779YRdcF3FF0yyN2mzXJ3wa4m87WY5unuvrYsEALAht/JIxaGoYcZqNI/qMEt/FOfRI+oq+d0/F0thbp7Mvv7Xejp8AIDzqM/1m+lbYTeO7NytZ5dt5e8n/a4fauviAABsiAAFdkMNNd+15jO9HnHTcFsXBwBgQwQosCvfrrvYn6nP0GvFq3VrWxcHAGAjBCiwK+kHf5Sfjh4XDy8v6Tss2tbFAQDYCAEK7M63n23Uy4gbY21dFACAjRCgwO7sKQ9QekdfK14+3OYBAFdEgAK7k/HjEck4fFRaenoymgcAXBQBCuzS3qRNejkw7npbFwUAYAMEKLBL+zZu0cte114jLb28bF0cAEAzI0CBXUo7cEhP2ubZylvCro60dXEAAM2MAAV2a//n/9NLZpUFANdDgAK79V15gNI35lfi5qYeGwUAcBUEKLBbP+74Ri4UFEibTh0luM8Vti4OAKAZEaDAbpUWF8uBr7bpdW7zAIBrIUCBY/RDYdp7AHApBCiwaz98uVXKysokuHcv8e/U0dbFAQA0EwIU2LWCnHNyav8Per3X4ChbFwcA0EwIUGD3DiZv18te1xKgAICrIECB3Tv49Q69DKMFBQBcBgEK7N6xb/eJ6Xyh+HdoL4FhPWxdHABAMyBAgUMMNz7yzbd6nds8AOAaCFDgWP1QuM0DAC6BAAUO4VB5P5TLI6+UFh4eti4OAKCJEaDAIaQfPCy52afFq3Ur6T4w3NbFAQA0MQIUOFwrSq9rr7Z1UQAATYwABQ7j0Ladetkj6ipbFwUA0MQIUOAwjuzao5ddw/tISy8vWxcHANCECFDgME6fPCXnfsqSlh4ecln/vrYuDgCgCRGgwKEc/eZiK8rlVw20dVEAAE2IAAUO5QgBCgC4BAIUOJQjuy7OKNstor+4t2hh6+IAAJoIAQocSsaPR6QwN0+8fXykyxU9bV0cAIA9BCgzZsyQ7du3S25urmRmZsrKlSulV69eFnmWLVsmhmFYpOTkZIs8np6ekpiYKFlZWZKfny+rVq2S4OBg6+wRnJpRViZHv92r10OvirB1cQAA9hCgxMTEyOLFi2Xw4MESFxcnLVu2lKSkJGndurVFvrVr10pgYKA5jRgxwmL7okWLJCEhQcaMGSPR0dHi6+sra9asEXd3GnRwaUe/uRigMKMsADg3o6GpQ4cOhjJ06FDze8uWLTNWrlxZ62f8/f0Nk8lkjB492vxeUFCQUVJSYsTHx9fp5/r5+emfq5aNKT/JMeug59WRxvx9ycb/rfvI5mUhUQecA5wDnANS5zqoz/W7UU0Wbdq00cszZ85YvD9s2DB9C+jAgQOydOlS6dixo3lbZGSkvsWjWl4qpKenS0pKigwZMqTGn6Py+/n5WSS4rpP7v5eysjIJCA4S3/btbF0cAEATaFSAsmDBAvnyyy9l//79Frd3xo4dK7GxsTJt2jSJioqSTZs26SBDUbd8TCaT5OTkWHyXCmjUtprMnDlT93upSKmpqY0pNhycqeC8/HTkmF6/LLyfrYsDALCnAOXll1+WAQMGyN13323x/ooVK+TTTz/VQYvqV3LzzTfrjrS33HLLL36fm5ub7lBbk7lz54q/v7850aEWJ/Z9pyvhsgHMKAsAzqhBAYoagXPbbbfJ9ddff8nWjIyMDDl+/LiEhYWZX3t5eUnbtm0t8nXq1Em3otSkqKhI8vLyLBJc2/F9F1vtuvWnBQUAnFG9A5SXXnpJbr/9dn0L59ixi83svyQgIEC6du2q+5kou3bt0gGHGgVUQd3aCQ8Pl61bt9a3OHBRJ/ZeDFC69uujW98AAM6nzr1vFy9ebJw9e9a47rrrjM6dO5uTt7e33u7j42PMmzfPGDx4sNGtWzcjJibG+Oqrr4yTJ08avr6+5u9ZsmSJceLECSM2NtaIiIgwNmzYYOzevdtwd3e3ei9gknPWgXuLFsbc7Zv1aJ5Ood1sXh4SdcA5wDnAOSDWvn7X/aSqzfjx4/V2FaisW7fOyMzM1EOJjx07pocdh4SEWHyPl5eXkZiYaGRnZxsFBQXG6tWrq+Wx4g6SnLQOfv/mKzpAGXTbCJuXhUQdcA5wDnAOiFWv327lKw5FDTNWo3lUh1n6o7iuW6f9QYbdf4989d6H8tGzL9q6OAAAK16/mboVDutEysWRPCF9e9u6KAAAKyNAgcNK/f6AXnbp1ZMnGwOAkyFAgcM6fTJVLuQXiIe3l3QK7Wbr4gAArIgABQ5LTeyXeuCgXg/ubflUbQCAYyNAgUNL/b48QOl7ha2LAgCwIgIUOEeAQgsKADgVAhQ4tNQfDpgDFGaUBQDnQYACh5Z55JgUm0zSys9XAoK72Lo4AAArIUCBQysrKZX0Q4f1enAfOsoCgLMgQIHDS/2hvB9KHzrKAoCzIECBw0v9riJAoQUFAJwFAQqcpqNsCC0oAOA0CFDg8FQflLLSUvFrHyD+HTvYujgAACsgQIHDK75g0qN5FOZDAQDnQIAC5+ooy4yyAOAUCFDgFNJ+OGR+sjEAwPERoMAppB38US8JUADAORCgwCmkHbjYgtL+shDxbOVt6+IAABqJAAVOoeBsjuRmZYu7u7sEhvWwdXEAAI1EgAKnkXaA2zwA4CwIUOA00g6Wd5S9IszWRQEANBIBCpxGOh1lAcBpEKDA6W7xBDHUGAAcHgEKnMZPx45LSVGRePv6SEBwkK2LAwBoBAIUOI2yklLJPHxxynv6oQCAYyNAgVN2lOU2DwA4NgIUOBWGGgOAcyBAgXOO5GGoMQA4NAIUOOWU9x30lPetbF0cAEADEaDAqRTknJNzmVl6PagXU94DgKMiQIHTSTtUMeU9M8oCgKMiQIHTSS+/zUMLCgC4SIAyY8YM2b59u+Tm5kpmZqasXLlSevXqVS3frFmzJDU1Vc6fPy+bN2+Wvn37Wmz39PSUxMREycrKkvz8fFm1apUEBwc3fm+AyiN56CgLAK4RoMTExMjixYtl8ODBEhcXJy1btpSkpCRp3bq1Oc/06dNl6tSpMnnyZImKipKMjAxZv369+Pr6mvMsWrRIEhISZMyYMRIdHa23rVmzRtzdadBB46WVj+RRLShubm5UKQA4KKOhqUOHDoYydOhQ83tpaWnG9OnTza89PT2Ns2fPGpMmTdKv/f39DZPJZIwePdqcJygoyCgpKTHi4+Pr9HP9/Pz0z1XLxpSf5Jx14N6ihfHczs+N+fuSjYCQLjYvD4k64BzgHOAckHpfvxvVZNGmTRu9PHPmjF6GhoZKUFCQblWpUFRUJFu2bJEhQ4bo15GRkfoWT+U86enpkpKSYs5Tlcrv5+dnkYDalJWWSsbho3qdjrIA4JgaFaAsWLBAvvzyS9m/f79+HRgYqJeqf0pl6nXFNrU0mUySk5NTa56qZs6cqfu9VCTVvwWo24RtPakoAHClAOXll1+WAQMGyN13311tm2Go1pmfqX4AVd+r6pfyzJ07V/z9/c2JDrWoa0dZnskDAC4UoKgROLfddptcf/31Fq0ZqkOsUrUlpFOnTuZWFZXHy8tL2rZtW2ueqtRtory8PIsE1GVGWVpQAMBFApSXXnpJbr/9domNjZVjxy4+2r7C0aNHdX8SNcKngoeHhx79s3XrVv16165dOuConEcFNOHh4eY8gLVu8XToGiJePj+PMgMAOI469y5evHixHpFz3XXXGZ07dzYnb29vcx41gkflGTVqlNGvXz9j+fLlRmpqquHr62vOs2TJEuPEiRNGbGysERERYWzYsMHYvXu34e7uXqdyMIqHHvF1OU+e2rBKj+TpPrA/Peg5ZzgHOAc4B8T2dVDP63fdv7g248ePt8g3a9YsPdy4sLDQ+Pzzz3WgUnm7l5eXkZiYaGRnZxsFBQXG6tWrjZCQkKbaQZKL1sHExS/qAOXa0Qk2LwuJOuAc4BzgHJB6Xb/dylccihpmrEbzqA6z9EdBbUZMeURu+O042bpipXz49AtUFAA40PWbqVvh/B1lezHUGAAcDQEKnD5AYcp7AHA8BChwWtknTkmxySRerVtLQAgPowQAR0KAAuee8v7HI3q9S68eti4OAKAeCFDgEjPKdrkizNZFAQDUAwEKnBrP5AEAx0SAAhfpKMtIHgBwJAQocGppBw/rZfuQYKa8BwAHQoACp1aYmys5GRcfQsl8KADgOAhQ4DIdZbnNAwCOgwAFrjOjLCN5AMBhEKDAdUby0FEWABwGAQqcXlp5gBIY1kPc3NTzMQEA9o4ABU4v6/hJKb6gprxvJe27MuU9ADgCAhQ4PaOsTNJ/vDjcmH4oAOAYCFDgEtIZyQMADoUABS4h7WDFSB5mlAUAR0CAAtd6aGAvHhoIAI6AAAUuIf3QxT4oAcFB4u3rY+viAAAugQAFLqEwN0/OpKXrdWaUBQD7R4ACl+soy0geALB/BChwGWmHKp7J08PWRQEAXAIBClwGHWUBwHEQoMDlHhoYpKa8d+fUBwB7xm9puIzTJ1OlqPCCeLbyZsp7ALBzBChwrSnvy4cb01EWAOwbAQpcCjPKAoBjIECBS0k/WN6CEsaU9wBgzwhQ4JodZXkmDwDYNQIUuJT0gxfnQgnoEiSt/P1sXRwAQC0IUOBSLuQXyJnU8invw5iwDQDsFQEKXA4dZQHA/hGgwOUwoywAOGGAMnToUFm9erWkpqaKYRgycuRIi+3Lli3T71dOycnJFnk8PT0lMTFRsrKyJD8/X1atWiXBwcGN3xugHv1QeKoxADhRgOLj4yN79uyRyZMn15pn7dq1EhgYaE4jRoyw2L5o0SJJSEiQMWPGSHR0tPj6+sqaNWvEnenH0YwjeQJ7Xs6U9wBgp1rW9wPr1q3T6ZeYTCbJzMyscZu/v79MnDhR7rvvPtm4caN+795775WTJ0/K8OHDJSkpqb5FAuo95b3pfKF4tW4lHbt1lZ+OHqcGAcAV+qAMGzZMBygHDhyQpUuXSseOHc3bIiMj9S2eyoFIenq6pKSkyJAhQ2r8PpXfz8/PIgENpW47ZpRPec9tHgBwkQBF3d4ZO3asxMbGyrRp0yQqKko2bdqkgwxF3fJRLSw5OTkWn1MBjdpWk5kzZ0pubq45qf4vQGOklfdD6dKLGWUBwClu8VzKihUrzOv79++XnTt3yvHjx+WWW26RlStX1vo5Nzc3/ZdtTebOnSsLFiwwv1YtKAQpaIzUHw7qZXCfXlQkALjiMOOMjAwdoISFhZlfe3l5Sdu2bS3yderUqdZ+K0VFRZKXl2eRgMY4tf8HvQzp25uKBABXDFACAgKka9euup+JsmvXLh1wxMXFmfOoWzvh4eGydevWpi4OYL7FU1JcLH7tA6RtYGdqBQAc/RaPGmbcs+fP9+1DQ0Nl4MCBcubMGZ1mz54tH374oQ5IunfvLnPmzJHs7Gzz7R3Vh+SNN96Q+fPny+nTp/VnXnzxRdm3b59s2LDBunsH1KK0uFjSDx2Wrn17S9d+vSUno+bWOwCA7Rj1STExMUZNli1bZnh7exvr1q0zMjMzDZPJZBw7dky/HxISYvEdXl5eRmJiopGdnW0UFBQYq1evrpbnl5Kfn5/+mWpZ3/KTqIOKc+DOv/zZmL8v2bj50Yc5LzgvOAc4BzgHpOnroD7Xb7fyFYeiOsmqlhg1pwr9UdBQ19xxm4yePVMOJm+X1yZNoSIBwI6u3zyLBy7L3FG2Hx1lAcDeEKDAZWX8eESKTSZp7e8v7UN4FhQA2BMCFLis0pIS85ONVUdZAID9IECBSzv1XcVtnj62LgoAoBICFLi0k/u/10taUADAvhCgwKWdrDSjrHrcAgDAPhCgwKX9dOSYFBVeEG9fH+nY/TJbFwcAUI4ABS6trLTU3A/lsv79bF0cAEA5AhS4vGN79uk66B7R3+XrAgDsBQEKXN7xPSm6DghQAMB+EKDA5VW0oHTuESrefr4uXx8AYA8IUODy8k+flewTp8Td3V260Q8FAOwCAQogIse+pR8KANgTAhRABSh0lAUAu0KAAlRqQVFDjd3c+W8BALbGb2Kg/MnGFwoK9IRtgT1DqRMAsDECFEBEjLIyObF3v66L7gMHUCcAYGMEKECV2zyXRw6kTgDAxghQgHJHdn2rlz0GXUWdAICNEaAA5Y5+u09KioqkTeeO0qFbV+oFAGyIAAUoV2IyyfHyfig9r46kXgDAhghQgEp+3L5LL3tGcZsHAGyJAAWo5Mcd3+hlDwIUALApAhSgypONiy+YxL9De+l8eXfqBgBshAAFqKS0uNg83Jh+KABgOwQoQBUHv96ul72jr6VuAMBGCFCAKr7b8pVehl0zSDy8vagfALABAhSghufynElN18FJ2DVR1A8A2AABClCD77b8Ty/7xvyK+gEAGyBAAX7hNk/f6whQAMAWCFCAWuZDMZ0/r6e9D+7TizoCgGZGgALUMtz4wFfb9PqAuFjqCADsPUAZOnSorF69WlJTU8UwDBk5cmS1PLNmzdLbz58/L5s3b5a+fftabPf09JTExETJysqS/Px8WbVqlQQHBzduTwAr2/PZRr288ubh1C0A2HuA4uPjI3v27JHJkyfXuH369OkydepUvT0qKkoyMjJk/fr14uvra86zaNEiSUhIkDFjxkh0dLTetmbNGnF3p0EH9mP/lv/p2zztQ4LlsgH9bF0cAHA5RkOTMnLkSIv30tLSjOnTp5tfe3p6GmfPnjUmTZqkX/v7+xsmk8kYPXq0OU9QUJBRUlJixMfH1+nn+vn56Z+tlo0pP4k6uNQ5cPecvxjz9yUbo2Y8zvnC+cI5wDnAOSCNq4P6XL+t2mQRGhoqQUFBkpSUZH6vqKhItmzZIkOGDNGvIyMj9S2eynnS09MlJSXFnKcqld/Pz88iAc1h99r1ehlx03Bxb9GCSgeAZmLVACUwMFAvMzMzLd5Xryu2qaXJZJKcnJxa81Q1c+ZMyc3NNSfVvwVoDgeTt0vB2Rzxax8gva5l0jYAaC5N0ulDdZ6tzM3Nrdp7Vf1Snrlz54q/v7850aEWzaWspFR2rflMr187OoGKBwBHDFBUh1ilaktIp06dzK0qKo+Xl5e0bdu21jxVqdtEeXl5FgloLltXfGSetK1dUM2tfAAAOw5Qjh49qvuTxMXFmd/z8PCQmJgY2bp1q369a9cuHXBUzqMCmvDwcHMewJ5kHTshB7/eofugDP7NKFsXBwBcQoOGGQ8cOFCnio6xar1r167mIcRPPPGEjBo1Svr16ydvvvmmng/lnXfe0dtVH5I33nhD5s+fL7GxsRIRESH//ve/Zd++fbJhwwZr7x9gFVvf+1Avr7n9Vmnh4UGtAkAzqNcQoZiYGKMmy5YtM+eZNWuWHm5cWFhofP7550a/fv0svsPLy8tITEw0srOzjYKCAmP16tVGSEhIkwxTIlEH1jgH3Fu0MJ7asEoPOb464VbOK84rzgHOAc4Badphxm7lKw5FDTNWLTGqwyz9UdBcrhs3Rkb+aYpkHT8pL4y8W8pKS6l8AGii6zdTtwJ19PX7q/SQ447dusrAeJ7PAwBNiQAFqKOiwkL5YvkKvX7Dg+P10HgAQNMgQAHq4X/vvC+FefkSFNZDBt54A3UHAE2EAAWohwt5+fL5m8v1+s2PPiQtWrak/gCgCRCgAPX0xdvvSW72aenQNYR5UQCgiRCgAPVUVHhBkpa8odfjH54grfx5eCUAWBsBCtAA21aulozDR8U3oJ3c/IeHqEMAsDICFKCBDxH86Jl55ocIdu3Xh3oEACsiQAEa6PDO3bLz47Xi7u4uv5k1gynwAcCKCFCARvh4/ktSkHNOgvv0klsee4S6BAArIUABGiH/9Fl578ln9HrMuLulz9Ah1CcAWAEBCtBI3235n3zx9n/0+phnnhT/Th2pUwBoJAIUwArWLFwsp747oEf1jJ07S9xbtKBeAaARCFAAKygtLpa3pz8lpvPnpefVkXLnX/5MvQJAIxCgAFaSffykLJ8xW8pKS+Wa22+VXz/+e+oWABqIAAWwov2bv5QVs+fq9esn3KufegwAqD8CFMDKdvz3E/n4xZf0+ohHH5Zb//gHcXNzo54BoB4IUIAm8Plb78jqeYl6fdj4e2Tsc7PFw9uLugaAOiJAAZrIln+9K8tnzpbS4hK5ckS8PLr8delwWQj1DQB1QIACNKFv1nwmrz74B8nNPi1devWUx95bJoPvHMktHwC4BAIUoIkd2fWtLBx9v1628vPVz+35/VuvSmDPy6l7AKiF6rlniIPx8/OT3Nxc8ff3l7y8PFsXB6gTNXnbr8bcITf9YZJ4+/joWz/bV62Rjf94S86mZVCLAJyeXz2u3wQoQDNr07mjjJoxVQYMH6Zfq0BFPRX5q/c+kNTvD3I8ADgtAhTAAXSPGCDxj0yQK4ZcY37v5Hc/yI6Va2Tv+s2Sd/qMTcsHANZGgAI4kO4D+0v0PXdK/+HDpKWnp36vrKxMju7eIz98mSwHk7frlhXDcLi7sQBggQAFcEA+bdvIVb++Sa68abh0Gxhusa3gbI4c/XavHN+TIse+3acDFvXcHwBwJAQogINrG9hZ+l0/VHpdGyU9oyLF29enWp7sk6ck/eBhyTxyTM6kpsnZtHQ5fSpdctIzpLSkxCblBoBfQoACOBH3li2ka78+ulVF3Q7qNqCfDmBqo24P5f6UJadV0JKaIbnZ2bo/S172Gck/fUbPyVKQkyOFufn6KcwA0FwIUAAn17qNvwT16ildrgjTs9MGBAdJQHAXCegSJJ6tvOv8PcUXTFKYlycX8gukMDdPCvPy9Wu1NBWcl2KTSecpKSrS6yXqtano4usLJikuKip/r/x9U0W+IikuMklZSWmT1gNgT/Qzt9zcypfqX+XXalmeRyrlUe+7qynJKm2/1HdIlTwV28zv68Jc/O7y54BZfm95nvJy/Pw97hbb1f/p9IM/WrWOCFAAF+bbvp05WFGBi1/79uLXvp34tg8Qv/LUqo2/uOtfik1L3WpSwUxJUbGUlZaKUVYmRpkhZWVq3RDDKH+tthmG3q5agCryqe1lpWU/56v4nNqu8peWSln554zyfGUV2yuSYVh+R6Wf1Vwdj+v9sEiLi8rFC1LVC9fP31vlYleeT28v/y6Lz1Yuj/kiVr5e5cL3cxkubqt84fv5Z1S5SFa+sF7yol31+8vfcy+/UFb5fM0X29ov+jWWrfzzdSnXzxfsqkHAz6+b4/+Rrfx09Lg8f9sYmwUoLa36kwHYXP7pszqd2Lu/1jzql6uXT2tp5ecn3n6+0srfT1r5+oi3n5+e7Va9VpPJtfTyFA8vL/Hw8tQjjNQDD/VSv+elt198ffG9ivwVWrRsqZNX62baecAJlZWViaigWgXU+p9h+bpiXSzf03kufuDnoL48KDdvr+E7K947l5ll0/0mQAFckPoFpG7rqCTp1v1uFfy0KA9aKoIaD0/Pi3/Ruqu/OFvopf7rtPJrvd1d/0V68fXP+dwrvzbncf95ewv3mr9P/YXbooX5L+rK+Sq+p0GtKA34TH0/YS5XlQuR+SJk3vZznoqLkfnzFheun8tt1Haxq/qd5m1VvrPy+1UveNV+dqXvNF9AL76u/H2qdaum8v5c1ovf/XN9XPyMUeU7Lb+jhgu4bpmzLJ/Fz6i275XKVqVcFgFApe+svq/Vj19dvgcXT2erpVmzZhlVpaenV8uTmppqnD9/3ti8ebPRt2/fev0MPz8//b1qae3yk6gDzgHOAc4BzgHOAWmSOqjP9btJbp6lpKRIYGCgOfXv39+8bfr06TJ16lSZPHmyREVFSUZGhqxfv158fX0JFgEAgNYkAUpJSYlkZmaaU3Z2tnnbY489Js8++6ysXLlS9u/fL+PHj5fWrVvLPffc0xRFAQAADqhJApSwsDBJTU2VI0eOyLvvviuhoaH6fbUMCgqSpKQkc96ioiLZsmWLDBkypNbv8/T01D1/KycAAOC8rB6gbNu2TcaNGyc33nijPPjgg/oWz9atWyUgIECvK6pVpTL1umJbTWbOnKmHJVUkFfwAAADnZfUAZd26dfLRRx/pfigbN26UW265Rb+vbuVUqNo7WfWw/6Uey3PnztVjpitScHCwtYsNAADsSJPPMHP+/HnZt2+fvu2jOsQqVVtLOnXqVK1VpTJ1G0hN6FI5AQAA59XkAYrqP9KnTx9JT0+Xo0eP6mVcXJx5u4eHh8TExOjbQAAAAE0yUdu8efPk448/lhMnTuiWkSeffFLflnnrrbf09kWLFskTTzwhhw4d0kmtq1aWd955hyMCAACaJkAJCQnRI3c6dOggWVlZ8vXXX8vgwYN1wKK88MIL0qpVK1myZIm0a9dOd6qNj4+X/Px8axcFAAA4KPXkJIebT7c+DxsCAACOd/123scwAgAAh0WAAgAA7A4BCgAAcP5Oss2JKe8BAHDO63ZLR95BprwHAMAxr+OX6iTrkKN4lC5dujTJCB5VaSrwUdPpO/sIIVfaV1fbX/bVeXFsnZOrHde0tDTnbEFR6rJzjeFKU+q70r662v6yr86LY+ucXOG45tVx/+gkCwAA7A4BCgAAsDsEKFWYTCaZPXu2Xjo7V9pXV9tf9tV5cWydkysd17py2E6yAADAedGCAgAA7A4BCgAAsDsEKAAAwO4QoAAAALtDgFLJI488IkeOHJHCwkLZuXOnREdHi6OZMWOGbN++XXJzcyUzM1NWrlwpvXr1ssizbNkyMQzDIiUnJ1vk8fT0lMTERMnKypL8/HxZtWqVnuHQnsyaNavafqSnp1fLo2ZnPH/+vGzevFn69u3rcPtZ4ejRo9X2V6WXX37Z4Y/r0KFDZfXq1fpYqXKPHDmyWh5rHMu2bdvKv/71L8nJydFJrbdp00bsZV9btmwpzz33nOzdu1fvg8rz1ltvSVBQkMV3qP2veqzfffddu9vXuhxba5239rC/l9rXmv7/qvTHP/7RIY9tc1CjeFw+jR492jCZTMbEiRON3r17GwsXLjTy8vKMrl27OlTdrF271hg/frzRt29fY8CAAcbHH39sHDt2zGjdurU5z7Jly4xPP/3U6Ny5szm1a9fO4nuWLFlinDx50rjhhhuMiIgIY+PGjcbu3bsNd3d3m+9jRZo1a5axb98+i/3o0KGDefv06dONc+fOGQkJCUa/fv2Md99910hNTTV8fX0daj8rktq3yvuqyqzExMQ4/HG96aabjKefflofK2XkyJEW2611LFX97N271xg8eLBOan316tV2s6/+/v5GUlKS8Zvf/Mbo1auXcc011xjJycnGjh07LL5j8+bNxmuvvWZxrNVnK+exh32ty7G11nlrD/t7qX2tvI8q3X///UZpaakRGhrqkMdWmj7ZvAB2kb7++mv9n6Dye999950xZ84cm5etsRc1ZejQoRa/EFauXFnrZ9R/BhWsqaCt4r2goCCjpKTEiI+Pt/k+VQ5Q1C+p2ranpaXpC1vFa09PT+Ps2bPGpEmTHGo/a0sqiD506JDTHdeafrFb41iqPzyUq6++2pxHBQCKCgbsZV+rpkGDBul8lf9YUhcxdfxr+4w97mtt+2uN89ZRj63a7w0bNli856jHVpogcYtHRDw8PCQyMlKSkpIsmpbU6yFDhogjq2j2O3PmjMX7w4YN07eADhw4IEuXLpWOHTuat6m6UE2qletD3TpJSUmxu/oICwvTzanq1pxqBg0NDdXvq6VqFq+8D0VFRbJlyxbzPjjSftZ0zt57773yz3/+0ymPa2XWOpbXXnutbg5Xt0ArbNu2Tb9nz/uv/g+XlZXpclY2duxYfctD7eO8efPE19fXvM3R9rWx562j7a/SqVMnueWWW+SNN96ots2Zjm1jOOzDAq2pQ4cO+t6v+g9SmXodGBgojmzBggXy5Zdfyv79+83vrV27Vt5//305fvy4/uX/9NNPy6ZNm/QvAvWLX+2zms2w6i9Ee6sP9Z9y3LhxcvDgQencubM8+eSTsnXrVunXr5+5nDUd027duul1R9nPmowaNUrfh37zzTed7rhWZa1jqZY//fRTte9X79nr/nt5eek+Ke+8847FA9aWL1+u+yRlZGRIeHi4zJ07VwYOHCjx8fEOt6/WOG8daX8rjB8/Xh/Tjz76yOJ9Zzq2jUWAUonqjFSZm5tbtfccieo8OWDAgGqdfVesWGFeV4GL6hCsfjmoaF51qq2NvdXHunXrzOvqLw3Vse7w4cP6P/7XX3/d4GNqb/tZk4kTJ+pf7JU7BTvLca2NNY5lTfntdf/VH03vvfeeuLu7y+9+9zuLba+//rrFsT506JDs2rVLrrzyStm9e7dD7au1zltH2d8KEyZM0MFI1antnenYNha3eEQkOztbSkpKqkWfqgmu6l9tjkL1eL/tttvk+uuv17dAfomK1NUvBHW7pOK1+stN/YXuSPWhRnfs27dP74faB+WXjqmj7udll10mw4cPt/hF5szH1VrHUuVRLW1VqdsJ9rb/KjhRF27VohAXF3fJx9N/8803urWh8rF2lH21xnnraPur/mjs3bv3Jf8PO9uxrS8CFBEpLi7WEar6RVCZeq1uGTial156SW6//XaJjY2VY8eOXTJ/QECAdO3a1fzXuKoL9R+icn2oi4NqbrTn+lD3qfv06aP3QzWRqmXlfVD9NmJiYsz74Kj7+cADD+jm3E8++cQljqu1jqVqYVMXuaioKHOeq6++Wr9nT/tfEZyoC5IKRKv2H6uJuq2pzv+KY+0o+2qt89bR9le1gKqWIjWc3JWObUPYvKeuPQ0zfuCBB3Qv6QULFuhhxpdddpnNy1aftHjxYj264brrrrMYpubt7a23+/j4GPPmzdND07p166aHqH711Vd6CF/VIZsnTpwwYmNj9bA+1dPcHoajVk5qP9R+du/eXfdoV8Ps1FDUimOmRn2ouhg1apQemrp8+fIah6ba+35WTm5ubnrY+Ny5cy3ed/Tjqso/cOBAnZTHHntMr1eMXLHWsVTDM7/99ls96kGlPXv2NPvwzF/a1xYtWhj//e9/9X6oaQIq/x/28PDQn7/88suNp556yoiMjNTH+uabb9YjDnft2mV3+3qp/bXmeWsP+3up81glPz8/Iz8/33jooYeqfd7Rjq00fbJ5AewmPfLII8bRo0eNCxcuGDt37rQYmusoqTZqbhS1XQUq69atMzIzM3VApi52aphfSEiIxfd4eXkZiYmJRnZ2tlFQUKBP/qp5bJ0q5sJQ+3Hq1Cnjgw8+MPr06VNtKLIaolpYWGh8/vnn+uLmaPtZOcXFxenjGRYWZvG+ox9XdWGqidoHax5LNb/G22+/rQNZldR6mzZt7GZf1UWpNhXz3ah9Uvuv9lP9rlJDzRctWlRt7hB72NdL7a81z1t72N+6nMcPPvig3oeqc5s44rGVJk5u5SsAAAB2gz4oAADA7hCgAAAAu0OAAgAA7A4BCgAAsDsEKAAAwO4QoAAAALtDgAIAAOwOAQoAALA7BCgAAMDuEKAAAAC7Q4ACAADsDgEKAAAQe/P/AO4rzlthcv8AAAAASUVORK5CYII=" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 15 + }, + { + "cell_type": "code", + "id": "859eb2c5-bdc0-4376-a250-05f1ff68e491", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:44.162964Z", + "start_time": "2026-02-02T02:00:44.142886Z" + } + }, + "source": [ + "true_params[\"beta\"]" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "array([-0.92518728, 0.27270752, -0.20081106])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 16 + }, + { + "cell_type": "code", + "id": "6ab67d23-5bb2-4f1e-b69c-2b955823f035", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:44.505502Z", + "start_time": "2026-02-02T02:00:44.489441Z" + } + }, + "source": [ + "f_loss_dloss(np.array(500), **true_loc_dict)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "[array(315.55632031),\n", + " array(-230.91497151),\n", + " array(302.6189981),\n", + " array([-3.98004556, 6.85857563, -9.27823392]),\n", + " array([43.40893142, 64.50616261, 59.40050354]),\n", + " array(18.76786789),\n", + " array(59.22616729)]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 17 + }, + { + "cell_type": "code", + "id": "ab2a24a2-d4d2-4791-b1e8-549f03ecc2af", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:45.124363Z", + "start_time": "2026-02-02T02:00:45.107974Z" + } + }, + "source": [ + "f_loss_dloss(np.array(500, dtype=int), *opt_param_values)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "[array(57.63612282),\n", + " array(-0.29440051),\n", + " array(0.21724758),\n", + " array([ 1.01318409, -0.01027799, 1.15720638]),\n", + " array([0.39221349, 0.37674832, 0.41073372]),\n", + " array(0.27247629),\n", + " array(0.38035044)]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 18 + }, + { + "cell_type": "code", + "id": "f9e889a5-b509-4208-9648-ded17cf1e5d7", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:47.298869Z", + "start_time": "2026-02-02T02:00:47.284083Z" + } + }, + "source": [ + "def inverse_softplus(x):\n", + " return np.log(np.expm1(x))\n", + "\n", + "\n", + "mcmc_param_values = {}\n", + "mcmc_param_values[\"sigma_loc\"] = np.log(idata.posterior[\"sigma\"]).mean((\"chain\", \"draw\")).values\n", + "mcmc_param_values[\"sigma_scale\"] = inverse_softplus(\n", + " np.log(idata.posterior[\"sigma\"]).std((\"chain\", \"draw\"))\n", + ").values\n", + "\n", + "for param in (\"beta\", \"alpha\"):\n", + " mcmc_param_values[f\"{param}_loc\"] = idata.posterior[param].mean((\"chain\", \"draw\")).values\n", + " mcmc_param_values[f\"{param}_scale\"] = inverse_softplus(\n", + " idata.posterior[param].std((\"chain\", \"draw\"))\n", + " ).values" + ], + "outputs": [], + "execution_count": 19 + }, + { + "cell_type": "code", + "id": "9eb47148-f517-4a00-924d-4a8269b9587c", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:48.036340Z", + "start_time": "2026-02-02T02:00:48.014115Z" + } + }, + "source": [ + "optimized_params" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "{'sigma_loc': array(-1.08338307),\n", + " 'sigma_scale': array(-3.40979122),\n", + " 'beta_loc': array([-0.90914275, 0.27361567, -0.15888119]),\n", + " 'beta_scale': array([-3.82559357, -3.84190526, -3.72892214]),\n", + " 'alpha_loc': array(11.78181088),\n", + " 'alpha_scale': array(-3.84125139)}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 20 + }, + { + "cell_type": "code", + "id": "8986aa65-5837-4d30-abf7-5aa70776d438", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:50.540383Z", + "start_time": "2026-02-02T02:00:50.522264Z" + } + }, + "source": [ + "mcmc_param_values" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "{'sigma_loc': array(-1.06794148),\n", + " 'sigma_scale': array(-2.57880694),\n", + " 'beta_loc': array([-0.90863733, 0.27245212, -0.15930082]),\n", + " 'beta_scale': array([-3.29214417, -3.33548593, -3.19476947]),\n", + " 'alpha_loc': array(11.78181598),\n", + " 'alpha_scale': array(-3.31635547)}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 21 + }, + { + "cell_type": "code", + "id": "653c599e-8db3-4aa7-939c-1ef08bf7cf74", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-02T02:00:51.912735Z", + "start_time": "2026-02-02T02:00:51.896347Z" + } + }, + "source": [ + "f_loss_dloss(np.array(500, dtype=int), **mcmc_param_values)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "[array(56.62881272),\n", + " array(-1.36239183),\n", + " array(1.0238659),\n", + " array([ 1.15112757, -1.83769092, -1.2189314 ]),\n", + " array([1.16681203, 1.04059241, 1.2233972 ]),\n", + " array(2.14756573),\n", + " array(1.07150661)]" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 22 + }, + { + "cell_type": "markdown", + "id": "0cf321e6", + "metadata": {}, + "source": [ + "## Todo:\n", + "\n", + "- Does this \"two models\" frameworks fits into what we already have?\n", + "- rsample --> stochastic gradients? Or automatic reparameterization?\n", + "- figure out guide param initalization\n", + "- More flexible optimizers..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77786d86", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-dev", + "language": "python", + "name": "pymc-dev" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc_extras/inference/advi/__init__.py b/pymc_extras/inference/advi/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymc_extras/inference/advi/autoguide.py b/pymc_extras/inference/advi/autoguide.py new file mode 100644 index 000000000..38e98a373 --- /dev/null +++ b/pymc_extras/inference/advi/autoguide.py @@ -0,0 +1,105 @@ +from dataclasses import dataclass, field + +import numpy as np +import pytensor.tensor as pt + +from pymc.distributions import Normal +from pymc.logprob.basic import conditional_logp +from pymc.model.core import Deterministic, Model +from pytensor import graph_replace +from pytensor.gradient import disconnected_grad +from pytensor.graph.basic import Variable + +from pymc_extras.inference.advi.pytensorf import get_symbolic_rv_shapes + + +@dataclass(frozen=True) +class AutoGuideModel: + model: Model + params_init_values: dict[Variable, np.ndarray] + name_to_param: dict[str, Variable] = field(init=False) + + def __post_init__(self): + object.__setattr__( + self, + "name_to_param", + {x.name: x for x in self.params_init_values.keys()}, + ) + + @property + def params(self) -> tuple[Variable, ...]: + return tuple(self.params_init_values.keys()) + + def __getitem__(self, name: str) -> Variable: + return self.name_to_param[name] + + def stochastic_logq(self, stick_the_landing: bool = True) -> pt.TensorVariable: + """Returns a graph representing the logp of the guide model, evaluated under draws from its random variables.""" + logp_terms = conditional_logp( + {rv: rv for rv in self.model.deterministics}, + warn_rvs=False, + ) + logq = pt.sum([logp_term.sum() for logp_term in logp_terms.values()]) + + if stick_the_landing: + # Detach variational parameters from the gradient computation of logq + repl = {p: disconnected_grad(p) for p in self.params} + logq = graph_replace(logq, repl) + + return logq + + +def AutoDiagonalNormal(model: Model) -> AutoGuideModel: + """ + Create a guide model for ADVI with a mean-field normal distribution. + + A guide model is a variational distribution that approximates the posterior distribution of the model's free + random variables. In this case, we use a mean-field normal distribution, which assumes that the free random + variables are independent and normally distributed. For details, see _[1]. + + For each free random variable in the model, we create a corresponding random variable in the guide model with a + normal distribution. The mean and standard deviation of each normal distribution are parameterized by learnable + parameters (loc and scale), which are initialized to small random values. + + Parameters + ---------- + model : Model + The probabilistic model for which to create the guide. + + Returns + ------- + guide_model : AutoGuideModel + An AutoGuideModel containing the guide model and the initial values for its parameters. + + References + ---------- + .. [1] Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, and David M. Blei. Automatic Differentiation + Variational Inference. Journal of Machine Learning Research, 18(14):1–45, 2017. + """ + coords = model.coords + free_rvs = model.free_RVs + + free_rv_shapes = dict(zip(free_rvs, get_symbolic_rv_shapes(free_rvs))) + params_init_values = {} + + with Model(coords=coords) as guide_model: + for rv in free_rvs: + loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) + scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape) + # TODO: Make these customizable + params_init_values[loc] = pt.random.uniform(-1, 1, size=free_rv_shapes[rv]).eval() + params_init_values[scale] = pt.full(free_rv_shapes[rv], 0.1).eval() + + z = Normal( + f"{rv.name}_z", + mu=0, + sigma=1, + shape=free_rv_shapes[rv], + ) + Deterministic( + rv.name, + loc + pt.softplus(scale) * z, + dims=model.named_vars_to_dims.get(rv.name, None), + ) + + return AutoGuideModel(guide_model, params_init_values) diff --git a/pymc_extras/inference/advi/objective.py b/pymc_extras/inference/advi/objective.py new file mode 100644 index 000000000..9076108f4 --- /dev/null +++ b/pymc_extras/inference/advi/objective.py @@ -0,0 +1,59 @@ +from pymc import Model +from pytensor import graph_replace +from pytensor.tensor import TensorVariable + +from pymc_extras.inference.advi.autoguide import AutoGuideModel + + +def get_logp_logq(model: Model, guide: AutoGuideModel, stick_the_landing: bool = True): + """ + Compute the log probability of the model and the guide. + + Parameters + ---------- + model : Model + The probabilistic model. + guide : AutoGuideModel + The variational guide. + stick_the_landing : bool, optional + Whether to use the stick-the-landing (STL) gradient estimator, by default True. + The STL estimator has lower gradient variance by removing the score function term + from the gradient. When True, gradients are stopped from flowing through logq. + + Returns + ------- + logp : TensorVariable + Log probability of the model. + logq : TensorVariable + Log probability of the guide. + """ + + inputs_to_guide_rvs = { + model_value_var: guide.model[rv.name] + for rv, model_value_var in model.rvs_to_values.items() + if rv not in model.observed_RVs + } + + logp = graph_replace(model.logp(), inputs_to_guide_rvs) + logq = guide.stochastic_logq(stick_the_landing=stick_the_landing) + + return logp, logq + + +def advi_objective(logp: TensorVariable, logq: TensorVariable): + """Compute the negative ELBO objective for ADVI. + + Parameters + ---------- + logp : TensorVariable + Log probability of the model. + logq : TensorVariable + Log probability of the guide. + + Returns + ------- + TensorVariable + The negative ELBO. + """ + negative_elbo = logq - logp + return negative_elbo diff --git a/pymc_extras/inference/advi/pytensorf.py b/pymc_extras/inference/advi/pytensorf.py new file mode 100644 index 000000000..3e204a16b --- /dev/null +++ b/pymc_extras/inference/advi/pytensorf.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import cast + +from pymc import SymbolicRandomVariable +from pymc.distributions.shape_utils import change_dist_size +from pytensor import config +from pytensor import tensor as pt +from pytensor.graph import FunctionGraph, ancestors, vectorize_graph +from pytensor.tensor import TensorLike, TensorVariable +from pytensor.tensor.basic import infer_shape_db +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.rewriting.shape import ShapeFeature + + +def vectorize_random_graph( + graph: Sequence[TensorVariable], batch_draws: TensorLike +) -> list[TensorVariable]: + # Find the root random nodes + rvs = tuple( + var + for var in ancestors(graph) + if ( + var.owner is not None + and isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable) + ) + ) + rvs_set = set(rvs) + root_rvs = tuple(rv for rv in rvs if not (set(rv.owner.inputs) & rvs_set)) + + # Vectorize graph by vectorizing root RVs + batch_draws = pt.as_tensor(batch_draws, dtype=int) + vectorized_replacements = { + root_rv: change_dist_size(root_rv, new_size=batch_draws, expand=True) + for root_rv in root_rvs + } + return cast(list[TensorVariable], vectorize_graph(graph, replace=vectorized_replacements)) + + +def get_symbolic_rv_shapes( + rvs: Sequence[TensorVariable], raise_if_rvs_in_graph: bool = True +) -> tuple[TensorVariable, ...]: + # TODO: Move me to pymc.pytensorf, this is needed often + + rv_shapes = [rv.shape for rv in rvs] + shape_fg = FunctionGraph(outputs=rv_shapes, features=[ShapeFeature()], clone=True) + with config.change_flags(optdb__max_use_ratio=10, cxx=""): + infer_shape_db.default_query.rewrite(shape_fg) + rv_shapes = shape_fg.outputs + + if raise_if_rvs_in_graph and (overlap := (set(rvs) & set(ancestors(rv_shapes)))): + raise ValueError(f"rv_shapes still depend the following rvs {overlap}") + + return cast(tuple[TensorVariable, ...], tuple(rv_shapes)) diff --git a/pymc_extras/inference/advi/training.py b/pymc_extras/inference/advi/training.py new file mode 100644 index 000000000..e3770726e --- /dev/null +++ b/pymc_extras/inference/advi/training.py @@ -0,0 +1,39 @@ +from typing import Protocol + +import numpy as np + +from pymc import Model, compile +from pymc.pytensorf import rewrite_pregrad +from pytensor import tensor as pt + +from pymc_extras.inference.advi.autoguide import AutoGuideModel +from pymc_extras.inference.advi.objective import advi_objective, get_logp_logq +from pymc_extras.inference.advi.pytensorf import vectorize_random_graph + + +class TrainingFn(Protocol): + def __call__(self, draws: int, *params: np.ndarray) -> tuple[np.ndarray, ...]: ... + + +def compile_svi_training_fn( + model: Model, guide: AutoGuideModel, stick_the_landing: bool = True, **compile_kwargs +) -> TrainingFn: + draws = pt.scalar("draws", dtype=int) + params = guide.params + + logp, logq = get_logp_logq(model, guide, stick_the_landing=stick_the_landing) + + scalar_negative_elbo = advi_objective(logp, logq) + [negative_elbo_draws] = vectorize_random_graph([scalar_negative_elbo], batch_draws=draws) + negative_elbo = negative_elbo_draws.mean(axis=0) + + negative_elbo_grads = pt.grad(rewrite_pregrad(negative_elbo), wrt=params) + + if "trust_input" not in compile_kwargs: + compile_kwargs["trust_input"] = True + + f_loss_dloss = compile( + inputs=[draws, *params], outputs=[negative_elbo, *negative_elbo_grads], **compile_kwargs + ) + + return f_loss_dloss diff --git a/tests/inference/advi/__init__.py b/tests/inference/advi/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/advi/test_autoguide.py b/tests/inference/advi/test_autoguide.py new file mode 100644 index 000000000..3318613c6 --- /dev/null +++ b/tests/inference/advi/test_autoguide.py @@ -0,0 +1,156 @@ +import numpy as np +import pymc as pm +import pytest + +from pytensor import function as pytensor_function +from scipy import special + +from pymc_extras.inference.advi.autoguide import AutoDiagonalNormal, AutoGuideModel + +# TODO: This is a magic number from AutoDiagonalNormal's scale initialization +SCALE_INIT = 0.1 + + +@pytest.fixture +def simple_model(): + with pm.Model() as model: + pm.Normal("x", 0, 1) + return model + + +@pytest.fixture +def multi_rv_model(): + with pm.Model() as model: + pm.Normal("x", 0, 1) + pm.Normal("y", 0, 1, shape=(3,)) + return model + + +class TestAutoDiagonalNormal: + def test_creates_guide_variables(self): + with pm.Model() as model: + pm.Normal("mu", 0, 1) + pm.Exponential("sigma", 1) + + guide = AutoDiagonalNormal(model) + + assert isinstance(guide, AutoGuideModel) + expected_vars = {"mu", "sigma", "mu_z", "sigma_z"} + assert expected_vars <= set(guide.model.named_vars.keys()) + + @pytest.mark.parametrize( + "rv_shapes, expected_param_shapes", + [ + ( + [(), (3,), (2, 4)], + { + "x_loc": (), + "x_scale": (), + "y_loc": (3,), + "y_scale": (3,), + "z_loc": (2, 4), + "z_scale": (2, 4), + }, + ), + ], + ) + def test_params_have_correct_shapes(self, rv_shapes, expected_param_shapes): + with pm.Model() as model: + for i, (name, shape) in enumerate(zip(["x", "y", "z"], rv_shapes)): + pm.Normal(name, 0, 1, shape=shape if shape else None) + + guide = AutoDiagonalNormal(model) + param_shapes = {p.name: v.shape for p, v in guide.params_init_values.items()} + + for param_name, expected_shape in expected_param_shapes.items(): + assert param_shapes[param_name] == expected_shape + + def test_preserves_coords_and_dims(self): + coords = {"city": ["A", "B", "C"]} + with pm.Model(coords=coords) as model: + pm.Normal("mu", 0, 1, dims=["city"]) + + guide = AutoDiagonalNormal(model) + + assert tuple(guide.model.coords["city"]) == tuple(coords["city"]) + assert guide.model.named_vars_to_dims["mu"] == ("city",) + + +class TestAutoGuideModel: + def test_params_returns_all_loc_and_scale(self, multi_rv_model): + guide = AutoDiagonalNormal(multi_rv_model) + + param_names = {p.name for p in guide.params} + assert param_names == {"x_loc", "x_scale", "y_loc", "y_scale"} + + def test_getitem_returns_param_by_name(self, simple_model): + guide = AutoDiagonalNormal(simple_model) + + loc = guide["x_loc"] + scale = guide["x_scale"] + + assert loc.name == "x_loc" + assert scale.name == "x_scale" + + def test_stochastic_logq_returns_scalar(self, multi_rv_model): + guide = AutoDiagonalNormal(multi_rv_model) + logq = guide.stochastic_logq() + + f = pytensor_function(list(guide.params), logq) + result = f(*[guide.params_init_values[p] for p in guide.params]) + + assert result.shape == () + assert np.isfinite(result) + + +class TestAutoDiagonalNormalSampling: + def test_samples_have_expected_variance(self, simple_model): + """Samples from guide should have std ≈ softplus(scale_init).""" + guide = AutoDiagonalNormal(simple_model) + x_det = guide.model["x"] + + z_rv = guide.model["x_z"] + rng = z_rv.owner.inputs[0] + updates = {rng: z_rv.owner.outputs[0]} + + f = pytensor_function(list(guide.params), x_det, updates=updates) + samples = np.array( + [f(*[guide.params_init_values[p] for p in guide.params]) for _ in range(1000)] + ) + + EXPECTED_STD = special.softplus(SCALE_INIT) + + np.testing.assert_allclose(np.std(samples), EXPECTED_STD, rtol=0.1) + + def test_loc_shifts_output_mean(self, simple_model): + guide = AutoDiagonalNormal(simple_model) + x_det = guide.model["x"] + + loc_var, scale_var = guide["x_loc"], guide["x_scale"] + f = pytensor_function([loc_var, scale_var], x_det) + + init_scale = guide.params_init_values[scale_var] + val_at_0 = f(np.array(0.0), init_scale) + val_at_5 = f(np.array(5.0), init_scale) + + np.testing.assert_allclose(val_at_5 - val_at_0, 5.0) + + def test_scale_affects_output_variance(self, simple_model): + guide = AutoDiagonalNormal(simple_model) + x_det = guide.model["x"] + + z_rv = guide.model["x_z"] + rng = z_rv.owner.inputs[0] + updates = {rng: z_rv.owner.outputs[0]} + + loc_var, scale_var = guide["x_loc"], guide["x_scale"] + f = pytensor_function([loc_var, scale_var], x_det, updates=updates) + + def sample_std(scale_val, n=500): + samples = [f(np.array(0.0), np.array(scale_val)) for _ in range(n)] + return np.std(samples) + + std_small = sample_std(0.1) + std_large = sample_std(2.0) + + assert std_large > std_small * 2