|
21 | 21 |
|
22 | 22 | import astropy.units as u |
23 | 23 | from astropy.modeling import fitting |
| 24 | +from astropy.modeling.functional_models import Gaussian1D, Linear1D |
| 25 | +from astropy.visualization import quantity_support |
24 | 26 |
|
25 | 27 | from sunkit_spex.data.simulated_data import simulate_square_response_matrix |
26 | 28 | from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func |
|
51 | 53 | # use a straight line model for a continuum, Gaussian for a line |
52 | 54 | ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line) |
53 | 55 |
|
54 | | -plt.figure() |
55 | | -plt.plot(ph_energies, ph_model(ph_energies)) |
56 | | -plt.title("Simulated Photon Spectrum") |
57 | | -plt.show() |
| 56 | +with quantity_support(): |
| 57 | + plt.figure() |
| 58 | + plt.plot(ph_energies, ph_model(ph_energies)) |
| 59 | + plt.xlabel(f"Energy [{ph_energies.unit}]") |
| 60 | + plt.title("Simulated Photon Spectrum") |
| 61 | + plt.show() |
58 | 62 |
|
59 | 63 | ##################################################### |
60 | 64 | # |
|
65 | 69 | matrix=srm * u.ct / u.ph, input_axis=SpectralAxis(ph_energies), output_axis=SpectralAxis(ph_energies) |
66 | 70 | ) |
67 | 71 |
|
68 | | -plt.figure() |
69 | | -plt.imshow( |
70 | | - srm, |
71 | | - origin="lower", |
72 | | - extent=(ph_energies[0].value, ph_energies[-1].value, ph_energies[0].value, ph_energies[-1].value), |
73 | | - norm=LogNorm(), |
74 | | -) |
75 | | -plt.ylabel("Photon Energies [keV]") |
76 | | -plt.xlabel("Count Energies [keV]") |
77 | | -plt.title("Simulated SRM") |
78 | | -plt.show() |
| 72 | +with quantity_support(): |
| 73 | + plt.figure() |
| 74 | + plt.imshow( |
| 75 | + srm_model.matrix.value, |
| 76 | + origin="lower", |
| 77 | + extent=( |
| 78 | + srm_model.inputs_axis[0].value, |
| 79 | + srm_model.inputs_axis[-1].value, |
| 80 | + srm_model.output_axis[0].value, |
| 81 | + srm_model.output_axis[-1].value, |
| 82 | + ), |
| 83 | + norm=LogNorm(), |
| 84 | + ) |
| 85 | + plt.ylabel(f"Photon Energies [{srm_model.inputs_axis.unit}]") |
| 86 | + plt.xlabel(f"Count Energies [{srm_model.output_axis.unit}]") |
| 87 | + plt.title("Simulated SRM") |
| 88 | + plt.show() |
79 | 89 |
|
80 | 90 | ##################################################### |
81 | 91 | # |
|
99 | 109 | sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model.value) * u.ct |
100 | 110 | ) |
101 | 111 |
|
102 | | -obs_spec = Spectrum(sim_count_model_wn, spectral_axis=ph_energies) |
| 112 | +obs_spec = Spectrum(sim_count_model_wn.reshape(-1), spectral_axis=ph_energies) |
103 | 113 |
|
104 | 114 | ##################################################### |
105 | 115 | # |
106 | 116 | # Can plot all the different components in the simulated count spectrum |
107 | 117 |
|
108 | | -plt.figure() |
109 | | -plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features") |
110 | | -plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature") |
111 | | -plt.plot(ph_energies, sim_count_model, label="total sim. spectrum") |
112 | | -plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5) |
113 | | -plt.xlabel("Energy [keV]") |
114 | | -plt.ylabel("cts s$^{-1}$ keV$^{-1}$") |
115 | | -plt.title("Simulated Count Spectrum") |
116 | | -plt.legend() |
| 118 | +with quantity_support(): |
| 119 | + plt.figure() |
| 120 | + plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features") |
| 121 | + plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature") |
| 122 | + plt.plot(ph_energies, sim_count_model, label="total sim. spectrum") |
| 123 | + plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5) |
| 124 | + plt.xlabel(f"Energy [{ph_energies.unit}]") |
| 125 | + plt.title("Simulated Count Spectrum") |
| 126 | + plt.legend() |
117 | 127 |
|
118 | | -plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold") |
119 | | -plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold") |
120 | | -plt.show() |
| 128 | + plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold") |
| 129 | + plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold") |
| 130 | + plt.show() |
121 | 131 |
|
122 | 132 | ##################################################### |
123 | 133 | # |
|
142 | 152 |
|
143 | 153 | opt_res = scipy_minimize(minimize_func, count_model_4fit.parameters, (obs_spec, count_model_4fit, chi_squared)) |
144 | 154 |
|
145 | | -plt.figure() |
146 | | -plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") |
147 | | -plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit") |
148 | | -plt.xlabel("Energy [keV]") |
149 | | -plt.ylabel("cts s$^{-1}$ keV$^{-1}$") |
150 | | -plt.title("Simulated Count Spectrum Fit with Scipy") |
151 | | -plt.legend() |
152 | | -plt.show() |
| 155 | +with quantity_support(): |
| 156 | + plt.figure() |
| 157 | + plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") |
| 158 | + plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit") |
| 159 | + plt.xlabel(f"Energy [{ph_energies.unit}]") |
| 160 | + plt.title("Simulated Count Spectrum Fit with Scipy") |
| 161 | + plt.legend() |
| 162 | + plt.show() |
153 | 163 |
|
154 | 164 |
|
155 | 165 | ##################################################### |
|
158 | 168 | # |
159 | 169 | # Try and ensure we start fresh with new model definitions |
160 | 170 |
|
161 | | -ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line) |
162 | | -count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss) |
| 171 | +ph_mod_4astropyfit = Linear1D(**guess_cont) + Gaussian1D(**guess_line) |
| 172 | +count_model_4astropyfit = (ph_mod_4astropyfit | srm_model) + Gaussian1D(**guess_gauss) |
163 | 173 |
|
164 | 174 | astropy_fit = fitting.LevMarLSQFitter() |
165 | 175 |
|
166 | | -astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn) |
| 176 | +astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, obs_spec.data << obs_spec.unit) |
167 | 177 |
|
168 | 178 | plt.figure() |
169 | 179 | plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") |
|
178 | 188 | # |
179 | 189 | # Display a table of the fitted results |
180 | 190 |
|
181 | | -plt.figure(layout="constrained") |
182 | | - |
183 | | -row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss)) |
184 | | -column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit") |
185 | | -true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values())) |
186 | | -guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values())) |
187 | | -scipy_vals = opt_res.x |
188 | | -astropy_vals = astropy_fitted_result.parameters |
189 | | -cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T |
190 | | -cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str) |
191 | | - |
192 | | -plt.axis("off") |
193 | | -plt.table( |
194 | | - cellText=cell_text, |
195 | | - cellColours=None, |
196 | | - cellLoc="center", |
197 | | - rowLabels=row_labels, |
198 | | - rowColours=None, |
199 | | - colLabels=column_labels, |
200 | | - colColours=None, |
201 | | - colLoc="center", |
202 | | - bbox=[0, 0, 1, 1], |
203 | | -) |
204 | | - |
205 | | -plt.show() |
| 191 | +# plt.figure(layout="constrained") |
| 192 | +# |
| 193 | +# row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss)) |
| 194 | +# column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit") |
| 195 | +# true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values())) |
| 196 | +# guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values())) |
| 197 | +# scipy_vals = opt_res.x |
| 198 | +# astropy_vals = astropy_fitted_result.parameters |
| 199 | +# cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T |
| 200 | +# cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str) |
| 201 | +# |
| 202 | +# plt.axis("off") |
| 203 | +# plt.table( |
| 204 | +# cellText=cell_text, |
| 205 | +# cellColours=None, |
| 206 | +# cellLoc="center", |
| 207 | +# rowLabels=row_labels, |
| 208 | +# rowColours=None, |
| 209 | +# colLabels=column_labels, |
| 210 | +# colColours=None, |
| 211 | +# colLoc="center", |
| 212 | +# bbox=[0, 0, 1, 1], |
| 213 | +# ) |
| 214 | +# |
| 215 | +# plt.show() |
0 commit comments