Skip to content

Commit 076f9ef

Browse files
authored
Goodness of fit fix (#473)
* fix mean of residuals in plot_goodness_of_fit * add possibility to choose between normalized and unnormalized errors
1 parent 7a60b2e commit 076f9ef

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

petab/v1/visualize/plot_residuals.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def plot_goodness_of_fit(
136136
size: tuple = (10, 7),
137137
color=None,
138138
ax: plt.Axes | None = None,
139+
normalized_error: bool = True,
139140
) -> matplotlib.axes.Axes:
140141
"""
141142
Plot goodness of fit.
@@ -154,6 +155,10 @@ def plot_goodness_of_fit(
154155
`matplotlib.pyplot.scatter`.
155156
ax:
156157
Axis object.
158+
normalized_error:
159+
Type of error to display.
160+
If True, mean of squared normalized residuals is shown,
161+
otherwise mean of squared residuals.
157162
158163
Returns
159164
-------
@@ -168,12 +173,26 @@ def plot_goodness_of_fit(
168173
"are needed for goodness_of_fit"
169174
)
170175

171-
residual_df = calculate_residuals(
172-
measurement_dfs=petab_problem.measurement_df,
173-
simulation_dfs=simulations_df,
174-
observable_dfs=petab_problem.observable_df,
175-
parameter_dfs=petab_problem.parameter_df,
176-
)[0]
176+
if normalized_error:
177+
residual_df = calculate_residuals(
178+
measurement_dfs=petab_problem.measurement_df,
179+
simulation_dfs=simulations_df,
180+
observable_dfs=petab_problem.observable_df,
181+
parameter_dfs=petab_problem.parameter_df,
182+
normalize=True,
183+
)[0]
184+
error_name = "mean of squared\nnormalized residuals"
185+
else:
186+
residual_df = calculate_residuals(
187+
measurement_dfs=petab_problem.measurement_df,
188+
simulation_dfs=simulations_df,
189+
observable_dfs=petab_problem.observable_df,
190+
parameter_dfs=petab_problem.parameter_df,
191+
normalize=False,
192+
)[0]
193+
error_name = "mean of squared residuals"
194+
error = np.mean(np.power(residual_df["residual"], 2))
195+
177196
slope, intercept, r_value, p_value, std_err = stats.linregress(
178197
simulations_df["simulation"],
179198
petab_problem.measurement_df["measurement"],
@@ -199,15 +218,14 @@ def plot_goodness_of_fit(
199218
ax.plot(x, x, linestyle="--", color="gray")
200219
ax.plot(x, intercept + slope * x, "r", label="fitted line")
201220

202-
mse = np.mean(np.abs(residual_df["residual"]))
203221
ax.text(
204222
0.1,
205223
0.70,
206224
f"$R^2$: {r_value**2:.2f}\n"
207225
f"slope: {slope:.2f}\n"
208226
f"intercept: {intercept:.2f}\n"
209227
f"p-value: {p_value:.2e}\n"
210-
f"mean squared error: {mse:.2e}\n",
228+
f"{error_name}: {error:.2e}\n",
211229
transform=ax.transAxes,
212230
)
213231

0 commit comments

Comments
 (0)