Skip to content

Commit f7a44bb

Browse files
committed
Update AI module
1 parent c3d2de9 commit f7a44bb

5 files changed

Lines changed: 440 additions & 52 deletions

File tree

scripts/generator/ML_filtration.py

Lines changed: 146 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
import pandas as pd
33
import numpy as np
44
from sklearn.model_selection import train_test_split
5-
from PROBESt.AI import LogisticRegressionModel, PerceptronModel, DeepNeuralNetworkModel
6-
from PROBESt.filtration import train_filtration_AI, validate_filtration_AI, apply_filtration_AI
7-
from models_registry import ShallowNet, WideNet, ResidualNet, GAILDiscriminator, TabTransformer
8-
from PROBESt.AI import TorchClassifier
9-
from PROBESt.filtration import train_filtration_AI, validate_filtration_AI
5+
from PROBESt.AI import LogisticRegressionModel, DeepNeuralNetworkModel, TorchClassifier
6+
from PROBESt.filtration import (
7+
train_filtration_AI, validate_filtration_AI, apply_filtration_AI,
8+
plot_combined_roc_curves, plot_combined_metrics, plot_learning_curves
9+
)
10+
from PROBESt.models_registry import (
11+
ShallowNet, WideNet, ResidualNet, GAILDiscriminator, TabTransformer,
12+
GAILDeep, GAILWide, GAILNarrow, GAILWithDropout
13+
)
1014

