@@ -45,7 +45,8 @@ def plot_univariate_predictor_quality(df_auc: pd.DataFrame,
4545
4646
4747def plot_correlation_matrix (df_corr : pd .DataFrame ,
48- dim : tuple = (12 , 8 )):
48+ dim : tuple = (12 , 8 ),
49+ path : str = None ):
4950 """Plot correlation matrix amongst the predictors
5051
5152 Parameters
@@ -54,10 +55,16 @@ def plot_correlation_matrix(df_corr: pd.DataFrame,
5455 Correlation matrix
5556 dim : tuple, optional
5657 tuple with width and lentgh of the plot
58+ path : str, optional
59+ path to store the figure
5760 """
5861 fig , ax = plt .subplots (figsize = dim )
5962 ax = sns .heatmap (df_corr , cmap = 'Blues' )
6063 ax .set_title ('Correlation Matrix' )
64+
65+ if path is not None :
66+ plt .savefig (path , format = "png" , dpi = 300 , bbox_inches = "tight" )
67+
6168 plt .show ()
6269
6370
@@ -111,7 +118,8 @@ def plot_performance_curves(model_performance: pd.DataFrame,
111118
112119def plot_variable_importance (df_variable_importance : pd .DataFrame ,
113120 title : str = None ,
114- dim : tuple = (12 , 8 )):
121+ dim : tuple = (12 , 8 ),
122+ path : str = None ):
115123 """Plot variable importance of a given model
116124
117125 Parameters
@@ -122,6 +130,8 @@ def plot_variable_importance(df_variable_importance: pd.DataFrame,
122130 Title of the plot
123131 dim : tuple, optional
124132 tuple with width and lentgh of the plot
133+ path : str, optional
134+ path to store the figure
125135 """
126136 with plt .style .context ("seaborn-ticks" ):
127137 fig , ax = plt .subplots (figsize = dim )
@@ -139,4 +149,7 @@ def plot_variable_importance(df_variable_importance: pd.DataFrame,
139149 # Remove white lines from the second axis
140150 ax .grid (False )
141151
152+ if path is not None :
153+ plt .savefig (path , format = "png" , dpi = 300 , bbox_inches = "tight" )
154+
142155 plt .show ()
0 commit comments