Skip to content

Commit 2e31e64

Browse files
committed
Add more advanced GAILs
1 parent f7a44bb commit 2e31e64

2 files changed

Lines changed: 219 additions & 4 deletions

File tree

scripts/generator/ML_filtration.py

Lines changed: 135 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
)
1010
from PROBESt.models_registry import (
1111
ShallowNet, WideNet, ResidualNet, GAILDiscriminator, TabTransformer,
12-
GAILDeep, GAILWide, GAILNarrow, GAILWithDropout
12+
GAILDeep, GAILWide, GAILNarrow, GAILWithDropout,
13+
GAILWideDeep, GAILWideDropout, GAILWideBatchNorm, GAILWideExtra, GAILWideBalanced
1314
)
1415

1516
MODELS = {
@@ -120,6 +121,10 @@ def main():
120121
print(f"Best model: {best_model_name} (F1: {results[best_model_name]['f1']:.4f})")
121122
print(f"{'='*60}")
122123

124+
# Store GAIL search results for potential use later
125+
gail_search_results = {}
126+
gail_trained_models = {}
127+
123128
# Architecture search for GAIL if it's the best model
124129
if best_model_name == "GAIL":
125130
print("\n" + "="*60)
@@ -136,9 +141,6 @@ def main():
136141
"GAIL_Custom2": lambda n: TorchClassifier(GAILDiscriminator(n, hidden1=192, hidden2=96), weight_pos=5),
137142
}
138143

139-
gail_search_results = {}
140-
gail_trained_models = {}
141-
142144
for variant_name, constructor in gail_variations.items():
143145
print(f"\n===== Training {variant_name} =====")
144146
variant_model = constructor(input_size)
@@ -188,6 +190,135 @@ def main():
188190
else:
189191
print(f"\nOriginal GAIL remains the best.")
190192

