|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +import logging |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from matplotlib import pyplot as plt |
| 8 | + |
| 9 | +from mud_examples.utils import check_dir |
| 10 | + |
| 11 | +plt.rcParams['figure.figsize'] = 10, 10 |
| 12 | +plt.rcParams['font.size'] = 16 |
| 13 | + |
| 14 | + |
| 15 | +__author__ = "Mathematical Michael" |
| 16 | +__copyright__ = "Mathematical Michael" |
| 17 | +__license__ = "mit" |
| 18 | + |
| 19 | +_logger = logging.getLogger(__name__) |
| 20 | +_mpl_logger = logging.getLogger('matplotlib') |
| 21 | +_mpl_logger.setLevel(logging.WARNING) |
| 22 | + |
| 23 | + |
| 24 | +def experiment_measurements(fun, num_measurements, |
| 25 | + sd, num_trials, seed=21): |
| 26 | + """ |
| 27 | + Fixed sensors, varying how much data is incorporated into the solution. |
| 28 | + """ |
| 29 | + experiments = {} |
| 30 | + solutions = {} |
| 31 | + for ns in num_measurements: |
| 32 | + _logger.debug(f'Measurement experiment. Num measurements: {ns}') |
| 33 | + discretizations = [] |
| 34 | + estimates = [] |
| 35 | + for t in range(num_trials): |
| 36 | + np.random.seed(seed + t) |
| 37 | + _d = fun(sd=sd, num_obs=ns) |
| 38 | + estimate = _d.estimate() |
| 39 | + discretizations.append(_d) |
| 40 | + estimates.append(estimate) |
| 41 | + experiments[ns] = discretizations |
| 42 | + solutions[ns] = estimates |
| 43 | + |
| 44 | + return experiments, solutions |
| 45 | + |
| 46 | + |
| 47 | +def experiment_equipment(fun, num_measure, |
| 48 | + sd_vals, num_trials, |
| 49 | + reference_value): |
| 50 | + """ |
| 51 | + Fixed number of sensors, varying the quality of equipment. |
| 52 | + """ |
| 53 | + sd_err = [] |
| 54 | + sd_var = [] |
| 55 | + for sd in sd_vals: |
| 56 | + _logger.debug(f'Equipment Experiment. Std Dev: {sd}') |
| 57 | + temp_err = [] |
| 58 | + for t in range(num_trials): |
| 59 | + _d = fun(sd=sd, num_obs=num_measure) |
| 60 | + estimate = _d.estimate() |
| 61 | + temp_err.append(np.linalg.norm(estimate - reference_value)) |
| 62 | + sd_err.append(np.mean(temp_err)) |
| 63 | + sd_var.append(np.var(temp_err)) |
| 64 | + |
| 65 | + return sd_err, sd_var |
| 66 | + |
| 67 | + |
| 68 | +def plot_experiment_equipment(tolerances, res, prefix, fsize=32, linewidth=5, |
| 69 | + title="Variance of MUD Error", save=True): |
| 70 | + print("Plotting experiments involving equipment differences...") |
| 71 | + plt.figure(figsize=(10, 10)) |
| 72 | + for _res in res: |
| 73 | + _prefix, _in, _rm, _re = _res |
| 74 | + regression_err_mean, slope_err_mean, \ |
| 75 | + regression_err_vars, slope_err_vars, \ |
| 76 | + sd_means, sd_vars, num_sensors = _re |
| 77 | + plt.plot(tolerances, regression_err_mean, |
| 78 | + label=f"{_prefix:10s} slope: {slope_err_mean:1.4f}", |
| 79 | + lw=linewidth) |
| 80 | + plt.scatter(tolerances, sd_means, marker='x', lw=20) |
| 81 | + |
| 82 | + plt.yscale('log') |
| 83 | + plt.xscale('log') |
| 84 | + plt.Axes.set_aspect(plt.gca(), 1) |
| 85 | + plt.ylim(2E-3, 2E-2) |
| 86 | + # plt.ylabel("Absolute Error", fontsize=fsize) |
| 87 | + plt.xlabel('Tolerance', fontsize=fsize) |
| 88 | + plt.legend() |
| 89 | + plt.title(f"Mean of MUD Error for N={num_sensors}", fontsize=1.25 * fsize) |
| 90 | + if save: |
| 91 | + fdir = ''.join(prefix.split('/')[::-1]) |
| 92 | + check_dir(fdir) |
| 93 | + _logger.info("Saving equipment experiments: mean convergence.") |
| 94 | + plt.savefig(f'{prefix}_convergence_mud_std_mean.png', |
| 95 | + bbox_inches='tight') |
| 96 | + else: |
| 97 | + plt.show() |
| 98 | + |
| 99 | + plt.figure(figsize=(10, 10)) |
| 100 | + for _res in res: |
| 101 | + _prefix, _in, _rm, _re = _res |
| 102 | + regression_err_mean, slope_err_mean, \ |
| 103 | + regression_err_vars, slope_err_vars, \ |
| 104 | + sd_means, sd_vars, num_sensors = _re |
| 105 | + plt.plot(tolerances, regression_err_vars, |
| 106 | + label=f"{_prefix:10s} slope: {slope_err_vars:1.4f}", |
| 107 | + lw=linewidth) |
| 108 | + plt.scatter(tolerances, sd_vars, marker='x', lw=20) |
| 109 | + plt.xscale('log') |
| 110 | + plt.yscale('log') |
| 111 | + plt.ylim(2E-5, 2E-4) |
| 112 | + plt.Axes.set_aspect(plt.gca(), 1) |
| 113 | + # plt.ylabel("Absolute Error", fontsize=fsize) |
| 114 | + plt.xlabel('Tolerance', fontsize=fsize) |
| 115 | + plt.legend() |
| 116 | + plt.title(title, fontsize=1.25 * fsize) |
| 117 | + if save: |
| 118 | + _logger.info("Saving equipment experiments: variance convergence.") |
| 119 | + plt.savefig(f'{prefix}_convergence_mud_std_var.png', |
| 120 | + bbox_inches='tight') |
| 121 | + else: |
| 122 | + plt.show() |
| 123 | + |
| 124 | + |
| 125 | +def plot_experiment_measurements(res, prefix, |
| 126 | + fsize=32, linewidth=5, |
| 127 | + xlabel='Number of Measurements', |
| 128 | + save=True, legend=False): |
| 129 | + print("Plotting experiments involving increasing # of measurements.") |
| 130 | + plt.figure(figsize=(10, 10)) |
| 131 | + for _res in res: |
| 132 | + _prefix, _in, _rm, _re = _res |
| 133 | + solutions = _in[-1] |
| 134 | + measurements = list(solutions.keys()) |
| 135 | + regression_mean, slope_mean, \ |
| 136 | + regression_vars, slope_vars, \ |
| 137 | + means, variances = _rm |
| 138 | + plt.plot(measurements[:len(regression_mean)], regression_mean, |
| 139 | + label=f"{_prefix:4s} slope: {slope_mean:1.4f}", |
| 140 | + lw=linewidth) |
| 141 | + plt.scatter(measurements[:len(means)], means, marker='x', lw=20) |
| 142 | + plt.xscale('log') |
| 143 | + plt.yscale('log') |
| 144 | + plt.Axes.set_aspect(plt.gca(), 1) |
| 145 | + plt.ylim(0.9 * min(means), 1.3 * max(means)) |
| 146 | + plt.ylim(2E-3, 2E-1) |
| 147 | + plt.xlabel(xlabel, fontsize=fsize) |
| 148 | + if legend: |
| 149 | + plt.legend(fontsize=fsize * 0.8) |
| 150 | + # plt.ylabel('Absolute Error in MUD', fontsize=fsize) |
| 151 | + title = "$\\mathrm{\\mathbb{E}}(|\\lambda^* - \\lambda^\\dagger|)$" # noqa E501 |
| 152 | + plt.title(title, fontsize=1.25 * fsize) |
| 153 | + if save: |
| 154 | + fdir = '/'.join(prefix.split('/')[:-1]) |
| 155 | + check_dir(fdir) |
| 156 | + _logger.info("Saving measurement experiments: mean convergence.") |
| 157 | + plt.savefig(f'{prefix}_convergence_obs_mean.png', bbox_inches='tight') |
| 158 | + else: |
| 159 | + plt.show() |
| 160 | + |
| 161 | + plt.figure(figsize=(10, 10)) |
| 162 | + for _res in res: |
| 163 | + _prefix, _in, _rm, _re = _res |
| 164 | + regression_mean, slope_mean, \ |
| 165 | + regression_vars, slope_vars, \ |
| 166 | + means, variances = _rm |
| 167 | + plt.plot(measurements[:len(regression_vars)], regression_vars, |
| 168 | + label=f"{_prefix:4s} slope: {slope_vars:1.4f}", |
| 169 | + lw=linewidth) |
| 170 | + plt.scatter(measurements[:len(variances)], variances, |
| 171 | + marker='x', lw=20) |
| 172 | + plt.xscale('log') |
| 173 | + plt.yscale('log') |
| 174 | + plt.Axes.set_aspect(plt.gca(), 1) |
| 175 | +# if not len(np.unique(variances)) == 1: |
| 176 | +# plt.ylim(0.9 * min(variances), 1.3 * max(variances)) |
| 177 | + plt.ylim(5E-6, 5E-4) |
| 178 | + plt.xlabel(xlabel, fontsize=fsize) |
| 179 | + if legend: |
| 180 | + plt.legend(fontsize=fsize * 0.8) |
| 181 | + # plt.ylabel('Absolute Error in MUD', fontsize=fsize) |
| 182 | + plt.title("$\\mathrm{Var}(|\\lambda^* - \\lambda^\\dagger|)$", |
| 183 | + fontsize=1.25 * fsize) |
| 184 | + if save: |
| 185 | + _logger.info("Saving measurement experiments: variance convergence.") |
| 186 | + plt.savefig(f'{prefix}_convergence_obs_var.png', bbox_inches='tight') |
| 187 | + else: |
| 188 | + plt.show() |
0 commit comments