1115
MODELS = {
1216
"ShallowNet": lambda n: TorchClassifier(ShallowNet(n), weight_pos=5),
@@ -41,55 +45,161 @@ def main():
4145
print(f"Validation set size: {len(val_data)}")
4246
print(f"Test set size: {len(test_data)}")
4347

44-
# Train and evaluate deep neural network model
45-
print("\nTraining Deep Neural Network model...")
46-
dnn_model = DeepNeuralNetworkModel(input_size=train_data.shape[1] - 1)
48+
# Get input size
49+
input_size = train_data.shape[1] - 1
50+
51+
# Train and validate all models from MODELS dictionary
52+
results = {}
53+
trained_models = {}
54+
models_for_plots = {} # For combined plots: (model, val_data)
55+
56+
for name, constructor in MODELS.items():
57+
print(f"\n===== Training {name} =====")
58+
model = constructor(input_size)
59+
60+
# For GAIL, track learning curves during training
61+
if name == "GAIL":
62+
X_train = train_data.drop(columns=['type'])
63+
y_train = train_data['type']
64+
model.train(X_train, y_train, epochs=100, batch_size=32,
65+
val_data=val_data, track_curves=True)
66+
trained = model
67+
# Get training metrics from validation (we'll use val_metrics for results)
68+
metrics = {}
69+
else:
70+
trained, metrics = train_filtration_AI(model, train_data)
71+
72+
val_metrics = validate_filtration_AI(trained, val_data, output_name=f"{name}.png")
73+
results[name] = val_metrics
74+
trained_models[name] = trained
75+
models_for_plots[name] = (trained, val_data)
76+
77+
print(f"\n{name} validation metrics:")
78+
for metric, value in val_metrics.items():
79+
print(f" {metric}: {value:.4f}")
80+
81+
# Train and validate additional baseline models
82+
print("\n===== Training Deep Neural Network =====")
83+
dnn_model = DeepNeuralNetworkModel(input_size=input_size)
4784
trained_dnn_model, dnn_metrics = train_filtration_AI(dnn_model, train_data)
48-
49-
print("\nDeep Neural Network metrics:")
50-
for metric, value in dnn_metrics.items():
51-
print(f"{metric}: {value:.4f}")
52-
53-
# Validate deep neural network model
54-
print("\nValidating Deep Neural Network model...")
5585
dnn_val_metrics = validate_filtration_AI(trained_dnn_model, val_data, output_name='DNN.png')
86+
results["DeepNeuralNetwork"] = dnn_val_metrics
87+
trained_models["DeepNeuralNetwork"] = trained_dnn_model
88+
models_for_plots["DeepNeuralNetwork"] = (trained_dnn_model, val_data)
5689

5790
print("\nDeep Neural Network validation metrics:")
5891
for metric, value in dnn_val_metrics.items():
59-
print(f"{metric}: {value:.4f}")
92+
print(f" {metric}: {value:.4f}")
6093

61-
# Train and evaluate logistic regression model
62-
print("\nTraining Logistic Regression model...")
94+
print("\n===== Training Logistic Regression =====")
6395
lr_model = LogisticRegressionModel()
6496
trained_lr_model, lr_metrics = train_filtration_AI(lr_model, train_data)
65-
66-
print("\nLogistic Regression metrics:")
67-
for metric, value in lr_metrics.items():
68-
print(f"{metric}: {value:.4f}")
69-
70-
# Validate logistic regression model
71-
print("\nValidating Logistic Regression model...")
7297
lr_val_metrics = validate_filtration_AI(trained_lr_model, val_data, output_name='LR.png')
98+
results["LogisticRegression"] = lr_val_metrics
99+
trained_models["LogisticRegression"] = trained_lr_model
100+
models_for_plots["LogisticRegression"] = (trained_lr_model, val_data)
73101

74102
print("\nLogistic Regression validation metrics:")
75103
for metric, value in lr_val_metrics.items():
76-
print(f"{metric}: {value:.4f}")
104+
print(f" {metric}: {value:.4f}")
105+
106+
# Generate combined plots
107+
print("\n===== Generating combined plots =====")
108+
output_dir = 'tests_outs'
109+
plot_combined_roc_curves(models_for_plots, output_dir=output_dir)
110+
print(f"Combined ROC curves saved to {os.path.join(output_dir, 'combined_roc_curves.png')}")
111+
112+
plot_combined_metrics(results, output_dir=output_dir)
113+
print(f"Combined metrics plot saved to {os.path.join(output_dir, 'combined_metrics.png')}")
114+
115+
# Select best model based on F1 score
116+
best_model_name = max(results.keys(), key=lambda m: results[m]["f1"])
117+
best_model = trained_models[best_model_name]
118+
119+
print(f"\n{'='*60}")
120+
print(f"Best model: {best_model_name} (F1: {results[best_model_name]['f1']:.4f})")
121+
print(f"{'='*60}")
122+
123+
# Architecture search for GAIL if it's the best model
124+
if best_model_name == "GAIL":
125+
print("\n" + "="*60)
126+
print("GAIL is the best model. Performing architecture search...")
127+
print("="*60)
128+
129+
# Define GAIL architecture variations
130+
gail_variations = {
131+
"GAIL_Deep": lambda n: TorchClassifier(GAILDeep(n), weight_pos=5),
132+
"GAIL_Wide": lambda n: TorchClassifier(GAILWide(n), weight_pos=5),
133+
"GAIL_Narrow": lambda n: TorchClassifier(GAILNarrow(n), weight_pos=5),
134+
"GAIL_Dropout": lambda n: TorchClassifier(GAILWithDropout(n), weight_pos=5),
135+
"GAIL_Custom1": lambda n: TorchClassifier(GAILDiscriminator(n, hidden1=384, hidden2=192), weight_pos=5),
136+
"GAIL_Custom2": lambda n: TorchClassifier(GAILDiscriminator(n, hidden1=192, hidden2=96), weight_pos=5),
137+
}
138+
139+
gail_search_results = {}
140+
gail_trained_models = {}
141+
142+
for variant_name, constructor in gail_variations.items():
143+
print(f"\n===== Training {variant_name} =====")
144+
variant_model = constructor(input_size)
145+
146+
# Train with learning curve tracking
147+
X_train = train_data.drop(columns=['type'])
148+
y_train = train_data['type']
149+
variant_model.train(X_train, y_train, epochs=100, batch_size=32,
150+
val_data=val_data, track_curves=True)
151+
152+
# Validate
153+
variant_val_metrics = validate_filtration_AI(
154+
variant_model, val_data, output_name=f"{variant_name}.png"
155+
)
156+
gail_search_results[variant_name] = variant_val_metrics
157+
gail_trained_models[variant_name] = variant_model
158+
159+
print(f"\n{variant_name} validation metrics:")
160+
for metric, value in variant_val_metrics.items():
161+
print(f" {metric}: {value:.4f}")
162+
163+
# Compare with original GAIL
164+
gail_search_results["GAIL_Original"] = results["GAIL"]
165+
gail_trained_models["GAIL_Original"] = trained_models["GAIL"]
166+
167+
# Find best GAIL variant
168+
best_gail_variant = max(gail_search_results.keys(),
169+
key=lambda m: gail_search_results[m]["f1"])
170+
best_gail_model = gail_trained_models[best_gail_variant]
171+
172+
print(f"\n{'='*60}")
173+
print(f"Best GAIL variant: {best_gail_variant} (F1: {gail_search_results[best_gail_variant]['f1']:.4f})")
174+
print(f"{'='*60}")
175+
176+
# Plot learning curves for best GAIL variant
177+
if hasattr(best_gail_model, 'train_losses') and len(best_gail_model.train_losses) > 0:
178+
print("\nGenerating learning curves for best GAIL variant...")
179+
plot_learning_curves(best_gail_model, output_dir=output_dir,
180+
output_name=f'learning_curves_{best_gail_variant}.png')
181+
print(f"Learning curves saved to {os.path.join(output_dir, f'learning_curves_{best_gail_variant}.png')}")
182+
183+
# Update best model if variant is better
184+
if gail_search_results[best_gail_variant]["f1"] > results["GAIL"]["f1"]:
185+
print(f"\nBest GAIL variant ({best_gail_variant}) outperforms original GAIL!")
186+
best_model = best_gail_model
187+
best_model_name = best_gail_variant
188+
else:
189+
print(f"\nOriginal GAIL remains the best.")
190+
191+
# Plot learning curves for the best model (if it has learning curve data)
192+
if hasattr(best_model, 'train_losses') and len(best_model.train_losses) > 0:
193+
print("\nGenerating learning curves for best model...")
194+
plot_learning_curves(best_model, output_dir=output_dir,
195+
output_name='learning_curves_best_model.png')
196+
print(f"Learning curves saved to {os.path.join(output_dir, 'learning_curves_best_model.png')}")
77197

78198
# Apply best model to test set
79199
print("\nApplying best model to test set...")
80-
models = {
81-
"Deep Neural Network": (trained_dnn_model, dnn_val_metrics),
82-
"Logistic Regression": (trained_lr_model, lr_val_metrics)
83-
}
84-
85-
best_model_name = max(models.keys(), key=lambda k: models[k][1]['f1'])
86-
best_model = models[best_model_name][0]
87-
88-
print(f"\nUsing {best_model_name} model for final predictions")
89200
test_predictions = apply_filtration_AI(best_model, test_data)
90201

91202
# Save predictions
92-
output_dir = 'tests_outs'
93203
os.makedirs(output_dir, exist_ok=True)
94204
test_predictions.to_csv(os.path.join(output_dir, 'test_predictions.csv'), index=False)
95205
print(f"\nPredictions saved to {os.path.join(output_dir, 'test_predictions.csv')}")

src/PROBESt/AI.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,37 @@ def __init__(self, model: nn.Module, learning_rate=0.001, weight_pos=1.0):
136136

137137
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
138138

139-
def train(self, X, y, epochs=100, batch_size=32):
139+
# Track learning curves
140+
self.train_losses = []
141+
self.val_losses = []
142+
self.val_metrics_history = {'f1': [], 'accuracy': [], 'recall': [], 'precision': []}
143+
144+
def train(self, X, y, epochs=100, batch_size=32, val_data=None, track_curves=False):
140145
X_scaled = self.preprocess_data(X)
141146
X_tensor = torch.FloatTensor(X_scaled)
142147
y_tensor = torch.FloatTensor(y.values).reshape(-1, 1)
143148

144149
dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)
145150
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
146151

