Skip to content

Commit 438ef86

Browse files
author
sborms
committed
metric_name arg in plot_performance_curves
1 parent eccb344 commit 438ef86

2 files changed

Lines changed: 151 additions & 142 deletions

File tree

cobra/evaluation/plotting_utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def plot_performance_curves(model_performance: pd.DataFrame,
8484
path: str=None,
8585
colors: dict={"train": "#0099bf",
8686
"selection": "#ff9500",
87-
"validation": "#8064a2"}):
87+
"validation": "#8064a2"},
88+
metric_name: str=None):
8889
"""Plot performance curves generated by the forward feature selection
8990
for the train-selection-validation sets.
9091
@@ -97,14 +98,21 @@ def plot_performance_curves(model_performance: pd.DataFrame,
9798
Width and length of the plot.
9899
path : str, optional
99100
Path to store the figure.
101+
colors : dict, optional
102+
Map with colors for train-selection-validation curves.
103+
metric_name : str, optional
104+
Name to indicate the metric used in model_performance.
105+
Defaults to RMSE in case of regression and AUC in case of
106+
classification.
100107
"""
101108

102109
model_type = model_performance["model_type"][0]
103110

104-
if model_type == "classification":
105-
metric = "AUC"
106-
elif model_type == "regression":
107-
metric = "RMSE"
111+
if metric_name is None:
112+
if model_type == "classification":
113+
metric_name = "AUC"
114+
elif model_type == "regression":
115+
metric_name = "RMSE"
108116

109117
max_metric = np.round(max(max(model_performance['train_performance']),
110118
max(model_performance['selection_performance']),
@@ -115,13 +123,13 @@ def plot_performance_curves(model_performance: pd.DataFrame,
115123
fig, ax = plt.subplots(figsize=dim)
116124

117125
plt.plot(model_performance['train_performance'], marker=".",
118-
markersize=20, linewidth=3, label=metric+" train",
126+
markersize=20, linewidth=3, label="train",
119127
color=colors["train"])
120128
plt.plot(model_performance['selection_performance'], marker=".",
121-
markersize=20, linewidth=3, label=metric+" selection",
129+
markersize=20, linewidth=3, label="selection",
122130
color=colors["selection"])
123131
plt.plot(model_performance['validation_performance'], marker=".",
124-
markersize=20, linewidth=3, label=metric+" validation",
132+
markersize=20, linewidth=3, label="validation",
125133
color=colors["validation"])
126134

127135
# Set x- and y-ticks
@@ -141,6 +149,7 @@ def plot_performance_curves(model_performance: pd.DataFrame,
141149
ax.legend(loc='lower right')
142150
fig.suptitle('Performance curves forward feature selection',
143151
fontsize=20)
152+
plt.title("Metric: "+metric_name, fontsize=15, loc="left")
144153
plt.ylabel('Model performance')
145154

146155
if path is not None:

0 commit comments

Comments
 (0)