99)
1010from PROBESt .models_registry import (
1111 ShallowNet , WideNet , ResidualNet , GAILDiscriminator , TabTransformer ,
12- GAILDeep , GAILWide , GAILNarrow , GAILWithDropout
12+ GAILDeep , GAILWide , GAILNarrow , GAILWithDropout ,
13+ GAILWideDeep , GAILWideDropout , GAILWideBatchNorm , GAILWideExtra , GAILWideBalanced
1314)
1415
1516MODELS = {
@@ -120,6 +121,10 @@ def main():
120121 print (f"Best model: { best_model_name } (F1: { results [best_model_name ]['f1' ]:.4f} )" )
121122 print (f"{ '=' * 60 } " )
122123
124+ # Store GAIL search results for potential use later
125+ gail_search_results = {}
126+ gail_trained_models = {}
127+
123128 # Architecture search for GAIL if it's the best model
124129 if best_model_name == "GAIL" :
125130 print ("\n " + "=" * 60 )
@@ -136,9 +141,6 @@ def main():
136141 "GAIL_Custom2" : lambda n : TorchClassifier (GAILDiscriminator (n , hidden1 = 192 , hidden2 = 96 ), weight_pos = 5 ),
137142 }
138143
139- gail_search_results = {}
140- gail_trained_models = {}
141-
142144 for variant_name , constructor in gail_variations .items ():
143145 print (f"\n ===== Training { variant_name } =====" )
144146 variant_model = constructor (input_size )
@@ -188,6 +190,135 @@ def main():
188190 else :
189191 print (f"\n Original GAIL remains the best." )
190192
193+ # Further architecture search for GAIL_Wide if it's the best model
194+ if best_model_name == "GAIL_Wide" or best_model_name .startswith ("GAIL_Wide" ):
195+ print ("\n " + "=" * 60 )
196+ print ("GAIL_Wide is the best model. Performing further architecture search..." )
197+ print ("=" * 60 )
198+
199+ # Define GAIL_Wide architecture variations
200+ gail_wide_variations = {
201+ "GAIL_Wide_Deep" : lambda n : TorchClassifier (GAILWideDeep (n ), weight_pos = 5 ),
202+ "GAIL_Wide_Dropout" : lambda n : TorchClassifier (GAILWideDropout (n ), weight_pos = 5 ),
203+ "GAIL_Wide_BatchNorm" : lambda n : TorchClassifier (GAILWideBatchNorm (n ), weight_pos = 5 ),
204+ "GAIL_Wide_Extra" : lambda n : TorchClassifier (GAILWideExtra (n ), weight_pos = 5 ),
205+ "GAIL_Wide_Balanced" : lambda n : TorchClassifier (GAILWideBalanced (n ), weight_pos = 5 ),
206+ "GAIL_Wide_Custom1" : lambda n : TorchClassifier (GAILWide (n , hidden1 = 640 , hidden2 = 320 ), weight_pos = 5 ),
207+ "GAIL_Wide_Custom2" : lambda n : TorchClassifier (GAILWide (n , hidden1 = 384 , hidden2 = 192 ), weight_pos = 5 ),
208+ "GAIL_Wide_Custom3" : lambda n : TorchClassifier (GAILWideDeep (n , hidden1 = 512 , hidden2 = 256 , hidden3 = 128 ), weight_pos = 5 ),
209+ }
210+
211+ gail_wide_search_results = {}
212+ gail_wide_trained_models = {}
213+
214+ for variant_name , constructor in gail_wide_variations .items ():
215+ print (f"\n ===== Training { variant_name } =====" )
216+ variant_model = constructor (input_size )
217+
218+ # Train with learning curve tracking (more epochs for better training)
219+ X_train = train_data .drop (columns = ['type' ])
220+ y_train = train_data ['type' ]
221+ variant_model .train (X_train , y_train , epochs = 150 , batch_size = 32 ,
222+ val_data = val_data , track_curves = True )
223+
224+ # Validate
225+ variant_val_metrics = validate_filtration_AI (
226+ variant_model , val_data , output_name = f"{ variant_name } .png"
227+ )
228+ gail_wide_search_results [variant_name ] = variant_val_metrics
229+ gail_wide_trained_models [variant_name ] = variant_model
230+
231+ print (f"\n { variant_name } validation metrics:" )
232+ for metric , value in variant_val_metrics .items ():
233+ print (f" { metric } : { value :.4f} " )
234+
235+ # Compare with original GAIL_Wide
236+ # Get the current best model's metrics
237+ if best_model_name == "GAIL_Wide" :
238+ # GAIL_Wide came from first GAIL search
239+ original_metrics = gail_search_results .get ("GAIL_Wide" , {})
240+ elif best_model_name in gail_search_results :
241+ # It's from the first GAIL search
242+ original_metrics = gail_search_results [best_model_name ]
243+ else :
244+ # Fallback to results dict
245+ original_metrics = results .get (best_model_name , {})
246+
247+ gail_wide_search_results ["GAIL_Wide_Original" ] = original_metrics
248+ gail_wide_trained_models ["GAIL_Wide_Original" ] = best_model
249+
250+ # Find best GAIL_Wide variant
251+ best_gail_wide_variant = max (gail_wide_search_results .keys (),
252+ key = lambda m : gail_wide_search_results [m ].get ("f1" , 0 ))
253+ best_gail_wide_model = gail_wide_trained_models [best_gail_wide_variant ]
254+
255+ print (f"\n { '=' * 60 } " )
256+ print (f"Best GAIL_Wide variant: { best_gail_wide_variant } (F1: { gail_wide_search_results [best_gail_wide_variant ].get ('f1' , 0 ):.4f} )" )
257+ print (f"{ '=' * 60 } " )
258+
259+ # Plot learning curves for best GAIL_Wide variant
260+ if hasattr (best_gail_wide_model , 'train_losses' ) and len (best_gail_wide_model .train_losses ) > 0 :
261+ print ("\n Generating learning curves for best GAIL_Wide variant..." )
262+ plot_learning_curves (best_gail_wide_model , output_dir = output_dir ,
263+ output_name = f'learning_curves_{ best_gail_wide_variant } .png' )
264+ print (f"Learning curves saved to { os .path .join (output_dir , f'learning_curves_{ best_gail_wide_variant } .png' )} " )
265+
266+ # Update best model if variant is better
267+ current_f1 = gail_wide_search_results .get ("GAIL_Wide_Original" , {}).get ("f1" , 0 )
268+ if gail_wide_search_results [best_gail_wide_variant ].get ("f1" , 0 ) > current_f1 :
269+ print (f"\n Best GAIL_Wide variant ({ best_gail_wide_variant } ) outperforms original!" )
270+ best_model = best_gail_wide_model
271+ best_model_name = best_gail_wide_variant
272+
273+ # Train the best model for even more epochs
274+ print (f"\n { '=' * 60 } " )
275+ print (f"Training { best_model_name } for extended epochs (300 epochs)..." )
276+ print (f"{ '=' * 60 } " )
277+ X_train = train_data .drop (columns = ['type' ])
278+ y_train = train_data ['type' ]
279+ best_model .train (X_train , y_train , epochs = 300 , batch_size = 32 ,
280+ val_data = val_data , track_curves = True )
281+
282+ # Re-validate after extended training
283+ final_val_metrics = validate_filtration_AI (
284+ best_model , val_data , output_name = f"{ best_model_name } _final.png"
285+ )
286+ print (f"\n Final { best_model_name } validation metrics after extended training:" )
287+ for metric , value in final_val_metrics .items ():
288+ print (f" { metric } : { value :.4f} " )
289+
290+ # Plot final learning curves
291+ if hasattr (best_model , 'train_losses' ) and len (best_model .train_losses ) > 0 :
292+ print ("\n Generating final learning curves..." )
293+ plot_learning_curves (best_model , output_dir = output_dir ,
294+ output_name = f'learning_curves_{ best_model_name } _final.png' )
295+ print (f"Final learning curves saved to { os .path .join (output_dir , f'learning_curves_{ best_model_name } _final.png' )} " )
296+ else :
297+ print (f"\n Original GAIL_Wide remains the best." )
298+ # Still train the best model for more epochs
299+ print (f"\n { '=' * 60 } " )
300+ print (f"Training { best_model_name } for extended epochs (300 epochs)..." )
301+ print (f"{ '=' * 60 } " )
302+ X_train = train_data .drop (columns = ['type' ])
303+ y_train = train_data ['type' ]
304+ best_model .train (X_train , y_train , epochs = 300 , batch_size = 32 ,
305+ val_data = val_data , track_curves = True )
306+
307+ # Re-validate after extended training
308+ final_val_metrics = validate_filtration_AI (
309+ best_model , val_data , output_name = f"{ best_model_name } _final.png"
310+ )
311+ print (f"\n Final { best_model_name } validation metrics after extended training:" )
312+ for metric , value in final_val_metrics .items ():
313+ print (f" { metric } : { value :.4f} " )
314+
315+ # Plot final learning curves
316+ if hasattr (best_model , 'train_losses' ) and len (best_model .train_losses ) > 0 :
317+ print ("\n Generating final learning curves..." )
318+ plot_learning_curves (best_model , output_dir = output_dir ,
319+ output_name = f'learning_curves_{ best_model_name } _final.png' )
320+ print (f"Final learning curves saved to { os .path .join (output_dir , f'learning_curves_{ best_model_name } _final.png' )} " )
321+
191322 # Plot learning curves for the best model (if it has learning curve data)
192323 if hasattr (best_model , 'train_losses' ) and len (best_model .train_losses ) > 0 :
193324 print ("\n Generating learning curves for best model..." )
0 commit comments