152+
# Prepare validation data if provided
153+
val_X_tensor = None
154+
val_y_tensor = None
155+
if val_data is not None and track_curves:
156+
val_X = val_data.drop(columns=['type'])
157+
val_y = val_data['type']
158+
val_X_scaled = self.scaler.transform(val_X)
159+
val_X_tensor = torch.FloatTensor(val_X_scaled)
160+
val_y_tensor = torch.FloatTensor(val_y.values).reshape(-1, 1)
161+
162+
# Reset tracking if tracking curves
163+
if track_curves:
164+
self.train_losses = []
165+
self.val_losses = []
166+
self.val_metrics_history = {'f1': [], 'accuracy': [], 'recall': [], 'precision': []}
167+
147168
for e in range(epochs):
169+
self.model.train()
148170
total_loss = 0
149171
for bx, by in loader:
150172
self.optimizer.zero_grad()
@@ -154,8 +176,31 @@ def train(self, X, y, epochs=100, batch_size=32):
154176
self.optimizer.step()
155177
total_loss += loss.item()
156178

179+
avg_loss = total_loss / len(loader)
180+
if track_curves:
181+
self.train_losses.append(avg_loss)
182+
183+
# Evaluate on validation set if provided
184+
if val_X_tensor is not None and track_curves:
185+
self.model.eval()
186+
with torch.no_grad():
187+
val_logits = self.model(val_X_tensor)
188+
val_loss = self.criterion(val_logits, val_y_tensor).item()
189+
self.val_losses.append(val_loss)
190+
191+
# Calculate metrics
192+
val_pred_proba = torch.sigmoid(val_logits).numpy()
193+
val_pred = (val_pred_proba > 0.5).astype(int)
194+
val_y_np = val_y_tensor.numpy().flatten()
195+
196+
from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score
197+
self.val_metrics_history['f1'].append(f1_score(val_y_np, val_pred))
198+
self.val_metrics_history['accuracy'].append(accuracy_score(val_y_np, val_pred))
199+
self.val_metrics_history['recall'].append(recall_score(val_y_np, val_pred))
200+
self.val_metrics_history['precision'].append(precision_score(val_y_np, val_pred, zero_division=0))
201+
157202
if e % 20 == 0:
158-
print(f"Epoch {e}: loss = {total_loss:.4f}")
203+
print(f"Epoch {e}: loss = {avg_loss:.4f}")
159204

160205
def predict(self, X):
161206
X_scaled = self.scaler.transform(X)

0 commit comments

Comments
 (0)