Skip to content

Commit 23036e0

Browse files
Merge pull request #31 from mathematicalmichael/feature/refactor
linting and cleanup
2 parents acbf065 + cf14997 commit 23036e0

15 files changed

Lines changed: 750 additions & 591 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ MANIFEST
5252
.venv*/
5353

5454
# save directories
55+
figures/
5556
scripts/pde_*D/
5657
scripts/ode/
5758
scripts/*-data/
@@ -60,3 +61,4 @@ ode/
6061
pde_*D/
6162
*-data/
6263
*.pkl
64+

src/mud_examples/experiments.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
import logging
5+
6+
import numpy as np
7+
from matplotlib import pyplot as plt
8+
9+
from mud_examples.utils import check_dir
10+
11+
plt.rcParams['figure.figsize'] = 10, 10
12+
plt.rcParams['font.size'] = 16
13+
14+
15+
__author__ = "Mathematical Michael"
16+
__copyright__ = "Mathematical Michael"
17+
__license__ = "mit"
18+
19+
_logger = logging.getLogger(__name__)
20+
_mpl_logger = logging.getLogger('matplotlib')
21+
_mpl_logger.setLevel(logging.WARNING)
22+
23+
24+
def experiment_measurements(fun, num_measurements,
25+
sd, num_trials, seed=21):
26+
"""
27+
Fixed sensors, varying how much data is incorporated into the solution.
28+
"""
29+
experiments = {}
30+
solutions = {}
31+
for ns in num_measurements:
32+
_logger.debug(f'Measurement experiment. Num measurements: {ns}')
33+
discretizations = []
34+
estimates = []
35+
for t in range(num_trials):
36+
np.random.seed(seed + t)
37+
_d = fun(sd=sd, num_obs=ns)
38+
estimate = _d.estimate()
39+
discretizations.append(_d)
40+
estimates.append(estimate)
41+
experiments[ns] = discretizations
42+
solutions[ns] = estimates
43+
44+
return experiments, solutions
45+
46+
47+
def experiment_equipment(fun, num_measure,
48+
sd_vals, num_trials,
49+
reference_value):
50+
"""
51+
Fixed number of sensors, varying the quality of equipment.
52+
"""
53+
sd_err = []
54+
sd_var = []
55+
for sd in sd_vals:
56+
_logger.debug(f'Equipment Experiment. Std Dev: {sd}')
57+
temp_err = []
58+
for t in range(num_trials):
59+
_d = fun(sd=sd, num_obs=num_measure)
60+
estimate = _d.estimate()
61+
temp_err.append(np.linalg.norm(estimate - reference_value))
62+
sd_err.append(np.mean(temp_err))
63+
sd_var.append(np.var(temp_err))
64+
65+
return sd_err, sd_var
66+
67+
68+
def plot_experiment_equipment(tolerances, res, prefix, fsize=32, linewidth=5,
69+
title="Variance of MUD Error", save=True):
70+
print("Plotting experiments involving equipment differences...")
71+
plt.figure(figsize=(10, 10))
72+
for _res in res:
73+
_prefix, _in, _rm, _re = _res
74+
regression_err_mean, slope_err_mean, \
75+
regression_err_vars, slope_err_vars, \
76+
sd_means, sd_vars, num_sensors = _re
77+
plt.plot(tolerances, regression_err_mean,
78+
label=f"{_prefix:10s} slope: {slope_err_mean:1.4f}",
79+
lw=linewidth)
80+
plt.scatter(tolerances, sd_means, marker='x', lw=20)
81+
82+
plt.yscale('log')
83+
plt.xscale('log')
84+
plt.Axes.set_aspect(plt.gca(), 1)
85+
plt.ylim(2E-3, 2E-2)
86+
# plt.ylabel("Absolute Error", fontsize=fsize)
87+
plt.xlabel('Tolerance', fontsize=fsize)
88+
plt.legend()
89+
plt.title(f"Mean of MUD Error for N={num_sensors}", fontsize=1.25 * fsize)
90+
if save:
91+
fdir = ''.join(prefix.split('/')[::-1])
92+
check_dir(fdir)
93+
_logger.info("Saving equipment experiments: mean convergence.")
94+
plt.savefig(f'{prefix}_convergence_mud_std_mean.png',
95+
bbox_inches='tight')
96+
else:
97+
plt.show()
98+
99+
plt.figure(figsize=(10, 10))
100+
for _res in res:
101+
_prefix, _in, _rm, _re = _res
102+
regression_err_mean, slope_err_mean, \
103+
regression_err_vars, slope_err_vars, \
104+
sd_means, sd_vars, num_sensors = _re
105+
plt.plot(tolerances, regression_err_vars,
106+
label=f"{_prefix:10s} slope: {slope_err_vars:1.4f}",
107+
lw=linewidth)
108+
plt.scatter(tolerances, sd_vars, marker='x', lw=20)
109+
plt.xscale('log')
110+
plt.yscale('log')
111+
plt.ylim(2E-5, 2E-4)
112+
plt.Axes.set_aspect(plt.gca(), 1)
113+
# plt.ylabel("Absolute Error", fontsize=fsize)
114+
plt.xlabel('Tolerance', fontsize=fsize)
115+
plt.legend()
116+
plt.title(title, fontsize=1.25 * fsize)
117+
if save:
118+
_logger.info("Saving equipment experiments: variance convergence.")
119+
plt.savefig(f'{prefix}_convergence_mud_std_var.png',
120+
bbox_inches='tight')
121+
else:
122+
plt.show()
123+
124+
125+
def plot_experiment_measurements(res, prefix,
126+
fsize=32, linewidth=5,
127+
xlabel='Number of Measurements',
128+
save=True, legend=False):
129+
print("Plotting experiments involving increasing # of measurements.")
130+
plt.figure(figsize=(10, 10))
131+
for _res in res:
132+
_prefix, _in, _rm, _re = _res
133+
solutions = _in[-1]
134+
measurements = list(solutions.keys())
135+
regression_mean, slope_mean, \
136+
regression_vars, slope_vars, \
137+
means, variances = _rm
138+
plt.plot(measurements[:len(regression_mean)], regression_mean,
139+
label=f"{_prefix:4s} slope: {slope_mean:1.4f}",
140+
lw=linewidth)
141+
plt.scatter(measurements[:len(means)], means, marker='x', lw=20)
142+
plt.xscale('log')
143+
plt.yscale('log')
144+
plt.Axes.set_aspect(plt.gca(), 1)
145+
plt.ylim(0.9 * min(means), 1.3 * max(means))
146+
plt.ylim(2E-3, 2E-1)
147+
plt.xlabel(xlabel, fontsize=fsize)
148+
if legend:
149+
plt.legend(fontsize=fsize * 0.8)
150+
# plt.ylabel('Absolute Error in MUD', fontsize=fsize)
151+
title = "$\\mathrm{\\mathbb{E}}(|\\lambda^* - \\lambda^\\dagger|)$" # noqa E501
152+
plt.title(title, fontsize=1.25 * fsize)
153+
if save:
154+
fdir = '/'.join(prefix.split('/')[:-1])
155+
check_dir(fdir)
156+
_logger.info("Saving measurement experiments: mean convergence.")
157+
plt.savefig(f'{prefix}_convergence_obs_mean.png', bbox_inches='tight')
158+
else:
159+
plt.show()
160+
161+
plt.figure(figsize=(10, 10))
162+
for _res in res:
163+
_prefix, _in, _rm, _re = _res
164+
regression_mean, slope_mean, \
165+
regression_vars, slope_vars, \
166+
means, variances = _rm
167+
plt.plot(measurements[:len(regression_vars)], regression_vars,
168+
label=f"{_prefix:4s} slope: {slope_vars:1.4f}",
169+
lw=linewidth)
170+
plt.scatter(measurements[:len(variances)], variances,
171+
marker='x', lw=20)
172+
plt.xscale('log')
173+
plt.yscale('log')
174+
plt.Axes.set_aspect(plt.gca(), 1)
175+
# if not len(np.unique(variances)) == 1:
176+
# plt.ylim(0.9 * min(variances), 1.3 * max(variances))
177+
plt.ylim(5E-6, 5E-4)
178+
plt.xlabel(xlabel, fontsize=fsize)
179+
if legend:
180+
plt.legend(fontsize=fsize * 0.8)
181+
# plt.ylabel('Absolute Error in MUD', fontsize=fsize)
182+
plt.title("$\\mathrm{Var}(|\\lambda^* - \\lambda^\\dagger|)$",
183+
fontsize=1.25 * fsize)
184+
if save:
185+
_logger.info("Saving measurement experiments: variance convergence.")
186+
plt.savefig(f'{prefix}_convergence_obs_var.png', bbox_inches='tight')
187+
else:
188+
plt.show()

0 commit comments

Comments
 (0)