Skip to content

Commit 35bd114

Browse files
committed
add set_plot_defaults function
1 parent b147e68 commit 35bd114

3 files changed

Lines changed: 247 additions & 136 deletions

File tree

src/diffpy/srfit/fitbase/fitrecipe.py

Lines changed: 130 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,32 @@ def __init__(self, name="fit"):
155155
self._contributions = OrderedDict()
156156
self._manage(self._contributions)
157157

158+
self.plot_options = {
159+
"show_observed": True,
160+
"show_fit": True,
161+
"show_diff": True,
162+
"offset_scale": 1.0,
163+
"figsize": (8, 6),
164+
"data_style": "o",
165+
"fit_style": "-",
166+
"diff_style": "-",
167+
"data_color": None,
168+
"fit_color": None,
169+
"diff_color": None,
170+
"data_label": "Observed",
171+
"fit_label": "Calculated",
172+
"diff_label": "Difference",
173+
"xlabel": None,
174+
"ylabel": None,
175+
"title": None,
176+
"legend": True,
177+
"legend_loc": "best",
178+
"grid": False,
179+
"markersize": None,
180+
"linewidth": None,
181+
"alpha": 1.0,
182+
"show": True,
183+
}
158184
return
159185

160186
def pushFitHook(self, fithook, index=None):
@@ -875,35 +901,30 @@ def getBounds2(self):
875901
ub = array([b[1] for b in bounds])
876902
return lb, ub
877903

