1313 GAILDeep , GAILWide , GAILNarrow , GAILWithDropout ,
1414 GAILWideDeep , GAILWideDropout , GAILWideBatchNorm , GAILWideExtra , GAILWideBalanced
1515)
16+ from PROBESt .tokenization import tokenize_table
1617
1718MODELS = {
1819 "ShallowNet" : lambda n : TorchClassifier (ShallowNet (n ), weight_pos = 5 ),
2223 "TabTransformer" : lambda n : TorchClassifier (TabTransformer (n ), weight_pos = 5 ),
2324}
2425
25- def main ():
26- # Load data
27- data_path = 'data/databases/open/test_ML_database.csv'
28- data = pd .read_csv (data_path )
26+ def load_and_prepare_data (data_path , use_tokenized = False , add_tokens = 50 , force_regenerate = False ):
27+ """Load and prepare data, optionally with tokenization.
28+
29+ Args:
30+ data_path: Path to the CSV file
31+ use_tokenized: If True, tokenize sequences and add token columns
32+ add_tokens: Number of top k-mers to add as columns per sequence column (if use_tokenized=True)
33+ force_regenerate: If True, regenerate tokenized file even if it exists
34+
35+ Returns:
36+ Tuple of (train_data, val_data, test_data)
37+ """
38+ if use_tokenized :
39+ print (f"\n { '=' * 60 } " )
40+ print ("Tokenizing sequences..." )
41+ print (f"{ '=' * 60 } " )
42+ # Tokenize the table
43+ tokenized_path = data_path .rsplit ('.' , 1 )[0 ] + '_tokenized.csv'
44+ if force_regenerate or not os .path .exists (tokenized_path ):
45+ print (f"Generating tokenized file: { tokenized_path } " )
46+ tokenize_table (data_path , output_csv = tokenized_path , add_tokens = add_tokens ,
47+ drop_original_sequences = True )
48+ else :
49+ print (f"Using existing tokenized file: { tokenized_path } " )
50+ data = pd .read_csv (tokenized_path )
51+ print (f"Loaded tokenized data from { tokenized_path } " )
52+ else :
53+ data = pd .read_csv (data_path )
2954
3055 # Convert boolean 'type' column to numeric
3156 data ['type' ] = data ['type' ].astype (int )
@@ -47,6 +72,14 @@ def main():
4772 print (f"Validation set size: { len (val_data )} " )
4873 print (f"Test set size: { len (test_data )} " )
4974
75+ return train_data , val_data , test_data
76+
77+
78+ def main ():
79+ # Load data
80+ data_path = 'data/databases/open/test_ML_database.csv'
81+ train_data , val_data , test_data = load_and_prepare_data (data_path , use_tokenized = False )
82+
5083 # Get input size
5184 input_size = train_data .shape [1 ] - 1
5285
@@ -352,6 +385,95 @@ def main():
352385 os .makedirs (output_dir , exist_ok = True )
353386 test_predictions .to_csv (os .path .join (output_dir , 'test_predictions.csv' ), index = False )
354387 print (f"\n Predictions saved to { os .path .join (output_dir , 'test_predictions.csv' )} " )
388+
389+ # Test GAIL_Wide_Custom2 on both tokenized and non-tokenized data
390+ print ("\n " + "=" * 60 )
391+ print ("Testing GAIL_Wide_Custom2 on tokenized vs non-tokenized data" )
392+ print ("=" * 60 )
393+
394+ # Test on non-tokenized data (already loaded)
395+ print ("\n " + "-" * 60 )
396+ print ("Testing GAIL_Wide_Custom2 on NON-TOKENIZED data" )
397+ print ("-" * 60 )
398+ gail_wide_custom2_non_tokenized = TorchClassifier (
399+ GAILWide (input_size , hidden1 = 384 , hidden2 = 192 ), weight_pos = 5
400+ )
401+ X_train = train_data .drop (columns = ['type' ])
402+ y_train = train_data ['type' ]
403+ gail_wide_custom2_non_tokenized .train (
404+ X_train , y_train , epochs = 150 , batch_size = 32 ,
405+ val_data = val_data , track_curves = True
406+ )
407+ non_tokenized_metrics = validate_filtration_AI (
408+ gail_wide_custom2_non_tokenized , val_data ,
409+ output_name = 'GAIL_Wide_Custom2_non_tokenized.png'
410+ )
411+ print ("\n GAIL_Wide_Custom2 (non-tokenized) validation metrics:" )
412+ for metric , value in non_tokenized_metrics .items ():
413+ print (f" { metric } : { value :.4f} " )
414+
415+ # Plot learning curves for non-tokenized
416+ if hasattr (gail_wide_custom2_non_tokenized , 'train_losses' ) and len (gail_wide_custom2_non_tokenized .train_losses ) > 0 :
417+ plot_learning_curves (
418+ gail_wide_custom2_non_tokenized , output_dir = output_dir ,
419+ output_name = 'learning_curves_GAIL_Wide_Custom2_non_tokenized.png'
420+ )
421+ print (f"Learning curves saved to { os .path .join (output_dir , 'learning_curves_GAIL_Wide_Custom2_non_tokenized.png' )} " )
422+
423+ # Test on tokenized data
424+ print ("\n " + "-" * 60 )
425+ print ("Testing GAIL_Wide_Custom2 on TOKENIZED data" )
426+ print ("-" * 60 )
427+ train_data_tokenized , val_data_tokenized , test_data_tokenized = load_and_prepare_data (
428+ data_path , use_tokenized = True , add_tokens = 50 , force_regenerate = True
429+ )
430+ input_size_tokenized = train_data_tokenized .shape [1 ] - 1
431+ print (f"Tokenized data input size: { input_size_tokenized } (vs { input_size } for non-tokenized)" )
432+
433+ gail_wide_custom2_tokenized = TorchClassifier (
434+ GAILWide (input_size_tokenized , hidden1 = 384 , hidden2 = 192 ), weight_pos = 5
435+ )
436+ X_train_tokenized = train_data_tokenized .drop (columns = ['type' ])
437+ y_train_tokenized = train_data_tokenized ['type' ]
438+ gail_wide_custom2_tokenized .train (
439+ X_train_tokenized , y_train_tokenized , epochs = 150 , batch_size = 32 ,
440+ val_data = val_data_tokenized , track_curves = True
441+ )
442+ tokenized_metrics = validate_filtration_AI (
443+ gail_wide_custom2_tokenized , val_data_tokenized ,
444+ output_name = 'GAIL_Wide_Custom2_tokenized.png'
445+ )
446+ print ("\n GAIL_Wide_Custom2 (tokenized) validation metrics:" )
447+ for metric , value in tokenized_metrics .items ():
448+ print (f" { metric } : { value :.4f} " )
449+
450+ # Plot learning curves for tokenized
451+ if hasattr (gail_wide_custom2_tokenized , 'train_losses' ) and len (gail_wide_custom2_tokenized .train_losses ) > 0 :
452+ plot_learning_curves (
453+ gail_wide_custom2_tokenized , output_dir = output_dir ,
454+ output_name = 'learning_curves_GAIL_Wide_Custom2_tokenized.png'
455+ )
456+ print (f"Learning curves saved to { os .path .join (output_dir , 'learning_curves_GAIL_Wide_Custom2_tokenized.png' )} " )
457+
458+ # Compare results
459+ print ("\n " + "=" * 60 )
460+ print ("COMPARISON: GAIL_Wide_Custom2 - Tokenized vs Non-Tokenized" )
461+ print ("=" * 60 )
462+ print (f"{ 'Metric' :<20} { 'Non-Tokenized' :<15} { 'Tokenized' :<15} { 'Difference' :<15} " )
463+ print ("-" * 60 )
464+ for metric in non_tokenized_metrics .keys ():
465+ non_val = non_tokenized_metrics [metric ]
466+ tok_val = tokenized_metrics [metric ]
467+ diff = tok_val - non_val
468+ print (f"{ metric :<20} { non_val :<15.4f} { tok_val :<15.4f} { diff :+.4f} " )
469+
470+ # Determine winner
471+ if tokenized_metrics ['f1' ] > non_tokenized_metrics ['f1' ]:
472+ print (f"\n ✓ Tokenized version performs better (F1: { tokenized_metrics ['f1' ]:.4f} vs { non_tokenized_metrics ['f1' ]:.4f} )" )
473+ elif non_tokenized_metrics ['f1' ] > tokenized_metrics ['f1' ]:
474+ print (f"\n ✓ Non-tokenized version performs better (F1: { non_tokenized_metrics ['f1' ]:.4f} vs { tokenized_metrics ['f1' ]:.4f} )" )
475+ else :
476+ print (f"\n = Both versions perform equally (F1: { non_tokenized_metrics ['f1' ]:.4f} )" )
355477
356478if __name__ == '__main__' :
357479 main ()
0 commit comments