Skip to content

Commit 5513cd4

Browse files
committed
add possibility to choose between normalized and unnormalized errors
1 parent 80b83c2 commit 5513cd4

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

petab/v1/visualize/plot_residuals.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def plot_goodness_of_fit(
136136
size: tuple = (10, 7),
137137
color=None,
138138
ax: plt.Axes | None = None,
139+
normalized_error: bool = True,
139140
) -> matplotlib.axes.Axes:
140141
"""
141142
Plot goodness of fit.
@@ -154,6 +155,9 @@ def plot_goodness_of_fit(
154155
`matplotlib.pyplot.scatter`.
155156
ax:
156157
Axis object.
158+
normalized_error:
159+
Type of error to display. If True, mean of squared normalized residuals is shown,
160+
otherwise mean of squared residuals.
157161
158162
Returns
159163
-------
@@ -168,15 +172,25 @@ def plot_goodness_of_fit(
168172
"are needed for goodness_of_fit"
169173
)
170174

171-
residual_df = calculate_residuals(
172-
measurement_dfs=petab_problem.measurement_df,
173-
simulation_dfs=simulations_df,
174-
observable_dfs=petab_problem.observable_df,
175-
parameter_dfs=petab_problem.parameter_df,
176-
normalize=True
177-
)[0]
178-
# compute mean of squared normalized residuals
179-
msnr = np.mean(np.power(residual_df["residual"], 2))
175+
if normalized_error:
176+
residual_df = calculate_residuals(
177+
measurement_dfs=petab_problem.measurement_df,
178+
simulation_dfs=simulations_df,
179+
observable_dfs=petab_problem.observable_df,
180+
parameter_dfs=petab_problem.parameter_df,
181+
normalize=True,
182+
)[0]
183+
error_name = "mean of squared\nnormalized residuals"
184+
else:
185+
residual_df = calculate_residuals(
186+
measurement_dfs=petab_problem.measurement_df,
187+
simulation_dfs=simulations_df,
188+
observable_dfs=petab_problem.observable_df,
189+
parameter_dfs=petab_problem.parameter_df,
190+
normalize=False,
191+
)[0]
192+
error_name = "mean of squared residuals"
193+
error = np.mean(np.power(residual_df["residual"], 2))
180194

181195
slope, intercept, r_value, p_value, std_err = stats.linregress(
182196
simulations_df["simulation"],
@@ -210,7 +224,7 @@ def plot_goodness_of_fit(
210224
f"slope: {slope:.2f}\n"
211225
f"intercept: {intercept:.2f}\n"
212226
f"p-value: {p_value:.2e}\n"
213-
f"mean of squared\nnormalized residuals: {msnr:.2e}\n",
227+
f"{error_name}: {error:.2e}\n",
214228
transform=ax.transAxes,
215229
)
216230

0 commit comments

Comments
 (0)