Skip to content

Commit 0ae7e14

Browse files
committed
Final changes post Tiki review
1 parent fc90a62 commit 0ae7e14

7 files changed

Lines changed: 233 additions & 289 deletions

CITATION.cff

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ authors:
2929
given-names: "Tiki"
3030
email: "julian.t.gonzalez@usace.army.mil"
3131
affiliation: "U.S. Army Corps of Engineers, Risk Management Center"
32-
orcid: ""
3332
- family-names: "Smith"
3433
given-names: "C. Haden"
3534
email: "cole.h.smith@usace.army.mil"

notebooks/04_mcmc_bayesian_inference.ipynb

Lines changed: 32 additions & 32 deletions
Large diffs are not rendered by default.

notebooks/05_mcmc_adaptive.ipynb

Lines changed: 62 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,10 @@
3232
},
3333
{
3434
"cell_type": "code",
35-
"execution_count": 1,
35+
"execution_count": 2,
3636
"id": "6f5439b8",
3737
"metadata": {},
3838
"outputs": [
39-
{
40-
"name": "stderr",
41-
"output_type": "stream",
42-
"text": [
43-
"WARNING (pytensor.configdefaults): g++ not available, if using conda: `conda install gxx`\n",
44-
"WARNING (pytensor.configdefaults): g++ not detected! PyTensor will be unable to compile C-implementations and will default to Python. Performance may be severely degraded. To remove this warning, set PyTensor flags cxx to an empty string.\n",
45-
"c:\\GIT\\Numerics-Python-Examples\\.venv\\Lib\\site-packages\\arviz\\__init__.py:39: FutureWarning: \n",
46-
"ArviZ is undergoing a major refactor to improve flexibility and extensibility while maintaining a user-friendly interface.\n",
47-
"Some upcoming changes may be backward incompatible.\n",
48-
"For details and migration guidance, visit: https://python.arviz.org/en/latest/user_guide/migration_guide.html\n",
49-
" warn(\n"
50-
]
51-
},
5239
{
5340
"name": "stdout",
5441
"output_type": "stream",
@@ -86,7 +73,7 @@
8673
},
8774
{
8875
"cell_type": "code",
89-
"execution_count": 2,
76+
"execution_count": 3,
9077
"id": "9a147b4e",
9178
"metadata": {},
9279
"outputs": [
@@ -96,7 +83,7 @@
9683
"False"
9784
]
9885
},
99-
"execution_count": 2,
86+
"execution_count": 3,
10087
"metadata": {},
10188
"output_type": "execute_result"
10289
}
@@ -135,18 +122,19 @@
135122
},
136123
{
137124
"cell_type": "code",
138-
"execution_count": 3,
125+
"execution_count": 4,
139126
"id": "ef3f0173",
140127
"metadata": {},
141128
"outputs": [],
142129
"source": [
143130
"def extract_chain_samples(results, param_indx):\n",
144-
" \"\"\"Extract parameter samples from MCMC sampler.\"\"\"\n",
131+
" \"\"\"Extract parameter samples from MCMC sampler results.\"\"\"\n",
145132
" chains = []\n",
146133
" for c in range(len(results.MarkovChains)):\n",
147-
" chain_c = [results.MarkovChains[c][i].Values[param_indx]\n",
148-
" for i in range(len(results.MarkovChains[c]))]\n",
149-
" chains.append(chain_c)\n",
134+
" chain = []\n",
135+
" for i in range(len(results.MarkovChains[c])):\n",
136+
" chain.append(float(results.MarkovChains[c][i].Values[param_indx]))\n",
137+
" chains.append(chain)\n",
150138
" return chains\n",
151139
"\n",
152140
"def plot_trace(samples, param_names, title=\"Trace Plots\"):\n",
@@ -238,7 +226,7 @@
238226
},
239227
{
240228
"cell_type": "code",
241-
"execution_count": 4,
229+
"execution_count": 5,
242230
"id": "f230fcf0",
243231
"metadata": {},
244232
"outputs": [
@@ -496,7 +484,7 @@
496484
},
497485
{
498486
"cell_type": "code",
499-
"execution_count": 5,
487+
"execution_count": 6,
500488
"id": "031793f7",
501489
"metadata": {},
502490
"outputs": [
@@ -779,7 +767,7 @@
779767
},
780768
{
781769
"cell_type": "code",
782-
"execution_count": 6,
770+
"execution_count": 7,
783771
"id": "5840b6c1",
784772
"metadata": {},
785773
"outputs": [
@@ -831,7 +819,7 @@
831819
" <tr>\n",
832820
" <th>0</th>\n",
833821
" <td>ARWMH</td>\n",
834-
" <td>0.511</td>\n",
822+
" <td>0.528</td>\n",
835823
" <td>108.998</td>\n",
836824
" <td>2.674</td>\n",
837825
" <td>8.998</td>\n",
@@ -840,20 +828,9 @@
840828
" <td>202.499</td>\n",
841829
" </tr>\n",
842830
" <tr>\n",
843-
" <th>2</th>\n",
844-
" <td>DEMCzs</td>\n",
845-
" <td>0.914</td>\n",
846-
" <td>109.255</td>\n",
847-
" <td>2.693</td>\n",
848-
" <td>9.255</td>\n",
849-
" <td>0.193</td>\n",
850-
" <td>767.374</td>\n",
851-
" <td>768.685</td>\n",
852-
" </tr>\n",
853-
" <tr>\n",
854831
" <th>1</th>\n",
855832
" <td>DEMCz</td>\n",
856-
" <td>0.934</td>\n",
833+
" <td>0.885</td>\n",
857834
" <td>109.654</td>\n",
858835
" <td>2.694</td>\n",
859836
" <td>9.654</td>\n",
@@ -862,9 +839,20 @@
862839
" <td>581.106</td>\n",
863840
" </tr>\n",
864841
" <tr>\n",
842+
" <th>2</th>\n",
843+
" <td>DEMCzs</td>\n",
844+
" <td>0.910</td>\n",
845+
" <td>109.255</td>\n",
846+
" <td>2.693</td>\n",
847+
" <td>9.255</td>\n",
848+
" <td>0.193</td>\n",
849+
" <td>767.374</td>\n",
850+
" <td>768.685</td>\n",
851+
" </tr>\n",
852+
" <tr>\n",
865853
" <th>3</th>\n",
866854
" <td>HMC</td>\n",
867-
" <td>8.857</td>\n",
855+
" <td>8.497</td>\n",
868856
" <td>111.574</td>\n",
869857
" <td>2.769</td>\n",
870858
" <td>11.574</td>\n",
@@ -875,7 +863,7 @@
875863
" <tr>\n",
876864
" <th>4</th>\n",
877865
" <td>NUTS</td>\n",
878-
" <td>36.850</td>\n",
866+
" <td>37.273</td>\n",
879867
" <td>108.194</td>\n",
880868
" <td>2.695</td>\n",
881869
" <td>8.194</td>\n",
@@ -889,16 +877,16 @@
889877
],
890878
"text/plain": [
891879
" Sampler Runtime (s) λ Estimate κ Estimate λ Error κ Error ESS (λ) \\\n",
892-
"0 ARWMH 0.511 108.998 2.674 8.998 0.174 182.686 \n",
893-
"2 DEMCzs 0.914 109.255 2.693 9.255 0.193 767.374 \n",
894-
"1 DEMCz 0.934 109.654 2.694 9.654 0.194 726.493 \n",
895-
"3 HMC 8.857 111.574 2.769 11.574 0.269 10.068 \n",
896-
"4 NUTS 36.850 108.194 2.695 8.194 0.195 73.592 \n",
880+
"0 ARWMH 0.528 108.998 2.674 8.998 0.174 182.686 \n",
881+
"1 DEMCz 0.885 109.654 2.694 9.654 0.194 726.493 \n",
882+
"2 DEMCzs 0.910 109.255 2.693 9.255 0.193 767.374 \n",
883+
"3 HMC 8.497 111.574 2.769 11.574 0.269 10.068 \n",
884+
"4 NUTS 37.273 108.194 2.695 8.194 0.195 73.592 \n",
897885
"\n",
898886
" ESS (κ) \n",
899887
"0 202.499 \n",
900-
"2 768.685 \n",
901888
"1 581.106 \n",
889+
"2 768.685 \n",
902890
"3 102.772 \n",
903891
"4 709.799 "
904892
]
@@ -1026,7 +1014,7 @@
10261014
},
10271015
{
10281016
"cell_type": "code",
1029-
"execution_count": 7,
1017+
"execution_count": 8,
10301018
"id": "2a1a1ca1",
10311019
"metadata": {},
10321020
"outputs": [
@@ -1037,13 +1025,13 @@
10371025
"\n",
10381026
"Running DEMCzs...\n",
10391027
"Numerics DEMCzs Performance:\n",
1040-
" Runtime: 0.243 seconds\n",
1028+
" Runtime: 0.241 seconds\n",
10411029
" Mean μ: 12660.51\n",
10421030
" Mean σ: 4831.53\n",
10431031
"\n",
10441032
"Running Numerics NUTS...\n",
10451033
"Numerics NUTS Performance:\n",
1046-
" Runtime: 7.619 seconds\n",
1034+
" Runtime: 7.685 seconds\n",
10471035
" Mean μ: 12650.87\n",
10481036
" Mean σ: 4845.18\n",
10491037
"\n",
@@ -1057,7 +1045,7 @@
10571045
"Initializing NUTS using jitter+adapt_diag...\n",
10581046
"Sequential sampling (1 chains in 1 job)\n",
10591047
"NUTS: [mu, sigma]\n",
1060-
"Sampling 1 chain for 1_000 tune and 2_000 draw iterations (1_000 + 2_000 draws total) took 23 seconds.\n",
1048+
"Sampling 1 chain for 1_000 tune and 2_000 draw iterations (1_000 + 2_000 draws total) took 25 seconds.\n",
10611049
"Only one chain was sampled, this makes it impossible to run some convergence checks\n"
10621050
]
10631051
},
@@ -1066,18 +1054,18 @@
10661054
"output_type": "stream",
10671055
"text": [
10681056
"Numerics PyMC Performance:\n",
1069-
" Runtime: 26.254 seconds\n",
1057+
" Runtime: 27.552 seconds\n",
10701058
" Mean μ: 12641.05\n",
10711059
" Mean σ: 4826.64\n",
10721060
"\n",
10731061
"PERFORMANCE COMPARISON: Numerics DEMCzs vs Numerics NUTS vs PyMC\n",
1074-
"Numerics DEMCzs: 0.243 seconds\n",
1075-
"Numerics NUTS: 7.619 seconds\n",
1076-
"PyMC (NUTS): 26.254 seconds\n",
1062+
"Numerics DEMCzs: 0.241 seconds\n",
1063+
"Numerics NUTS: 7.685 seconds\n",
1064+
"PyMC (NUTS): 27.552 seconds\n",
10771065
"\n",
10781066
" Numerics Speedup vs PyMC:\n",
1079-
" DEMCzs: 108.17x faster\n",
1080-
" NUTS: 3.45x faster\n"
1067+
" DEMCzs: 114.41x faster\n",
1068+
" NUTS: 3.59x faster\n"
10811069
]
10821070
},
10831071
{
@@ -1246,7 +1234,7 @@
12461234
},
12471235
{
12481236
"cell_type": "code",
1249-
"execution_count": 8,
1237+
"execution_count": 9,
12501238
"id": "689c6a99",
12511239
"metadata": {},
12521240
"outputs": [
@@ -1264,7 +1252,7 @@
12641252
"Initializing NUTS using jitter+adapt_diag...\n",
12651253
"Multiprocess sampling (3 chains in 3 jobs)\n",
12661254
"NUTS: [xi, alpha]\n",
1267-
"Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 69 seconds.\n",
1255+
"Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 68 seconds.\n",
12681256
"We recommend running at least 4 chains for robust computation of convergence diagnostics\n"
12691257
]
12701258
},
@@ -1369,9 +1357,9 @@
13691357
" <tr>\n",
13701358
" <th>10</th>\n",
13711359
" <td>Runtime (secs)</td>\n",
1372-
" <td>8.7408</td>\n",
1373-
" <td>69.8270</td>\n",
1374-
" <td>-61.0863</td>\n",
1360+
" <td>8.5154</td>\n",
1361+
" <td>68.8552</td>\n",
1362+
" <td>-60.3398</td>\n",
13751363
" </tr>\n",
13761364
" </tbody>\n",
13771365
"</table>\n",
@@ -1389,7 +1377,7 @@
13891377
"7 Alpha Lower CI 2297.817553 2305.096694 -0.3158%\n",
13901378
"8 Alpha Median 2795.192763 2793.589666 0.0574%\n",
13911379
"9 Alpha Upper CI 3447.047966 3453.778878 -0.1949%\n",
1392-
"10 Runtime (secs) 8.7408 69.8270 -61.0863"
1380+
"10 Runtime (secs) 8.5154 68.8552 -60.3398"
13931381
]
13941382
},
13951383
"metadata": {},
@@ -1399,7 +1387,7 @@
13991387
"name": "stdout",
14001388
"output_type": "stream",
14011389
"text": [
1402-
"Speedup: 7.99x faster with Numerics\n",
1390+
"Speedup: 8.09x faster with Numerics\n",
14031391
"\n",
14041392
"Gumbel Distribution [PyMC Comparison]\n"
14051393
]
@@ -1411,7 +1399,7 @@
14111399
"Initializing NUTS using jitter+adapt_diag...\n",
14121400
"Multiprocess sampling (3 chains in 3 jobs)\n",
14131401
"NUTS: [xi, alpha]\n",
1414-
"Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 52 seconds.\n",
1402+
"Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 54 seconds.\n",
14151403
"We recommend running at least 4 chains for robust computation of convergence diagnostics\n"
14161404
]
14171405
},
@@ -1516,9 +1504,9 @@
15161504
" <tr>\n",
15171505
" <th>10</th>\n",
15181506
" <td>Runtime (secs)</td>\n",
1519-
" <td>6.9611</td>\n",
1520-
" <td>53.2323</td>\n",
1521-
" <td>-46.2712</td>\n",
1507+
" <td>6.9768</td>\n",
1508+
" <td>55.1851</td>\n",
1509+
" <td>-48.2083</td>\n",
15221510
" </tr>\n",
15231511
" </tbody>\n",
15241512
"</table>\n",
@@ -1536,7 +1524,7 @@
15361524
"7 Alpha Lower CI 3752.989571 3752.668837 0.0085%\n",
15371525
"8 Alpha Median 4463.712922 4476.005183 -0.2746%\n",
15381526
"9 Alpha Upper CI 5392.367283 5419.541733 -0.5014%\n",
1539-
"10 Runtime (secs) 6.9611 53.2323 -46.2712"
1527+
"10 Runtime (secs) 6.9768 55.1851 -48.2083"
15401528
]
15411529
},
15421530
"metadata": {},
@@ -1546,7 +1534,7 @@
15461534
"name": "stdout",
15471535
"output_type": "stream",
15481536
"text": [
1549-
"Speedup: 7.65x faster with Numerics\n",
1537+
"Speedup: 7.91x faster with Numerics\n",
15501538
"\n",
15511539
"Weibull Distribution [PyMC Comparison]\n"
15521540
]
@@ -1558,7 +1546,7 @@
15581546
"Initializing NUTS using jitter+adapt_diag...\n",
15591547
"Multiprocess sampling (3 chains in 3 jobs)\n",
15601548
"NUTS: [kappa, lambda]\n",
1561-
"Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 58 seconds.\n",
1549+
"Sampling 3 chains for 1_750 tune and 3_500 draw iterations (5_250 + 10_500 draws total) took 61 seconds.\n",
15621550
"We recommend running at least 4 chains for robust computation of convergence diagnostics\n"
15631551
]
15641552
},
@@ -1663,9 +1651,9 @@
16631651
" <tr>\n",
16641652
" <th>10</th>\n",
16651653
" <td>Runtime (secs)</td>\n",
1666-
" <td>7.6491</td>\n",
1667-
" <td>59.4641</td>\n",
1668-
" <td>-51.8150</td>\n",
1654+
" <td>7.8432</td>\n",
1655+
" <td>61.8657</td>\n",
1656+
" <td>-54.0226</td>\n",
16691657
" </tr>\n",
16701658
" </tbody>\n",
16711659
"</table>\n",
@@ -1683,7 +1671,7 @@
16831671
"7 Lambda Lower CI 2.445549 13122.431346 -99.9814%\n",
16841672
"8 Lambda Median 2.973241 14278.392251 -99.9792%\n",
16851673
"9 Lambda Upper CI 3.547739 15514.805284 -99.9771%\n",
1686-
"10 Runtime (secs) 7.6491 59.4641 -51.8150"
1674+
"10 Runtime (secs) 7.8432 61.8657 -54.0226"
16871675
]
16881676
},
16891677
"metadata": {},
@@ -1693,7 +1681,7 @@
16931681
"name": "stdout",
16941682
"output_type": "stream",
16951683
"text": [
1696-
"Speedup: 7.77x faster with Numerics\n",
1684+
"Speedup: 7.89x faster with Numerics\n",
16971685
"\n"
16981686
]
16991687
}

0 commit comments

Comments
 (0)