|
32 | 32 | }, |
33 | 33 | { |
34 | 34 | "cell_type": "code", |
35 | | - "execution_count": 1, |
| 35 | + "execution_count": 2, |
36 | 36 | "id": "6f5439b8", |
37 | 37 | "metadata": {}, |
38 | 38 | "outputs": [ |
39 | | - { |
40 | | - "name": "stderr", |
41 | | - "output_type": "stream", |
42 | | - "text": [ |
43 | | - "WARNING (pytensor.configdefaults): g++ not available, if using conda: `conda install gxx`\n", |
44 | | - "WARNING (pytensor.configdefaults): g++ not detected! PyTensor will be unable to compile C-implementations and will default to Python. Performance may be severely degraded. To remove this warning, set PyTensor flags cxx to an empty string.\n", |
45 | | - "c:\\GIT\\Numerics-Python-Examples\\.venv\\Lib\\site-packages\\arviz\\__init__.py:39: FutureWarning: \n", |
46 | | - "ArviZ is undergoing a major refactor to improve flexibility and extensibility while maintaining a user-friendly interface.\n", |
47 | | - "Some upcoming changes may be backward incompatible.\n", |
48 | | - "For details and migration guidance, visit: https://python.arviz.org/en/latest/user_guide/migration_guide.html\n", |
49 | | - " warn(\n" |
50 | | - ] |
51 | | - }, |
52 | 39 | { |
53 | 40 | "name": "stdout", |
54 | 41 | "output_type": "stream", |
|
86 | 73 | }, |
87 | 74 | { |
88 | 75 | "cell_type": "code", |
89 | | - "execution_count": 2, |
| 76 | + "execution_count": 3, |
90 | 77 | "id": "9a147b4e", |
91 | 78 | "metadata": {}, |
92 | 79 | "outputs": [ |
|
96 | 83 | "False" |
97 | 84 | ] |
98 | 85 | }, |
99 | | - "execution_count": 2, |
| 86 | + "execution_count": 3, |
100 | 87 | "metadata": {}, |
101 | 88 | "output_type": "execute_result" |
102 | 89 | } |
|
135 | 122 | }, |
136 | 123 | { |
137 | 124 | "cell_type": "code", |
138 | | - "execution_count": 3, |
| 125 | + "execution_count": 4, |
139 | 126 | "id": "ef3f0173", |
140 | 127 | "metadata": {}, |
141 | 128 | "outputs": [], |
142 | 129 | "source": [ |
143 | 130 | "def extract_chain_samples(results, param_indx):\n", |
144 | | - " \"\"\"Extract parameter samples from MCMC sampler.\"\"\"\n", |
| 131 | + " \"\"\"Extract parameter samples from MCMC sampler results.\"\"\"\n", |
145 | 132 | " chains = []\n", |
146 | 133 | " for c in range(len(results.MarkovChains)):\n", |
147 | | - " chain_c = [results.MarkovChains[c][i].Values[param_indx]\n", |
148 | | - " for i in range(len(results.MarkovChains[c]))]\n", |
149 | | - " chains.append(chain_c)\n", |
| 134 | + " chain = []\n", |
| 135 | + " for i in range(len(results.MarkovChains[c])):\n", |
| 136 | + " chain.append(float(results.MarkovChains[c][i].Values[param_indx]))\n", |
| 137 | + " chains.append(chain)\n", |
150 | 138 | " return chains\n", |
151 | 139 | "\n", |
152 | 140 | "def plot_trace(samples, param_names, title=\"Trace Plots\"):\n", |
|
238 | 226 | }, |
239 | 227 | { |
240 | 228 | "cell_type": "code", |
241 | | - "execution_count": 4, |
| 229 | + "execution_count": 5, |
242 | 230 | "id": "f230fcf0", |
243 | 231 | "metadata": {}, |
244 | 232 | "outputs": [ |
|
496 | 484 | }, |
497 | 485 | { |
498 | 486 | "cell_type": "code", |
499 | | - "execution_count": 5, |
| 487 | + "execution_count": 6, |
500 | 488 | "id": "031793f7", |
501 | 489 | "metadata": {}, |
502 | 490 | "outputs": [ |
|
779 | 767 | }, |
780 | 768 | { |
781 | 769 | "cell_type": "code", |
782 | | - "execution_count": 6, |
| 770 | + "execution_count": 7, |
783 | 771 | "id": "5840b6c1", |
784 | 772 | "metadata": {}, |
785 | 773 | "outputs": [ |
|
831 | 819 | " <tr>\n", |
832 | 820 | " <th>0</th>\n", |
833 | 821 | " <td>ARWMH</td>\n", |
834 | | - " <td>0.511</td>\n", |
| 822 | + " <td>0.528</td>\n", |
835 | 823 | " <td>108.998</td>\n", |
836 | 824 | " <td>2.674</td>\n", |
837 | 825 | " <td>8.998</td>\n", |
|
840 | 828 | " <td>202.499</td>\n", |
841 | 829 | " </tr>\n", |
842 | 830 | " <tr>\n", |
843 | | - " <th>2</th>\n", |
844 | | - " <td>DEMCzs</td>\n", |
845 | | - " <td>0.914</td>\n", |
846 | | - " <td>109.255</td>\n", |
847 | | - " <td>2.693</td>\n", |
848 | | - " <td>9.255</td>\n", |
849 | | - " <td>0.193</td>\n", |
850 | | - " <td>767.374</td>\n", |
851 | | - " <td>768.685</td>\n", |
852 | | - " </tr>\n", |
853 | | - " <tr>\n", |
854 | 831 | " <th>1</th>\n", |
855 | 832 | " <td>DEMCz</td>\n", |
856 | | - " <td>0.934</td>\n", |
| 833 | + " <td>0.885</td>\n", |
857 | 834 | " <td>109.654</td>\n", |
858 | 835 | " <td>2.694</td>\n", |
859 | 836 | " <td>9.654</td>\n", |
|
862 | 839 | " <td>581.106</td>\n", |
863 | 840 | " </tr>\n", |
864 | 841 | " <tr>\n", |
| 842 | + " <th>2</th>\n", |
| 843 | + " <td>DEMCzs</td>\n", |
| 844 | + " <td>0.910</td>\n", |
| 845 | + " <td>109.255</td>\n", |
| 846 | + " <td>2.693</td>\n", |
| 847 | + " <td>9.255</td>\n", |
| 848 | + " <td>0.193</td>\n", |
| 849 | + " <td>767.374</td>\n", |
| 850 | + " <td>768.685</td>\n", |
| 851 | + " </tr>\n", |
| 852 | + " <tr>\n", |
865 | 853 | " <th>3</th>\n", |
866 | 854 | " <td>HMC</td>\n", |
867 | | - " <td>8.857</td>\n", |
| 855 | + " <td>8.497</td>\n", |
868 | 856 | " <td>111.574</td>\n", |
869 | 857 | " <td>2.769</td>\n", |
870 | 858 | " <td>11.574</td>\n", |
|
875 | 863 | " <tr>\n", |
876 | 864 | " <th>4</th>\n", |
877 | 865 | " <td>NUTS</td>\n", |
878 | | - " <td>36.850</td>\n", |
| 866 | + " <td>37.273</td>\n", |
879 | 867 | " <td>108.194</td>\n", |
880 | 868 | " <td>2.695</td>\n", |
881 | 869 | " <td>8.194</td>\n", |
|
889 | 877 | ], |
890 | 878 | "text/plain": [ |
891 | 879 | " Sampler Runtime (s) λ Estimate κ Estimate λ Error κ Error ESS (λ) \\\n", |
892 | | - "0 ARWMH 0.511 108.998 2.674 8.998 0.174 182.686 \n", |
893 | | - "2 DEMCzs 0.914 109.255 2.693 9.255 0.193 767.374 \n", |
894 | | - "1 DEMCz 0.934 109.654 2.694 9.654 0.194 726.493 \n", |
895 | | - "3 HMC 8.857 111.574 2.769 11.574 0.269 10.068 \n", |
896 | | - "4 NUTS 36.850 108.194 2.695 8.194 0.195 73.592 \n", |
| 880 | + "0 ARWMH 0.528 108.998 2.674 8.998 0.174 182.686 \n", |
| 881 | + "1 DEMCz 0.885 109.654 2.694 9.654 0.194 726.493 \n", |
| 882 | + "2 DEMCzs 0.910 109.255 2.693 9.255 0.193 767.374 \n", |
| 883 | + "3 HMC 8.497 111.574 2.769 11.574 0.269 10.068 \n", |
| 884 | + "4 NUTS 37.273 108.194 2.695 8.194 0.195 73.592 \n", |
897 | 885 | "\n", |
898 | 886 | " ESS (κ) \n", |
899 | 887 | "0 202.499 \n", |
900 | | - "2 768.685 \n", |
901 | 888 | "1 581.106 \n", |
| 889 | + "2 768.685 \n", |
902 | 890 | "3 102.772 \n", |
903 | 891 | "4 709.799 " |
904 | 892 | ] |
|
1026 | 1014 | }, |
1027 | 1015 | { |
1028 | 1016 | "cell_type": "code", |
1029 | | - "execution_count": 7, |
| 1017 | + "execution_count": 8, |
1030 | 1018 | "id": "2a1a1ca1", |
1031 | 1019 | "metadata": {}, |
1032 | 1020 | "outputs": [ |
|
1037 | 1025 | "\n", |
1038 | 1026 | "Running DEMCzs...\n", |
1039 | 1027 | "Numerics DEMCzs Performance:\n", |
1040 | | - " Runtime: 0.243 seconds\n", |
| 1028 | + " Runtime: 0.241 seconds\n", |
1041 | 1029 | " Mean μ: 12660.51\n", |
1042 | 1030 | " Mean σ: 4831.53\n", |
1043 | 1031 | "\n", |
1044 | 1032 | "Running Numerics NUTS...\n", |
1045 | 1033 | "Numerics NUTS Performance:\n", |
1046 | | - " Runtime: 7.619 seconds\n", |
| 1034 | + " Runtime: 7.685 seconds\n", |
1047 | 1035 | " Mean μ: 12650.87\n", |
1048 | 1036 | " Mean σ: 4845.18\n", |
1049 | 1037 | "\n", |
|
1057 | 1045 | "Initializing NUTS using jitter+adapt_diag...\n", |
1058 | 1046 | "Sequential sampling (1 chains in 1 job)\n", |
1059 | 1047 | "NUTS: [mu, sigma]\n", |
1060 | | - "Sampling 1 chain for 1_000 tune and 2_000 draw iterations (1_000 + 2_000 draws total) took 23 seconds.\n", |
| 1048 | + "Sampling 1 chain for 1_000 tune and 2_000 draw iterations (1_000 + 2_000 draws total) took 25 seconds.\n", |
1061 | 1049 | "Only one chain was sampled, this makes it impossible to run some convergence checks\n" |
1062 | 1050 | ] |
1063 | 1051 | }, |
|
1066 | 1054 | "output_type": "stream", |
1067 | 1055 | "text": [ |
1068 | 1056 | "Numerics PyMC Performance:\n", |
1069 | | - " Runtime: 26.254 seconds\n", |
| 1057 | + " Runtime: 27.552 seconds\n", |
1070 | 1058 | " Mean μ: 12641.05\n", |
1071 | 1059 | " Mean σ: 4826.64\n", |
1072 | 1060 | "\n", |
1073 | 1061 | "PERFORMANCE COMPARISON: Numerics DEMCzs vs Numerics NUTS vs PyMC\n", |
1074 | | - "Numerics DEMCzs: 0.243 seconds\n", |
1075 | | - "Numerics NUTS: 7.619 seconds\n", |
1076 | | - "PyMC (NUTS): 26.254 seconds\n", |
| 1062 | + "Numerics DEMCzs: 0.241 seconds\n", |
| 1063 | + "Numerics NUTS: 7.685 seconds\n", |
| 1064 | + "PyMC (NUTS): 27.552 seconds\n", |
1077 | 1065 | "\n", |
1078 | 1066 | " Numerics Speedup vs PyMC:\n", |
1079 | | - " DEMCzs: 108.17x faster\n", |
1080 | | - " NUTS: 3.45x faster\n" |
| 1067 | + " DEMCzs: 114.41x faster\n", |
| 1068 | + " NUTS: 3.59x faster\n" |
1081 | 1069 | ] |
1082 | 1070 | }, |
1083 | 1071 | { |
|
1246 | 1234 | }, |
1247 | 1235 | { |
1248 | 1236 | "cell_type": "code", |
1249 | | - "execution_count": 8, |
| 1237 | + "execution_count": 9, |
1250 | 1238 | "id": "689c6a99", |
1251 | 1239 | "metadata": {}, |
1252 | 1240 | "outputs": [ |
|
1264 | 1252 | "Initializing NUTS using jitter+adapt_diag...\n", |
1265 | 1253 | "Multiprocess sampling (3 chains in 3 jobs)\n", |
1266 | 1254 | "NUTS: [xi, alpha]\n", |
1267 | | - "Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 69 seconds.\n", |
| 1255 | + "Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 68 seconds.\n", |
1268 | 1256 | "We recommend running at least 4 chains for robust computation of convergence diagnostics\n" |
1269 | 1257 | ] |
1270 | 1258 | }, |
|
1369 | 1357 | " <tr>\n", |
1370 | 1358 | " <th>10</th>\n", |
1371 | 1359 | " <td>Runtime (secs)</td>\n", |
1372 | | - " <td>8.7408</td>\n", |
1373 | | - " <td>69.8270</td>\n", |
1374 | | - " <td>-61.0863</td>\n", |
| 1360 | + " <td>8.5154</td>\n", |
| 1361 | + " <td>68.8552</td>\n", |
| 1362 | + " <td>-60.3398</td>\n", |
1375 | 1363 | " </tr>\n", |
1376 | 1364 | " </tbody>\n", |
1377 | 1365 | "</table>\n", |
|
1389 | 1377 | "7 Alpha Lower CI 2297.817553 2305.096694 -0.3158%\n", |
1390 | 1378 | "8 Alpha Median 2795.192763 2793.589666 0.0574%\n", |
1391 | 1379 | "9 Alpha Upper CI 3447.047966 3453.778878 -0.1949%\n", |
1392 | | - "10 Runtime (secs) 8.7408 69.8270 -61.0863" |
| 1380 | + "10 Runtime (secs) 8.5154 68.8552 -60.3398" |
1393 | 1381 | ] |
1394 | 1382 | }, |
1395 | 1383 | "metadata": {}, |
|
1399 | 1387 | "name": "stdout", |
1400 | 1388 | "output_type": "stream", |
1401 | 1389 | "text": [ |
1402 | | - "Speedup: 7.99x faster with Numerics\n", |
| 1390 | + "Speedup: 8.09x faster with Numerics\n", |
1403 | 1391 | "\n", |
1404 | 1392 | "Gumbel Distribution [PyMC Comparison]\n" |
1405 | 1393 | ] |
|
1411 | 1399 | "Initializing NUTS using jitter+adapt_diag...\n", |
1412 | 1400 | "Multiprocess sampling (3 chains in 3 jobs)\n", |
1413 | 1401 | "NUTS: [xi, alpha]\n", |
1414 | | - "Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 52 seconds.\n", |
| 1402 | + "Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 54 seconds.\n", |
1415 | 1403 | "We recommend running at least 4 chains for robust computation of convergence diagnostics\n" |
1416 | 1404 | ] |
1417 | 1405 | }, |
|
1516 | 1504 | " <tr>\n", |
1517 | 1505 | " <th>10</th>\n", |
1518 | 1506 | " <td>Runtime (secs)</td>\n", |
1519 | | - " <td>6.9611</td>\n", |
1520 | | - " <td>53.2323</td>\n", |
1521 | | - " <td>-46.2712</td>\n", |
| 1507 | + " <td>6.9768</td>\n", |
| 1508 | + " <td>55.1851</td>\n", |
| 1509 | + " <td>-48.2083</td>\n", |
1522 | 1510 | " </tr>\n", |
1523 | 1511 | " </tbody>\n", |
1524 | 1512 | "</table>\n", |
|
1536 | 1524 | "7 Alpha Lower CI 3752.989571 3752.668837 0.0085%\n", |
1537 | 1525 | "8 Alpha Median 4463.712922 4476.005183 -0.2746%\n", |
1538 | 1526 | "9 Alpha Upper CI 5392.367283 5419.541733 -0.5014%\n", |
1539 | | - "10 Runtime (secs) 6.9611 53.2323 -46.2712" |
| 1527 | + "10 Runtime (secs) 6.9768 55.1851 -48.2083" |
1540 | 1528 | ] |
1541 | 1529 | }, |
1542 | 1530 | "metadata": {}, |
|
1546 | 1534 | "name": "stdout", |
1547 | 1535 | "output_type": "stream", |
1548 | 1536 | "text": [ |
1549 | | - "Speedup: 7.65x faster with Numerics\n", |
| 1537 | + "Speedup: 7.91x faster with Numerics\n", |
1550 | 1538 | "\n", |
1551 | 1539 | "Weibull Distribution [PyMC Comparison]\n" |
1552 | 1540 | ] |
|
1558 | 1546 | "Initializing NUTS using jitter+adapt_diag...\n", |
1559 | 1547 | "Multiprocess sampling (3 chains in 3 jobs)\n", |
1560 | 1548 | "NUTS: [kappa, lambda]\n", |
1561 | | - "Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 58 seconds.\n", |
| 1549 | + "Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 61 seconds.\n", |
1562 | 1550 | "We recommend running at least 4 chains for robust computation of convergence diagnostics\n" |
1563 | 1551 | ] |
1564 | 1552 | }, |
|
1663 | 1651 | " <tr>\n", |
1664 | 1652 | " <th>10</th>\n", |
1665 | 1653 | " <td>Runtime (secs)</td>\n", |
1666 | | - " <td>7.6491</td>\n", |
1667 | | - " <td>59.4641</td>\n", |
1668 | | - " <td>-51.8150</td>\n", |
| 1654 | + " <td>7.8432</td>\n", |
| 1655 | + " <td>61.8657</td>\n", |
| 1656 | + " <td>-54.0226</td>\n", |
1669 | 1657 | " </tr>\n", |
1670 | 1658 | " </tbody>\n", |
1671 | 1659 | "</table>\n", |
|
1683 | 1671 | "7 Lambda Lower CI 2.445549 13122.431346 -99.9814%\n", |
1684 | 1672 | "8 Lambda Median 2.973241 14278.392251 -99.9792%\n", |
1685 | 1673 | "9 Lambda Upper CI 3.547739 15514.805284 -99.9771%\n", |
1686 | | - "10 Runtime (secs) 7.6491 59.4641 -51.8150" |
| 1674 | + "10 Runtime (secs) 7.8432 61.8657 -54.0226" |
1687 | 1675 | ] |
1688 | 1676 | }, |
1689 | 1677 | "metadata": {}, |
|
1693 | 1681 | "name": "stdout", |
1694 | 1682 | "output_type": "stream", |
1695 | 1683 | "text": [ |
1696 | | - "Speedup: 7.77x faster with Numerics\n", |
| 1684 | + "Speedup: 7.89x faster with Numerics\n", |
1697 | 1685 | "\n" |
1698 | 1686 | ] |
1699 | 1687 | } |
|
0 commit comments