@@ -10,8 +10,8 @@ def test_grea_predictor():
1010 'CNC[C@@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@H]1C' ,
1111 'C[C@H]1CN([C@@H](C)CO)C(=O)CCCn2cc(nn2)CO[C@@H]1CN(C)C(=O)CCC(F)(F)F' ,
1212 'CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F'
13- ]
14- properties = np .array ([0 , 0 , 1 , 1 ]) # Binary classification
13+ ] * 10
14+ properties = np .array ([0 , 0 , 1 , 1 ] * 10 ) # Binary classification
1515
1616 # 1. Basic initialization test
1717 print ("\n === Testing GREA model initialization ===" )
@@ -29,7 +29,7 @@ def test_grea_predictor():
2929
3030 # 2. Basic fitting test
3131 print ("\n === Testing GREA model fitting ===" )
32- model .fit (smiles_list [: 3 ], properties [: 3 ])
32+ model .fit (smiles_list , properties , X_val = smiles_list [ 3 : ], y_val = properties [3 : ])
3333 print ("GREA model fitting completed" )
3434
3535 # 3. Prediction test
@@ -70,13 +70,15 @@ def test_grea_predictor():
7070 model_auto = GREAMolecularPredictor (
7171 num_task = 1 ,
7272 task_type = "classification" ,
73- epochs = 3 ,
73+ epochs = 50 ,
7474 verbose = True
7575 )
7676
7777 model_auto .autofit (
7878 smiles_list ,
7979 properties ,
80+ X_val = smiles_list [3 :],
81+ y_val = properties [3 :],
8082 search_parameters = search_parameters ,
8183 n_trials = 2
8284 )
@@ -92,13 +94,15 @@ def test_grea_predictor():
9294 model_partial = GREAMolecularPredictor (
9395 num_task = 1 ,
9496 task_type = "classification" ,
95- epochs = 3 ,
97+ epochs = 50 ,
9698 verbose = True
9799 )
98100
99101 model_partial .autofit (
100102 smiles_list ,
101103 properties ,
104+ X_val = smiles_list [3 :],
105+ y_val = properties [3 :],
102106 search_parameters = partial_search ,
103107 n_trials = 2
104108 )
@@ -109,7 +113,7 @@ def test_grea_predictor():
109113 model_default = GREAMolecularPredictor (
110114 num_task = 1 ,
111115 task_type = "classification" ,
112- epochs = 3 ,
116+ epochs = 50 ,
113117 verbose = True
114118 )
115119
@@ -202,7 +206,7 @@ def test_grea_upload():
202206 )
203207
204208 # Fit the model with sample data
205- model_for_upload .autofit (smiles_list [: 3 ], properties [: 3 ])
209+ model_for_upload .autofit (smiles_list , properties , X_val = smiles_list [ 3 : ], y_val = properties [3 : ])
206210
207211 # Push to Hugging Face Hub
208212 # Note: HF_TOKEN should be set in environment variables
@@ -242,5 +246,5 @@ def test_grea_upload():
242246 print ("Cleaned up test_grea_model.pt" )
243247
244248if __name__ == "__main__" :
245- # test_grea_predictor()
249+ test_grea_predictor ()
246250 test_grea_upload ()
0 commit comments