193+
# Further architecture search for GAIL_Wide if it's the best model
194+
if best_model_name == "GAIL_Wide" or best_model_name.startswith("GAIL_Wide"):
195+
print("\n" + "="*60)
196+
print("GAIL_Wide is the best model. Performing further architecture search...")
197+
print("="*60)
198+
199+
# Define GAIL_Wide architecture variations
200+
gail_wide_variations = {
201+
"GAIL_Wide_Deep": lambda n: TorchClassifier(GAILWideDeep(n), weight_pos=5),
202+
"GAIL_Wide_Dropout": lambda n: TorchClassifier(GAILWideDropout(n), weight_pos=5),
203+
"GAIL_Wide_BatchNorm": lambda n: TorchClassifier(GAILWideBatchNorm(n), weight_pos=5),
204+
"GAIL_Wide_Extra": lambda n: TorchClassifier(GAILWideExtra(n), weight_pos=5),
205+
"GAIL_Wide_Balanced": lambda n: TorchClassifier(GAILWideBalanced(n), weight_pos=5),
206+
"GAIL_Wide_Custom1": lambda n: TorchClassifier(GAILWide(n, hidden1=640, hidden2=320), weight_pos=5),
207+
"GAIL_Wide_Custom2": lambda n: TorchClassifier(GAILWide(n, hidden1=384, hidden2=192), weight_pos=5),
208+
"GAIL_Wide_Custom3": lambda n: TorchClassifier(GAILWideDeep(n, hidden1=512, hidden2=256, hidden3=128), weight_pos=5),
209+
}
210+
211+
gail_wide_search_results = {}
212+
gail_wide_trained_models = {}
213+
214+
for variant_name, constructor in gail_wide_variations.items():
215+
print(f"\n===== Training {variant_name} =====")
216+
variant_model = constructor(input_size)
217+
218+
# Train with learning curve tracking (more epochs for better training)
219+
X_train = train_data.drop(columns=['type'])
220+
y_train = train_data['type']
221+
variant_model.train(X_train, y_train, epochs=150, batch_size=32,
222+
val_data=val_data, track_curves=True)
223+
224+
# Validate
225+
variant_val_metrics = validate_filtration_AI(
226+
variant_model, val_data, output_name=f"{variant_name}.png"
227+
)
228+
gail_wide_search_results[variant_name] = variant_val_metrics
229+
gail_wide_trained_models[variant_name] = variant_model
230+
231+
print(f"\n{variant_name} validation metrics:")
232+
for metric, value in variant_val_metrics.items():
233+
print(f" {metric}: {value:.4f}")
234+
235+
# Compare with original GAIL_Wide
236+
# Get the current best model's metrics
237+
if best_model_name == "GAIL_Wide":
238+
# GAIL_Wide came from first GAIL search
239+
original_metrics = gail_search_results.get("GAIL_Wide", {})
240+
elif best_model_name in gail_search_results:
241+
# It's from the first GAIL search
242+
original_metrics = gail_search_results[best_model_name]
243+
else:
244+
# Fallback to results dict
245+
original_metrics = results.get(best_model_name, {})
246+
247+
gail_wide_search_results["GAIL_Wide_Original"] = original_metrics
248+
gail_wide_trained_models["GAIL_Wide_Original"] = best_model
249+
250+
# Find best GAIL_Wide variant
251+
best_gail_wide_variant = max(gail_wide_search_results.keys(),
252+
key=lambda m: gail_wide_search_results[m].get("f1", 0))
253+
best_gail_wide_model = gail_wide_trained_models[best_gail_wide_variant]
254+
255+
print(f"\n{'='*60}")
256+
print(f"Best GAIL_Wide variant: {best_gail_wide_variant} (F1: {gail_wide_search_results[best_gail_wide_variant].get('f1', 0):.4f})")
257+
print(f"{'='*60}")
258+
259+
# Plot learning curves for best GAIL_Wide variant
260+
if hasattr(best_gail_wide_model, 'train_losses') and len(best_gail_wide_model.train_losses) > 0:
261+
print("\nGenerating learning curves for best GAIL_Wide variant...")
262+
plot_learning_curves(best_gail_wide_model, output_dir=output_dir,
263+
output_name=f'learning_curves_{best_gail_wide_variant}.png')
264+
print(f"Learning curves saved to {os.path.join(output_dir, f'learning_curves_{best_gail_wide_variant}.png')}")
265+
266+
# Update best model if variant is better
267+
current_f1 = gail_wide_search_results.get("GAIL_Wide_Original", {}).get("f1", 0)
268+
if gail_wide_search_results[best_gail_wide_variant].get("f1", 0) > current_f1:
269+
print(f"\nBest GAIL_Wide variant ({best_gail_wide_variant}) outperforms original!")
270+
best_model = best_gail_wide_model
271+
best_model_name = best_gail_wide_variant
272+
273+
# Train the best model for even more epochs
274+
print(f"\n{'='*60}")
275+
print(f"Training {best_model_name} for extended epochs (300 epochs)...")
276+
print(f"{'='*60}")
277+
X_train = train_data.drop(columns=['type'])
278+
y_train = train_data['type']
279+
best_model.train(X_train, y_train, epochs=300, batch_size=32,
280+
val_data=val_data, track_curves=True)
281+
282+
# Re-validate after extended training
283+
final_val_metrics = validate_filtration_AI(
284+
best_model, val_data, output_name=f"{best_model_name}_final.png"
285+
)
286+
print(f"\nFinal {best_model_name} validation metrics after extended training:")
287+
for metric, value in final_val_metrics.items():
288+
print(f" {metric}: {value:.4f}")
289+
290+
# Plot final learning curves
291+
if hasattr(best_model, 'train_losses') and len(best_model.train_losses) > 0:
292+
print("\nGenerating final learning curves...")
293+
plot_learning_curves(best_model, output_dir=output_dir,
294+
output_name=f'learning_curves_{best_model_name}_final.png')
295+
print(f"Final learning curves saved to {os.path.join(output_dir, f'learning_curves_{best_model_name}_final.png')}")
296+
else:
297+
print(f"\nOriginal GAIL_Wide remains the best.")
298+
# Still train the best model for more epochs
299+
print(f"\n{'='*60}")
300+
print(f"Training {best_model_name} for extended epochs (300 epochs)...")
301+
print(f"{'='*60}")
302+
X_train = train_data.drop(columns=['type'])
303+
y_train = train_data['type']
304+
best_model.train(X_train, y_train, epochs=300, batch_size=32,
305+
val_data=val_data, track_curves=True)
306+
307+
# Re-validate after extended training
308+
final_val_metrics = validate_filtration_AI(
309+
best_model, val_data, output_name=f"{best_model_name}_final.png"
310+
)
311+
print(f"\nFinal {best_model_name} validation metrics after extended training:")
312+
for metric, value in final_val_metrics.items():
313+
print(f" {metric}: {value:.4f}")
314+
315+
# Plot final learning curves
316+
if hasattr(best_model, 'train_losses') and len(best_model.train_losses) > 0:
317+
print("\nGenerating final learning curves...")
318+
plot_learning_curves(best_model, output_dir=output_dir,
319+
output_name=f'learning_curves_{best_model_name}_final.png')
320+
print(f"Final learning curves saved to {os.path.join(output_dir, f'learning_curves_{best_model_name}_final.png')}")
321+
191322
# Plot learning curves for the best model (if it has learning curve data)
192323
if hasattr(best_model, 'train_losses') and len(best_model.train_losses) > 0:
193324
print("\nGenerating learning curves for best model...")

