Skip to content

Commit 1580c30

Browse files
linting
1 parent bdd4ee0 commit 1580c30

14 files changed

Lines changed: 1627 additions & 1011 deletions

File tree

src/mud_examples/experiments.py

Lines changed: 107 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,27 @@
88

99
from mud_examples.utils import check_dir
1010

11-
plt.rcParams['figure.figsize'] = 10, 10
12-
plt.rcParams['font.size'] = 16
11+
plt.rcParams["figure.figsize"] = 10, 10
12+
plt.rcParams["font.size"] = 16
1313

1414

1515
__author__ = "Mathematical Michael"
1616
__copyright__ = "Mathematical Michael"
1717
__license__ = "mit"
1818

1919
_logger = logging.getLogger(__name__)
20-
_mpl_logger = logging.getLogger('matplotlib')
20+
_mpl_logger = logging.getLogger("matplotlib")
2121
_mpl_logger.setLevel(logging.WARNING)
2222

2323

24-
def experiment_measurements(fun,
25-
num_measurements,
26-
sd,
27-
num_trials,
28-
seed=21):
24+
def experiment_measurements(fun, num_measurements, sd, num_trials, seed=21):
2925
"""
3026
Fixed sensors, varying how much data is incorporated into the solution.
3127
"""
3228
experiments = {}
3329
solutions = {}
3430
for ns in num_measurements:
35-
_logger.debug(f'Measurement experiment. Num measurements: {ns}')
31+
_logger.debug(f"Measurement experiment. Num measurements: {ns}")
3632
discretizations = []
3733
estimates = []
3834
for t in range(num_trials):
@@ -47,18 +43,14 @@ def experiment_measurements(fun,
4743
return experiments, solutions
4844

4945

50-
def experiment_equipment(fun,
51-
num_measure,
52-
sd_vals,
53-
num_trials,
54-
seed=21):
46+
def experiment_equipment(fun, num_measure, sd_vals, num_trials, seed=21):
5547
"""
5648
Fixed number of sensors, varying the quality of equipment.
5749
"""
5850
experiments = {}
5951
solutions = {}
6052
for sd in sd_vals:
61-
_logger.debug(f'Equipment Experiment. Std Dev: {sd}')
53+
_logger.debug(f"Equipment Experiment. Std Dev: {sd}")
6254
discretizations = []
6355
estimates = []
6456
for t in range(num_trials):
@@ -73,82 +65,117 @@ def experiment_equipment(fun,
7365
return experiments, solutions
7466

7567

76-
def plot_experiment_equipment(tolerances, res, prefix, fsize=32, linewidth=5,
77-
title="Variance of MUD Error", save=True):
68+
def plot_experiment_equipment(
69+
tolerances,
70+
res,
71+
prefix,
72+
fsize=32,
73+
linewidth=5,
74+
title="Variance of MUD Error",
75+
save=True,
76+
):
7877
print("Plotting experiments involving equipment differences...")
7978
plt.figure(figsize=(10, 10))
8079
for _res in res:
8180
_example, _in, _rm, _re, _fname = _res
82-
regression_err_mean, slope_err_mean, \
83-
regression_err_vars, slope_err_vars, \
84-
sd_means, sd_vars, num_sensors = _re
85-
plt.plot(tolerances, regression_err_mean,
86-
label=f"{_example.upper()} slope: {slope_err_mean:1.4f}",
87-
lw=linewidth)
88-
plt.scatter(tolerances, sd_means, marker='x', lw=20)
89-
90-
plt.yscale('log')
91-
plt.xscale('log')
81+
(
82+
regression_err_mean,
83+
slope_err_mean,
84+
regression_err_vars,
85+
slope_err_vars,
86+
sd_means,
87+
sd_vars,
88+
num_sensors,
89+
) = _re
90+
plt.plot(
91+
tolerances,
92+
regression_err_mean,
93+
label=f"{_example.upper()} slope: {slope_err_mean:1.4f}",
94+
lw=linewidth,
95+
)
96+
plt.scatter(tolerances, sd_means, marker="x", lw=20)
97+
98+
plt.yscale("log")
99+
plt.xscale("log")
92100
plt.Axes.set_aspect(plt.gca(), 1)
93101
# plt.ylim(2E-3, 2E-2)
94102
# plt.ylabel("Absolute Error", fontsize=fsize)
95-
plt.xlabel('Tolerance', fontsize=fsize)
103+
plt.xlabel("Tolerance", fontsize=fsize)
96104
plt.legend()
97105
plt.title(f"Mean of MUD Error for N={num_sensors}", fontsize=1.25 * fsize)
98106
if save:
99-
fdir = ''.join(prefix.split('/')[::-1])
100-
check_dir(f'figures/{_fname}/{fdir}')
107+
fdir = "".join(prefix.split("/")[::-1])
108+
check_dir(f"figures/{_fname}/{fdir}")
101109
_logger.info("Saving equipment experiments: mean convergence.")
102-
plt.savefig(f'figures/{_fname}/{prefix}_convergence_mud_std_mean.png',
103-
bbox_inches='tight')
110+
plt.savefig(
111+
f"figures/{_fname}/{prefix}_convergence_mud_std_mean.png",
112+
bbox_inches="tight",
113+
)
104114
else:
105115
plt.show()
106116

107117
plt.figure(figsize=(10, 10))
108118
for _res in res:
109119
_example, _in, _rm, _re, _fname = _res
110-
regression_err_mean, slope_err_mean, \
111-
regression_err_vars, slope_err_vars, \
112-
sd_means, sd_vars, num_sensors = _re
113-
plt.plot(tolerances, regression_err_vars,
114-
label=f"{_example.upper()} slope: {slope_err_vars:1.4f}",
115-
lw=linewidth)
116-
plt.scatter(tolerances, sd_vars, marker='x', lw=20)
117-
plt.xscale('log')
118-
plt.yscale('log')
120+
(
121+
regression_err_mean,
122+
slope_err_mean,
123+
regression_err_vars,
124+
slope_err_vars,
125+
sd_means,
126+
sd_vars,
127+
num_sensors,
128+
) = _re
129+
plt.plot(
130+
tolerances,
131+
regression_err_vars,
132+
label=f"{_example.upper()} slope: {slope_err_vars:1.4f}",
133+
lw=linewidth,
134+
)
135+
plt.scatter(tolerances, sd_vars, marker="x", lw=20)
136+
plt.xscale("log")
137+
plt.yscale("log")
119138
# plt.ylim(2E-5, 2E-4)
120139
plt.Axes.set_aspect(plt.gca(), 1)
121140
# plt.ylabel("Absolute Error", fontsize=fsize)
122-
plt.xlabel('Tolerance', fontsize=fsize)
141+
plt.xlabel("Tolerance", fontsize=fsize)
123142
plt.legend()
124143
plt.title(title, fontsize=1.25 * fsize)
125144
if save:
126145
_logger.info("Saving equipment experiments: variance convergence.")
127-
plt.savefig(f'figures/{_fname}/{prefix}_convergence_mud_std_var.png',
128-
bbox_inches='tight')
146+
plt.savefig(
147+
f"figures/{_fname}/{prefix}_convergence_mud_std_var.png",
148+
bbox_inches="tight",
149+
)
129150
else:
130151
plt.show()
131152

132153

133-
def plot_experiment_measurements(res, prefix,
134-
fsize=32, linewidth=5,
135-
xlabel='Number of Measurements',
136-
save=True, legend=True):
154+
def plot_experiment_measurements(
155+
res,
156+
prefix,
157+
fsize=32,
158+
linewidth=5,
159+
xlabel="Number of Measurements",
160+
save=True,
161+
legend=True,
162+
):
137163
print("Plotting experiments involving increasing # of measurements.")
138164
plt.figure(figsize=(10, 10))
139165
for _res in res:
140166
_example, _in, _rm, _re, _fname = _res
141167
solutions = _in[-1]
142168
measurements = list(solutions.keys())
143-
regression_mean, slope_mean, \
144-
regression_vars, slope_vars, \
145-
means, variances = _rm
146-
plt.plot(measurements[:len(regression_mean)], regression_mean,
147-
label=f"{_example.upper()} slope: {slope_mean:1.4f}",
148-
lw=linewidth)
149-
plt.scatter(measurements[:len(means)], means, marker='x', lw=20)
150-
plt.xscale('log')
151-
plt.yscale('log')
169+
regression_mean, slope_mean, regression_vars, slope_vars, means, variances = _rm
170+
plt.plot(
171+
measurements[: len(regression_mean)],
172+
regression_mean,
173+
label=f"{_example.upper()} slope: {slope_mean:1.4f}",
174+
lw=linewidth,
175+
)
176+
plt.scatter(measurements[: len(means)], means, marker="x", lw=20)
177+
plt.xscale("log")
178+
plt.yscale("log")
152179
plt.Axes.set_aspect(plt.gca(), 1)
153180
# plt.ylim(0.9 * min(means), 1.3 * max(means))
154181
# plt.ylim(2E-3, 2E-1)
@@ -159,38 +186,43 @@ def plot_experiment_measurements(res, prefix,
159186
title = "$\\mathrm{\\mathbb{E}}(|\\lambda^* - \\lambda^\\dagger|)$" # noqa E501
160187
plt.title(title, fontsize=1.25 * fsize)
161188
if save:
162-
fdir = '/'.join(prefix.split('/')[:-1])
163-
check_dir(f'figures/{_fname}/{fdir}')
189+
fdir = "/".join(prefix.split("/")[:-1])
190+
check_dir(f"figures/{_fname}/{fdir}")
164191
_logger.info("Saving measurement experiments: mean convergence.")
165-
plt.savefig(f'figures/{_fname}/{prefix}_convergence_obs_mean.png', bbox_inches='tight')
192+
plt.savefig(
193+
f"figures/{_fname}/{prefix}_convergence_obs_mean.png", bbox_inches="tight"
194+
)
166195
else:
167196
plt.show()
168197

169198
plt.figure(figsize=(10, 10))
170199
for _res in res:
171200
_example, _in, _rm, _re, _fname = _res
172-
regression_mean, slope_mean, \
173-
regression_vars, slope_vars, \
174-
means, variances = _rm
175-
plt.plot(measurements[:len(regression_vars)], regression_vars,
176-
label=f"{_example.upper()} slope: {slope_vars:1.4f}",
177-
lw=linewidth)
178-
plt.scatter(measurements[:len(variances)], variances,
179-
marker='x', lw=20)
180-
plt.xscale('log')
181-
plt.yscale('log')
201+
regression_mean, slope_mean, regression_vars, slope_vars, means, variances = _rm
202+
plt.plot(
203+
measurements[: len(regression_vars)],
204+
regression_vars,
205+
label=f"{_example.upper()} slope: {slope_vars:1.4f}",
206+
lw=linewidth,
207+
)
208+
plt.scatter(measurements[: len(variances)], variances, marker="x", lw=20)
209+
plt.xscale("log")
210+
plt.yscale("log")
182211
plt.Axes.set_aspect(plt.gca(), 1)
183-
# if not len(np.unique(variances)) == 1:
184-
# plt.ylim(0.9 * min(variances), 1.3 * max(variances))
212+
# if not len(np.unique(variances)) == 1:
213+
# plt.ylim(0.9 * min(variances), 1.3 * max(variances))
185214
# plt.ylim(5E-6, 5E-4)
186215
plt.xlabel(xlabel, fontsize=fsize)
187216
if legend:
188217
plt.legend(fontsize=fsize * 0.8)
189218
# plt.ylabel('Absolute Error in MUD', fontsize=fsize)
190-
plt.title("$\\mathrm{Var}(|\\lambda^* - \\lambda^\\dagger|)$",
191-
fontsize=1.25 * fsize)
219+
plt.title(
220+
"$\\mathrm{Var}(|\\lambda^* - \\lambda^\\dagger|)$", fontsize=1.25 * fsize
221+
)
192222
if save:
193223
_logger.info("Saving measurement experiments: variance convergence.")
194-
plt.savefig(f'figures/{_fname}/{prefix}_convergence_obs_var.png', bbox_inches='tight')
224+
plt.savefig(
225+
f"figures/{_fname}/{prefix}_convergence_obs_var.png", bbox_inches="tight"
226+
)
195227
else:
196228
plt.show()

0 commit comments

Comments
 (0)