878-
def plot_recipe(
879-
self,
880-
show_observed=True,
881-
show_fit=True,
882-
show_diff=True,
883-
offset_scale=1.0,
884-
figsize=(8, 6),
885-
data_style="o",
886-
fit_style="-",
887-
diff_style="-",
888-
data_color=None,
889-
fit_color=None,
890-
diff_color=None,
891-
data_label="Observed",
892-
fit_label="Calculated",
893-
diff_label="Difference",
894-
xlabel=None,
895-
ylabel=None,
896-
title=None,
897-
legend=True,
898-
legend_loc="best",
899-
grid=False,
900-
markersize=None,
901-
linewidth=None,
902-
alpha=1.0,
903-
show=True,
904-
ax=None,
905-
return_fig=False,
906-
):
904+
def set_plot_defaults(self, **kwargs):
905+
"""Set default plotting options for all future plots.
906+
907+
Any keyword argument accepted by plot_recipe() can be set here.
908+
909+
Examples
910+
--------
911+
>>> recipe.set_plot_defaults(
912+
... xlabel='r (Å)',
913+
... ylabel='G(r) (Å⁻²)',
914+
... data_color='black',
915+
... fit_color='red'
916+
... )
917+
"""
918+
919+
for key in kwargs:
920+
if key not in self.plot_options:
921+
print(
922+
f"Warning: '{key}' is not a valid "
923+
"plot_recipe option and will be ignored."
924+
)
925+
self.plot_options.update(kwargs)
926+
927+
def plot_recipe(self, ax=None, return_fig=False, **kwargs):
907928
"""The fit recipe data, calculated fit, and difference curve are
908929
plotted.
909930
@@ -912,6 +933,17 @@ def plot_recipe(
912933
913934
Parameters
914935
----------
936+
ax : matplotlib.axes.Axes or None, optional
937+
The axes object to plot on. If None, creates a new figure.
938+
Default is None.
939+
return_fig : bool, optional
940+
The figure and axes objects are returned if True. Default is False.
941+
**kwargs : dict
942+
Any plotting option can be passed to override the defaults in
943+
recipe.plot_options.
944+
945+
Keyword Arguments
946+
-----------------
915947
show_observed : bool, optional
916948
The observed data is plotted if True. Default is True.
917949
show_fit : bool, optional
@@ -980,25 +1012,51 @@ def plot_recipe(
9801012
9811013
Examples
9821014
--------
983-
The plot is created with default settings:
1015+
Plot with default settings:
9841016
9851017
>>> recipe.plot_recipe()
9861018
987-
The data and fit are plotted (no difference curve):
1019+
Override defaults for one plot:
1020+
1021+
>>> recipe.plot_recipe(show_diff=False, title='My Custom Title')
1022+
1023+
Set defaults once, use everywhere:
1024+
1025+
>>> recipe.set_plot_defaults(xlabel='r (Å)', ylabel='G(r)')
1026+
>>> recipe.plot_recipe() # Uses xlabel and ylabel
1027+
>>> recipe.plot_recipe() # Still uses them
1028+
1029+
Override a default for one plot:
9881030
989-
>>> recipe.plot_recipe(show_diff=False)
1031+
>>> recipe.set_plot_defaults(figsize=(10, 7))
1032+
>>> recipe.plot_recipe() # Uses (10, 7)
1033+
>>> recipe.plot_recipe(figsize=(12, 8)) # Temporarily uses (12, 8)
1034+
>>> recipe.plot_recipe() # Back to (10, 7)
9901035
991-
The data is plotted to check before refinement:
1036+
Notes
1037+
-----
1038+
Default values are taken from recipe.plot_options. You can modify
1039+
these defaults in three ways:
9921040
993-
>>> recipe.plot_recipe(show_fit=False, show_diff=False)
1041+
1. Using set_plot_defaults():
1042+
recipe.set_plot_defaults(xlabel='r (Å)')
9941043
995-
The figure object is retrieved for further customization:
1044+
2. Direct attribute access:
1045+
recipe.plot_options['xlabel'] = 'r (Å)'
9961046
997-
>>> fig, axes = recipe.plot_recipe(show=False, return_fig=True)
998-
>>> axes[0].set_yscale('log')
999-
>>> plt.savefig('my_fit.png', dpi=300)
1047+
3. Using update():
1048+
recipe.plot_options.update({'xlabel': 'r (Å)', 'ylabel': 'G(r)'})
10001049
"""
1001-
if not any([show_observed, show_fit, show_diff]):
1050+
plot_params = self.plot_options.copy()
1051+
plot_params.update(kwargs)
1052+
1053+
if not any(
1054+
[
1055+
plot_params["show_observed"],
1056+
plot_params["show_fit"],
1057+
plot_params["show_diff"],
1058+
]
1059+
):
10021060
raise ValueError(
10031061
"At least one of show_observed, show_fit, "
10041062
"or show_diff must be True"
@@ -1009,85 +1067,83 @@ def plot_recipe(
10091067
"No contributions found in recipe. "
10101068
"Add contributions before plotting."
10111069
)
1012-
10131070
figures = []
10141071
axes_list = []
1015-
10161072
for name, contrib in self._contributions.items():
10171073
profile = contrib.profile
10181074
x = profile.x
10191075
yobs = profile.y
10201076
ycalc = profile.ycalc
10211077
if ycalc is None:
1022-
if show_fit or show_diff:
1078+
if plot_params["show_fit"] or plot_params["show_diff"]:
10231079
print(
10241080
f"Contribution '{name}' has no calculated values "
10251081
"(ycalc is None). "
10261082
"Only observed data will be plotted."
10271083
)
1028-
show_fit = False
1029-
show_diff = False
1084+
plot_params["show_fit"] = False
1085+
plot_params["show_diff"] = False
10301086
else:
10311087
diff = yobs - ycalc
10321088
y_min = min(yobs.min(), ycalc.min())
10331089
y_max = max(yobs.max(), ycalc.max())
10341090
y_range = y_max - y_min
10351091
base_offset = y_min - 0.1 * y_range
1036-
offset = base_offset * offset_scale
1037-
1092+
offset = base_offset * plot_params["offset_scale"]
10381093
if ax is None:
1039-
fig = plt.figure(figsize=figsize)
1094+
fig = plt.figure(figsize=plot_params["figsize"])
10401095
current_ax = fig.add_subplot(111)
10411096
else:
10421097
current_ax = ax
10431098
fig = current_ax.figure
1044-
if show_observed:
1099+
if plot_params["show_observed"]:
10451100
current_ax.plot(
10461101
x,
10471102
yobs,
1048-
data_style,
1049-
label=data_label,
1050-
color=data_color,
1051-
markersize=markersize,
1052-
alpha=alpha,
1103+
plot_params["data_style"],
1104+
label=plot_params["data_label"],
1105+
color=plot_params["data_color"],
1106+
markersize=plot_params["markersize"],
1107+
alpha=plot_params["alpha"],
10531108
)
1054-
if show_fit:
1109+
if plot_params["show_fit"]:
10551110
current_ax.plot(
10561111
x,
10571112
ycalc,
1058-
fit_style,
1059-
label=fit_label,
1060-
color=fit_color,
1061-
linewidth=linewidth,
1062-
alpha=alpha,
1113+
plot_params["fit_style"],
1114+
label=plot_params["fit_label"],
1115+
color=plot_params["fit_color"],
1116+
linewidth=plot_params["linewidth"],
1117+
alpha=plot_params["alpha"],
10631118
)
1064-
if show_diff:
1119+
if plot_params["show_diff"]:
10651120
current_ax.plot(
10661121
x,
10671122
diff + offset,
1068-
diff_style,
1069-
label=diff_label,
1070-
color=diff_color,
1071-
linewidth=linewidth,
1072-
alpha=alpha,
1123+
plot_params["diff_style"],
1124+
label=plot_params["diff_label"],
1125+
color=plot_params["diff_color"],
1126+
linewidth=plot_params["linewidth"],
1127+
alpha=plot_params["alpha"],
10731128
)
10741129
current_ax.axhline(
10751130
offset,
10761131
color="black",
10771132
)
1078-
current_ax.set_xlabel(xlabel)
1079-
current_ax.set_ylabel(ylabel)
1080-
1081-
if title is not None:
1082-
current_ax.set_title(title)
1083-
if legend:
1084-
current_ax.legend(loc=legend_loc, frameon=True)
1085-
if grid:
1133+
if plot_params["xlabel"] is not None:
1134+
current_ax.set_xlabel(plot_params["xlabel"])
1135+
if plot_params["ylabel"] is not None:
1136+
current_ax.set_ylabel(plot_params["ylabel"])
1137+
if plot_params["title"] is not None:
1138+
current_ax.set_title(plot_params["title"])
1139+
if plot_params["legend"]:
1140+
current_ax.legend(loc=plot_params["legend_loc"], frameon=True)
1141+
if plot_params["grid"]:
10861142
current_ax.grid(True)
10871143
fig.tight_layout()
10881144
figures.append(fig)
10891145
axes_list.append(current_ax)
1090-
if show and ax is None:
1146+
if plot_params["show"] and ax is None:
10911147
plt.show()
10921148
if return_fig:
10931149
if len(figures) == 1:

tests/conftest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
import pytest
77
import six
8+
from numpy import linspace, pi, sin
89

910
import diffpy.srfit.equation.literals as literals
11+
from diffpy.srfit.fitbase import FitContribution, FitRecipe, Profile
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -142,3 +144,51 @@ def _capturestdout(f, *args, **kwargs):
142144
return fp.getvalue()
143145

144146
return _capturestdout
147+
148+
149+
@pytest.fixture(scope="session")
150+
def build_recipe_one_contribution():
151+
"helper to build a simple recipe"
152+
profile = Profile()
153+
x = linspace(0, pi, 10)
154+
y = sin(x)
155+
profile.setObservedProfile(x, y)
156+
contribution = FitContribution("c1")
157+
contribution.setProfile(profile)
158+
contribution.setEquation("A*sin(k*x + c)")
159+
recipe = FitRecipe()
160+
recipe.addContribution(contribution)
161+
recipe.addVar(contribution.A, 1)
162+
recipe.addVar(contribution.k, 1)
163+
recipe.addVar(contribution.c, 1)
164+
return recipe
165+
166+
167+
@pytest.fixture(scope="session")
168+
def build_recipe_two_contributions():
169+
"helper to build a recipe with two contributions"
170+
profile1 = Profile()
171+
x = linspace(0, pi, 10)
172+
y1 = sin(x)
173+
profile1.setObservedProfile(x, y1)
174+
contribution1 = FitContribution("c1")
175+
contribution1.setProfile(profile1)
176+
contribution1.setEquation("A*sin(k*x + c)")
177+
178+
profile2 = Profile()
179+
y2 = 0.5 * sin(2 * x)
180+
profile2.setObservedProfile(x, y2)
181+
contribution2 = FitContribution("c2")
182+
contribution2.setProfile(profile2)
183+
contribution2.setEquation("B*sin(m*x + d)")
184+
recipe = FitRecipe()
185+
recipe.addContribution(contribution1)
186+
recipe.addContribution(contribution2)
187+
recipe.addVar(contribution1.A, 1)
188+
recipe.addVar(contribution1.k, 1)
189+
recipe.addVar(contribution1.c, 1)
190+
recipe.addVar(contribution2.B, 0.5)
191+
recipe.addVar(contribution2.m, 2)
192+
recipe.addVar(contribution2.d, 0)
193+
194+
return recipe

0 commit comments

Comments
 (0)