Skip to content

Commit 581f7b4

Browse files
Add option to save figure to plotting_utils
1 parent eeddace commit 581f7b4

2 files changed

Lines changed: 18 additions & 5 deletions

File tree

cobra/evaluation/evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ def _compute_optimal_cutoff(fpr: np.ndarray, tpr: np.ndarray,
381381
@staticmethod
382382
def _compute_cumulative_gains(y_true: np.ndarray,
383383
y_pred: np.ndarray) -> tuple:
384-
"""Compute lift of the model per decile, returns x-labels, lifts and
385-
the target incidence to create cummulative response curves
384+
"""Compute cumulative gains of the model, returns percentages and
385+
gains cummulative gains curves
386386
387387
Code from (https://github.com/reiinakano/scikit-plot/blob/
388388
2dd3e6a76df77edcbd724c4db25575f70abb57cb/
@@ -398,7 +398,7 @@ def _compute_cumulative_gains(y_true: np.ndarray,
398398
Returns
399399
-------
400400
tuple
401-
x-labels, lifts per decile and target incidence
401+
x-labels, gains
402402
"""
403403

404404
# make y_true a boolean vector

cobra/evaluation/plotting_utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def plot_univariate_predictor_quality(df_auc: pd.DataFrame,
4545

4646

4747
def plot_correlation_matrix(df_corr: pd.DataFrame,
48-
dim: tuple=(12, 8)):
48+
dim: tuple=(12, 8),
49+
path: str=None):
4950
"""Plot correlation matrix amongst the predictors
5051
5152
Parameters
@@ -54,10 +55,16 @@ def plot_correlation_matrix(df_corr: pd.DataFrame,
5455
Correlation matrix
5556
dim : tuple, optional
5657
tuple with width and lentgh of the plot
58+
path : str, optional
59+
path to store the figure
5760
"""
5861
fig, ax = plt.subplots(figsize=dim)
5962
ax = sns.heatmap(df_corr, cmap='Blues')
6063
ax.set_title('Correlation Matrix')
64+
65+
if path is not None:
66+
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
67+
6168
plt.show()
6269

6370

@@ -111,7 +118,8 @@ def plot_performance_curves(model_performance: pd.DataFrame,
111118

112119
def plot_variable_importance(df_variable_importance: pd.DataFrame,
113120
title: str=None,
114-
dim: tuple=(12, 8)):
121+
dim: tuple=(12, 8),
122+
path: str=None):
115123
"""Plot variable importance of a given model
116124
117125
Parameters
@@ -122,6 +130,8 @@ def plot_variable_importance(df_variable_importance: pd.DataFrame,
122130
Title of the plot
123131
dim : tuple, optional
124132
tuple with width and lentgh of the plot
133+
path : str, optional
134+
path to store the figure
125135
"""
126136
with plt.style.context("seaborn-ticks"):
127137
fig, ax = plt.subplots(figsize=dim)
@@ -139,4 +149,7 @@ def plot_variable_importance(df_variable_importance: pd.DataFrame,
139149
# Remove white lines from the second axis
140150
ax.grid(False)
141151

152+
if path is not None:
153+
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
154+
142155
plt.show()

0 commit comments

Comments
 (0)