src/PROBESt/models_registry.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,90 @@ def __init__(self, input_size, hidden1=512, hidden2=256):
100100
def forward(self, x):
101101
return self.net(x)
102102

103+
# ---------- GAIL_Wide variations for further architecture search ----------
104+
class GAILWideDeep(nn.Module):
105+
"""Wide and deep GAIL with 3 layers"""
106+
def __init__(self, input_size, hidden1=512, hidden2=256, hidden3=128):
107+
super().__init__()
108+
self.net = nn.Sequential(
109+
nn.Linear(input_size, hidden1),
110+
nn.LeakyReLU(0.2),
111+
nn.Linear(hidden1, hidden2),
112+
nn.LeakyReLU(0.2),
113+
nn.Linear(hidden2, hidden3),
114+
nn.LeakyReLU(0.2),
115+
nn.Linear(hidden3, 1),
116+
nn.Sigmoid()
117+
)
118+
def forward(self, x):
119+
return self.net(x)
120+
121+
class GAILWideDropout(nn.Module):
122+
"""Wide GAIL with dropout regularization"""
123+
def __init__(self, input_size, hidden1=512, hidden2=256, dropout=0.3):
124+
super().__init__()
125+
self.net = nn.Sequential(
126+
nn.Linear(input_size, hidden1),
127+
nn.LeakyReLU(0.2),
128+
nn.Dropout(dropout),
129+
nn.Linear(hidden1, hidden2),
130+
nn.LeakyReLU(0.2),
131+
nn.Dropout(dropout),
132+
nn.Linear(hidden2, 1),
133+
nn.Sigmoid()
134+
)
135+
def forward(self, x):
136+
return self.net(x)
137+
138+
class GAILWideBatchNorm(nn.Module):
139+
"""Wide GAIL with batch normalization"""
140+
def __init__(self, input_size, hidden1=512, hidden2=256):
141+
super().__init__()
142+
self.net = nn.Sequential(
143+
nn.Linear(input_size, hidden1),
144+
nn.BatchNorm1d(hidden1),
145+
nn.LeakyReLU(0.2),
146+
nn.Linear(hidden1, hidden2),
147+
nn.BatchNorm1d(hidden2),
148+
nn.LeakyReLU(0.2),
149+
nn.Linear(hidden2, 1),
150+
nn.Sigmoid()
151+
)
152+
def forward(self, x):
153+
return self.net(x)
154+
155+
class GAILWideExtra(nn.Module):
156+
"""Extra wide GAIL with even larger layers"""
157+
def __init__(self, input_size, hidden1=768, hidden2=384):
158+
super().__init__()
159+
self.net = nn.Sequential(
160+
nn.Linear(input_size, hidden1),
161+
nn.LeakyReLU(0.2),
162+
nn.Linear(hidden1, hidden2),
163+
nn.LeakyReLU(0.2),
164+
nn.Linear(hidden2, 1),
165+
nn.Sigmoid()
166+
)
167+
def forward(self, x):
168+
return self.net(x)
169+
170+
class GAILWideBalanced(nn.Module):
171+
"""Wide GAIL with balanced layer sizes"""
172+
def __init__(self, input_size, hidden1=512, hidden2=256, hidden3=128):
173+
super().__init__()
174+
self.net = nn.Sequential(
175+
nn.Linear(input_size, hidden1),
176+
nn.LeakyReLU(0.2),
177+
nn.Linear(hidden1, hidden2),
178+
nn.LeakyReLU(0.2),
179+
nn.Linear(hidden2, hidden3),
180+
nn.LeakyReLU(0.2),
181+
nn.Linear(hidden3, 1),
182+
nn.Sigmoid()
183+
)
184+
def forward(self, x):
185+
return self.net(x)
186+
103187
class GAILNarrow(nn.Module):
104188
"""Narrower GAIL with smaller hidden layers"""
105189
def __init__(self, input_size, hidden1=128, hidden2=64):

0 commit comments

Comments
 (0)