Skip to content

Commit 30e3a84

Browse files
committed
Scipy fitting works can't get astropy fitting to work
1 parent a722920 commit 30e3a84

2 files changed

Lines changed: 90 additions & 66 deletions

File tree

examples/fitting_simulated_data.py

Lines changed: 74 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import astropy.units as u
2323
from astropy.modeling import fitting
24+
from astropy.modeling.functional_models import Gaussian1D, Linear1D
25+
from astropy.visualization import quantity_support
2426

2527
from sunkit_spex.data.simulated_data import simulate_square_response_matrix
2628
from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func
@@ -51,10 +53,12 @@
5153
# use a straight line model for a continuum, Gaussian for a line
5254
ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line)
5355

54-
plt.figure()
55-
plt.plot(ph_energies, ph_model(ph_energies))
56-
plt.title("Simulated Photon Spectrum")
57-
plt.show()
56+
with quantity_support():
57+
plt.figure()
58+
plt.plot(ph_energies, ph_model(ph_energies))
59+
plt.xlabel(f"Energy [{ph_energies.unit}]")
60+
plt.title("Simulated Photon Spectrum")
61+
plt.show()
5862

5963
#####################################################
6064
#
@@ -65,17 +69,23 @@
6569
matrix=srm * u.ct / u.ph, input_axis=SpectralAxis(ph_energies), output_axis=SpectralAxis(ph_energies)
6670
)
6771

68-
plt.figure()
69-
plt.imshow(
70-
srm,
71-
origin="lower",
72-
extent=(ph_energies[0].value, ph_energies[-1].value, ph_energies[0].value, ph_energies[-1].value),
73-
norm=LogNorm(),
74-
)
75-
plt.ylabel("Photon Energies [keV]")
76-
plt.xlabel("Count Energies [keV]")
77-
plt.title("Simulated SRM")
78-
plt.show()
72+
with quantity_support():
73+
plt.figure()
74+
plt.imshow(
75+
srm_model.matrix.value,
76+
origin="lower",
77+
extent=(
78+
srm_model.inputs_axis[0].value,
79+
srm_model.inputs_axis[-1].value,
80+
srm_model.output_axis[0].value,
81+
srm_model.output_axis[-1].value,
82+
),
83+
norm=LogNorm(),
84+
)
85+
plt.ylabel(f"Photon Energies [{srm_model.inputs_axis.unit}]")
86+
plt.xlabel(f"Count Energies [{srm_model.output_axis.unit}]")
87+
plt.title("Simulated SRM")
88+
plt.show()
7989

8090
#####################################################
8191
#
@@ -99,25 +109,25 @@
99109
sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model.value) * u.ct
100110
)
101111

102-
obs_spec = Spectrum(sim_count_model_wn, spectral_axis=ph_energies)
112+
obs_spec = Spectrum(sim_count_model_wn.reshape(-1), spectral_axis=ph_energies)
103113

104114
#####################################################
105115
#
106116
# Can plot all the different components in the simulated count spectrum
107117

108-
plt.figure()
109-
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
110-
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
111-
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
112-
plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5)
113-
plt.xlabel("Energy [keV]")
114-
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
115-
plt.title("Simulated Count Spectrum")
116-
plt.legend()
118+
with quantity_support():
119+
plt.figure()
120+
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
121+
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
122+
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
123+
plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5)
124+
plt.xlabel(f"Energy [{ph_energies.unit}]")
125+
plt.title("Simulated Count Spectrum")
126+
plt.legend()
117127

118-
plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
119-
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
120-
plt.show()
128+
plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
129+
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
130+
plt.show()
121131

122132
#####################################################
123133
#
@@ -142,14 +152,14 @@
142152

143153
opt_res = scipy_minimize(minimize_func, count_model_4fit.parameters, (obs_spec, count_model_4fit, chi_squared))
144154

145-
plt.figure()
146-
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
147-
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit")
148-
plt.xlabel("Energy [keV]")
149-
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
150-
plt.title("Simulated Count Spectrum Fit with Scipy")
151-
plt.legend()
152-
plt.show()
155+
with quantity_support():
156+
plt.figure()
157+
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
158+
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit")
159+
plt.xlabel(f"Energy [{ph_energies.unit}]")
160+
plt.title("Simulated Count Spectrum Fit with Scipy")
161+
plt.legend()
162+
plt.show()
153163

154164

155165
#####################################################
@@ -158,12 +168,12 @@
158168
#
159169
# Try and ensure we start fresh with new model definitions
160170

161-
ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line)
162-
count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss)
171+
ph_mod_4astropyfit = Linear1D(**guess_cont) + Gaussian1D(**guess_line)
172+
count_model_4astropyfit = (ph_mod_4astropyfit | srm_model) + Gaussian1D(**guess_gauss)
163173

164174
astropy_fit = fitting.LevMarLSQFitter()
165175

166-
astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn)
176+
astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, obs_spec.data << obs_spec.unit)
167177

168178
plt.figure()
169179
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
@@ -178,28 +188,28 @@
178188
#
179189
# Display a table of the fitted results
180190

181-
plt.figure(layout="constrained")
182-
183-
row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss))
184-
column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit")
185-
true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values()))
186-
guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values()))
187-
scipy_vals = opt_res.x
188-
astropy_vals = astropy_fitted_result.parameters
189-
cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T
190-
cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str)
191-
192-
plt.axis("off")
193-
plt.table(
194-
cellText=cell_text,
195-
cellColours=None,
196-
cellLoc="center",
197-
rowLabels=row_labels,
198-
rowColours=None,
199-
colLabels=column_labels,
200-
colColours=None,
201-
colLoc="center",
202-
bbox=[0, 0, 1, 1],
203-
)
204-
205-
plt.show()
191+
# plt.figure(layout="constrained")
192+
#
193+
# row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss))
194+
# column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit")
195+
# true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values()))
196+
# guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values()))
197+
# scipy_vals = opt_res.x
198+
# astropy_vals = astropy_fitted_result.parameters
199+
# cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T
200+
# cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str)
201+
#
202+
# plt.axis("off")
203+
# plt.table(
204+
# cellText=cell_text,
205+
# cellColours=None,
206+
# cellLoc="center",
207+
# rowLabels=row_labels,
208+
# rowColours=None,
209+
# colLabels=column_labels,
210+
# colColours=None,
211+
# colLoc="center",
212+
# bbox=[0, 0, 1, 1],
213+
# )
214+
#
215+
# plt.show()
Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,31 @@
11
"""Module for model components required for instrument response models."""
22

3+
import astropy.units as u
34
from astropy.modeling import Fittable1DModel, Parameter
45

56
__all__ = ["MatrixModel"]
67

78

89
class MatrixModel(Fittable1DModel):
10+
# matrix = Parameter(description="The matrix with which to multiply the input.", fixed=True)
11+
912
def __init__(self, matrix, input_axis, output_axis):
1013
self.matrix = Parameter(default=matrix, description="The matrix with which to multiply the input.", fixed=True)
1114
self.inputs_axis = input_axis
1215
self.output_axis = output_axis
1316
super().__init__()
1417

15-
def evaluate(self, model_y):
18+
def evaluate(self, x):
1619
# Requires input must have a specific dimensionality
17-
return model_y @ self.matrix
20+
return x @ self.matrix
21+
22+
@property
23+
def input_units(self):
24+
return {"x": u.ph}
25+
26+
@property
27+
def output_units(self):
28+
return {"y": u.ct}
29+
30+
def _parameter_units_for_data_units(self, inputs_unit, outputs_unit):
31+
return {"x": u.ph, "y": u.ct}

0 commit comments

Comments
 (0)