|
2 | 2 | import pandas as pd |
3 | 3 | import numpy as np |
4 | 4 | from sklearn.model_selection import train_test_split |
5 | | -from PROBESt.AI import LogisticRegressionModel, PerceptronModel, DeepNeuralNetworkModel |
6 | | -from PROBESt.filtration import train_filtration_AI, validate_filtration_AI, apply_filtration_AI |
7 | | -from models_registry import ShallowNet, WideNet, ResidualNet, GAILDiscriminator, TabTransformer |
8 | | -from PROBESt.AI import TorchClassifier |
9 | | -from PROBESt.filtration import train_filtration_AI, validate_filtration_AI |
| 5 | +from PROBESt.AI import LogisticRegressionModel, DeepNeuralNetworkModel, TorchClassifier |
| 6 | +from PROBESt.filtration import ( |
| 7 | + train_filtration_AI, validate_filtration_AI, apply_filtration_AI, |
| 8 | + plot_combined_roc_curves, plot_combined_metrics, plot_learning_curves |
| 9 | +) |
| 10 | +from PROBESt.models_registry import ( |
| 11 | + ShallowNet, WideNet, ResidualNet, GAILDiscriminator, TabTransformer, |
| 12 | + GAILDeep, GAILWide, GAILNarrow, GAILWithDropout |
| 13 | +) |
10 | 14 |
|
11 | 15 | MODELS = { |
12 | 16 | "ShallowNet": lambda n: TorchClassifier(ShallowNet(n), weight_pos=5), |
@@ -41,55 +45,161 @@ def main(): |
41 | 45 | print(f"Validation set size: {len(val_data)}") |
42 | 46 | print(f"Test set size: {len(test_data)}") |
43 | 47 |
|
44 | | - # Train and evaluate deep neural network model |
45 | | - print("\nTraining Deep Neural Network model...") |
46 | | - dnn_model = DeepNeuralNetworkModel(input_size=train_data.shape[1] - 1) |
| 48 | + # Get input size |
| 49 | + input_size = train_data.shape[1] - 1 |
| 50 | + |
| 51 | + # Train and validate all models from MODELS dictionary |
| 52 | + results = {} |
| 53 | + trained_models = {} |
| 54 | + models_for_plots = {} # For combined plots: (model, val_data) |
| 55 | + |
| 56 | + for name, constructor in MODELS.items(): |
| 57 | + print(f"\n===== Training {name} =====") |
| 58 | + model = constructor(input_size) |
| 59 | + |
| 60 | + # For GAIL, track learning curves during training |
| 61 | + if name == "GAIL": |
| 62 | + X_train = train_data.drop(columns=['type']) |
| 63 | + y_train = train_data['type'] |
| 64 | + model.train(X_train, y_train, epochs=100, batch_size=32, |
| 65 | + val_data=val_data, track_curves=True) |
| 66 | + trained = model |
| 67 | + # Get training metrics from validation (we'll use val_metrics for results) |
| 68 | + metrics = {} |
| 69 | + else: |
| 70 | + trained, metrics = train_filtration_AI(model, train_data) |
| 71 | + |
| 72 | + val_metrics = validate_filtration_AI(trained, val_data, output_name=f"{name}.png") |
| 73 | + results[name] = val_metrics |
| 74 | + trained_models[name] = trained |
| 75 | + models_for_plots[name] = (trained, val_data) |
| 76 | + |
| 77 | + print(f"\n{name} validation metrics:") |
| 78 | + for metric, value in val_metrics.items(): |
| 79 | + print(f" {metric}: {value:.4f}") |
| 80 | + |
| 81 | + # Train and validate additional baseline models |
| 82 | + print("\n===== Training Deep Neural Network =====") |
| 83 | + dnn_model = DeepNeuralNetworkModel(input_size=input_size) |
47 | 84 | trained_dnn_model, dnn_metrics = train_filtration_AI(dnn_model, train_data) |
48 | | - |
49 | | - print("\nDeep Neural Network metrics:") |
50 | | - for metric, value in dnn_metrics.items(): |
51 | | - print(f"{metric}: {value:.4f}") |
52 | | - |
53 | | - # Validate deep neural network model |
54 | | - print("\nValidating Deep Neural Network model...") |
55 | 85 | dnn_val_metrics = validate_filtration_AI(trained_dnn_model, val_data, output_name='DNN.png') |
| 86 | + results["DeepNeuralNetwork"] = dnn_val_metrics |
| 87 | + trained_models["DeepNeuralNetwork"] = trained_dnn_model |
| 88 | + models_for_plots["DeepNeuralNetwork"] = (trained_dnn_model, val_data) |
56 | 89 |
|
57 | 90 | print("\nDeep Neural Network validation metrics:") |
58 | 91 | for metric, value in dnn_val_metrics.items(): |
59 | | - print(f"{metric}: {value:.4f}") |
| 92 | + print(f" {metric}: {value:.4f}") |
60 | 93 |
|
61 | | - # Train and evaluate logistic regression model |
62 | | - print("\nTraining Logistic Regression model...") |
| 94 | + print("\n===== Training Logistic Regression =====") |
63 | 95 | lr_model = LogisticRegressionModel() |
64 | 96 | trained_lr_model, lr_metrics = train_filtration_AI(lr_model, train_data) |
65 | | - |
66 | | - print("\nLogistic Regression metrics:") |
67 | | - for metric, value in lr_metrics.items(): |
68 | | - print(f"{metric}: {value:.4f}") |
69 | | - |
70 | | - # Validate logistic regression model |
71 | | - print("\nValidating Logistic Regression model...") |
72 | 97 | lr_val_metrics = validate_filtration_AI(trained_lr_model, val_data, output_name='LR.png') |
| 98 | + results["LogisticRegression"] = lr_val_metrics |
| 99 | + trained_models["LogisticRegression"] = trained_lr_model |
| 100 | + models_for_plots["LogisticRegression"] = (trained_lr_model, val_data) |
73 | 101 |
|
74 | 102 | print("\nLogistic Regression validation metrics:") |
75 | 103 | for metric, value in lr_val_metrics.items(): |
76 | | - print(f"{metric}: {value:.4f}") |
| 104 | + print(f" {metric}: {value:.4f}") |
| 105 | + |
| 106 | + # Generate combined plots |
| 107 | + print("\n===== Generating combined plots =====") |
| 108 | + output_dir = 'tests_outs' |
| 109 | + plot_combined_roc_curves(models_for_plots, output_dir=output_dir) |
| 110 | + print(f"Combined ROC curves saved to {os.path.join(output_dir, 'combined_roc_curves.png')}") |
| 111 | + |
| 112 | + plot_combined_metrics(results, output_dir=output_dir) |
| 113 | + print(f"Combined metrics plot saved to {os.path.join(output_dir, 'combined_metrics.png')}") |
| 114 | + |
| 115 | + # Select best model based on F1 score |
| 116 | + best_model_name = max(results.keys(), key=lambda m: results[m]["f1"]) |
| 117 | + best_model = trained_models[best_model_name] |
| 118 | + |
| 119 | + print(f"\n{'='*60}") |
| 120 | + print(f"Best model: {best_model_name} (F1: {results[best_model_name]['f1']:.4f})") |
| 121 | + print(f"{'='*60}") |
| 122 | + |
| 123 | + # Architecture search for GAIL if it's the best model |
| 124 | + if best_model_name == "GAIL": |
| 125 | + print("\n" + "="*60) |
| 126 | + print("GAIL is the best model. Performing architecture search...") |
| 127 | + print("="*60) |
| 128 | + |
| 129 | + # Define GAIL architecture variations |
| 130 | + gail_variations = { |
| 131 | + "GAIL_Deep": lambda n: TorchClassifier(GAILDeep(n), weight_pos=5), |
| 132 | + "GAIL_Wide": lambda n: TorchClassifier(GAILWide(n), weight_pos=5), |
| 133 | + "GAIL_Narrow": lambda n: TorchClassifier(GAILNarrow(n), weight_pos=5), |
| 134 | + "GAIL_Dropout": lambda n: TorchClassifier(GAILWithDropout(n), weight_pos=5), |
| 135 | + "GAIL_Custom1": lambda n: TorchClassifier(GAILDiscriminator(n, hidden1=384, hidden2=192), weight_pos=5), |
| 136 | + "GAIL_Custom2": lambda n: TorchClassifier(GAILDiscriminator(n, hidden1=192, hidden2=96), weight_pos=5), |
| 137 | + } |
| 138 | + |
| 139 | + gail_search_results = {} |
| 140 | + gail_trained_models = {} |
| 141 | + |
| 142 | + for variant_name, constructor in gail_variations.items(): |
| 143 | + print(f"\n===== Training {variant_name} =====") |
| 144 | + variant_model = constructor(input_size) |
| 145 | + |
| 146 | + # Train with learning curve tracking |
| 147 | + X_train = train_data.drop(columns=['type']) |
| 148 | + y_train = train_data['type'] |
| 149 | + variant_model.train(X_train, y_train, epochs=100, batch_size=32, |
| 150 | + val_data=val_data, track_curves=True) |
| 151 | + |
| 152 | + # Validate |
| 153 | + variant_val_metrics = validate_filtration_AI( |
| 154 | + variant_model, val_data, output_name=f"{variant_name}.png" |
| 155 | + ) |
| 156 | + gail_search_results[variant_name] = variant_val_metrics |
| 157 | + gail_trained_models[variant_name] = variant_model |
| 158 | + |
| 159 | + print(f"\n{variant_name} validation metrics:") |
| 160 | + for metric, value in variant_val_metrics.items(): |
| 161 | + print(f" {metric}: {value:.4f}") |
| 162 | + |
| 163 | + # Compare with original GAIL |
| 164 | + gail_search_results["GAIL_Original"] = results["GAIL"] |
| 165 | + gail_trained_models["GAIL_Original"] = trained_models["GAIL"] |
| 166 | + |
| 167 | + # Find best GAIL variant |
| 168 | + best_gail_variant = max(gail_search_results.keys(), |
| 169 | + key=lambda m: gail_search_results[m]["f1"]) |
| 170 | + best_gail_model = gail_trained_models[best_gail_variant] |
| 171 | + |
| 172 | + print(f"\n{'='*60}") |
| 173 | + print(f"Best GAIL variant: {best_gail_variant} (F1: {gail_search_results[best_gail_variant]['f1']:.4f})") |
| 174 | + print(f"{'='*60}") |
| 175 | + |
| 176 | + # Plot learning curves for best GAIL variant |
| 177 | + if hasattr(best_gail_model, 'train_losses') and len(best_gail_model.train_losses) > 0: |
| 178 | + print("\nGenerating learning curves for best GAIL variant...") |
| 179 | + plot_learning_curves(best_gail_model, output_dir=output_dir, |
| 180 | + output_name=f'learning_curves_{best_gail_variant}.png') |
| 181 | + print(f"Learning curves saved to {os.path.join(output_dir, f'learning_curves_{best_gail_variant}.png')}") |
| 182 | + |
| 183 | + # Update best model if variant is better |
| 184 | + if gail_search_results[best_gail_variant]["f1"] > results["GAIL"]["f1"]: |
| 185 | + print(f"\nBest GAIL variant ({best_gail_variant}) outperforms original GAIL!") |
| 186 | + best_model = best_gail_model |
| 187 | + best_model_name = best_gail_variant |
| 188 | + else: |
| 189 | + print(f"\nOriginal GAIL remains the best.") |
| 190 | + |
| 191 | + # Plot learning curves for the best model (if it has learning curve data) |
| 192 | + if hasattr(best_model, 'train_losses') and len(best_model.train_losses) > 0: |
| 193 | + print("\nGenerating learning curves for best model...") |
| 194 | + plot_learning_curves(best_model, output_dir=output_dir, |
| 195 | + output_name='learning_curves_best_model.png') |
| 196 | + print(f"Learning curves saved to {os.path.join(output_dir, 'learning_curves_best_model.png')}") |
77 | 197 |
|
78 | 198 | # Apply best model to test set |
79 | 199 | print("\nApplying best model to test set...") |
80 | | - models = { |
81 | | - "Deep Neural Network": (trained_dnn_model, dnn_val_metrics), |
82 | | - "Logistic Regression": (trained_lr_model, lr_val_metrics) |
83 | | - } |
84 | | - |
85 | | - best_model_name = max(models.keys(), key=lambda k: models[k][1]['f1']) |
86 | | - best_model = models[best_model_name][0] |
87 | | - |
88 | | - print(f"\nUsing {best_model_name} model for final predictions") |
89 | 200 | test_predictions = apply_filtration_AI(best_model, test_data) |
90 | 201 |
|
91 | 202 | # Save predictions |
92 | | - output_dir = 'tests_outs' |
93 | 203 | os.makedirs(output_dir, exist_ok=True) |
94 | 204 | test_predictions.to_csv(os.path.join(output_dir, 'test_predictions.csv'), index=False) |
95 | 205 | print(f"\nPredictions saved to {os.path.join(output_dir, 'test_predictions.csv')}") |
|
0 commit comments