@@ -84,7 +84,8 @@ def plot_performance_curves(model_performance: pd.DataFrame,
8484 path : str = None ,
8585 colors : dict = {"train" : "#0099bf" ,
8686 "selection" : "#ff9500" ,
87- "validation" : "#8064a2" }):
87+ "validation" : "#8064a2" },
88+ metric_name : str = None ):
8889 """Plot performance curves generated by the forward feature selection
8990 for the train-selection-validation sets.
9091
@@ -97,14 +98,21 @@ def plot_performance_curves(model_performance: pd.DataFrame,
9798 Width and length of the plot.
9899 path : str, optional
99100 Path to store the figure.
101+ colors : dict, optional
102+ Map with colors for train-selection-validation curves.
103+ metric_name : str, optional
104+ Name to indicate the metric used in model_performance.
105+ Defaults to RMSE in case of regression and AUC in case of
106+ classification.
100107 """
101108
102109 model_type = model_performance ["model_type" ][0 ]
103110
104- if model_type == "classification" :
105- metric = "AUC"
106- elif model_type == "regression" :
107- metric = "RMSE"
111+ if metric_name is None :
112+ if model_type == "classification" :
113+ metric_name = "AUC"
114+ elif model_type == "regression" :
115+ metric_name = "RMSE"
108116
109117 max_metric = np .round (max (max (model_performance ['train_performance' ]),
110118 max (model_performance ['selection_performance' ]),
@@ -115,13 +123,13 @@ def plot_performance_curves(model_performance: pd.DataFrame,
115123 fig , ax = plt .subplots (figsize = dim )
116124
117125 plt .plot (model_performance ['train_performance' ], marker = "." ,
118- markersize = 20 , linewidth = 3 , label = metric + " train" ,
126+ markersize = 20 , linewidth = 3 , label = " train" ,
119127 color = colors ["train" ])
120128 plt .plot (model_performance ['selection_performance' ], marker = "." ,
121- markersize = 20 , linewidth = 3 , label = metric + " selection" ,
129+ markersize = 20 , linewidth = 3 , label = " selection" ,
122130 color = colors ["selection" ])
123131 plt .plot (model_performance ['validation_performance' ], marker = "." ,
124- markersize = 20 , linewidth = 3 , label = metric + " validation" ,
132+ markersize = 20 , linewidth = 3 , label = " validation" ,
125133 color = colors ["validation" ])
126134
127135 # Set x- and y-ticks
@@ -141,6 +149,7 @@ def plot_performance_curves(model_performance: pd.DataFrame,
141149 ax .legend (loc = 'lower right' )
142150 fig .suptitle ('Performance curves forward feature selection' ,
143151 fontsize = 20 )
152+ plt .title ("Metric: " + metric_name , fontsize = 15 , loc = "left" )
144153 plt .ylabel ('Model performance' )
145154
146155 if path is not None :
0 commit comments