Skip to content

Commit 8fcd0f2

Browse files
committed
revise var warn in grea
1 parent 9d6dda1 commit 8fcd0f2

2 files changed

Lines changed: 17 additions & 9 deletions

File tree

tests/predictor/run_grea.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def test_grea_predictor():
1010
'CNC[C@@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@H]1C',
1111
'C[C@H]1CN([C@@H](C)CO)C(=O)CCCn2cc(nn2)CO[C@@H]1CN(C)C(=O)CCC(F)(F)F',
1212
'CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F'
13-
]
14-
properties = np.array([0, 0, 1, 1]) # Binary classification
13+
] * 10
14+
properties = np.array([0, 0, 1, 1] * 10) # Binary classification
1515

1616
# 1. Basic initialization test
1717
print("\n=== Testing GREA model initialization ===")
@@ -29,7 +29,7 @@ def test_grea_predictor():
2929

3030
# 2. Basic fitting test
3131
print("\n=== Testing GREA model fitting ===")
32-
model.fit(smiles_list[:3], properties[:3])
32+
model.fit(smiles_list, properties, X_val=smiles_list[3:], y_val=properties[3:])
3333
print("GREA model fitting completed")
3434

3535
# 3. Prediction test
@@ -70,13 +70,15 @@ def test_grea_predictor():
7070
model_auto = GREAMolecularPredictor(
7171
num_task=1,
7272
task_type="classification",
73-
epochs=3,
73+
epochs=50,
7474
verbose=True
7575
)
7676

7777
model_auto.autofit(
7878
smiles_list,
7979
properties,
80+
X_val=smiles_list[3:],
81+
y_val=properties[3:],
8082
search_parameters=search_parameters,
8183
n_trials=2
8284
)
@@ -92,13 +94,15 @@ def test_grea_predictor():
9294
model_partial = GREAMolecularPredictor(
9395
num_task=1,
9496
task_type="classification",
95-
epochs=3,
97+
epochs=50,
9698
verbose=True
9799
)
98100

99101
model_partial.autofit(
100102
smiles_list,
101103
properties,
104+
X_val=smiles_list[3:],
105+
y_val=properties[3:],
102106
search_parameters=partial_search,
103107
n_trials=2
104108
)
@@ -109,7 +113,7 @@ def test_grea_predictor():
109113
model_default = GREAMolecularPredictor(
110114
num_task=1,
111115
task_type="classification",
112-
epochs=3,
116+
epochs=50,
113117
verbose=True
114118
)
115119

@@ -202,7 +206,7 @@ def test_grea_upload():
202206
)
203207

204208
# Fit the model with sample data
205-
model_for_upload.autofit(smiles_list[:3], properties[:3])
209+
model_for_upload.autofit(smiles_list, properties, X_val=smiles_list[3:], y_val=properties[3:])
206210

207211
# Push to Hugging Face Hub
208212
# Note: HF_TOKEN should be set in environment variables
@@ -242,5 +246,5 @@ def test_grea_upload():
242246
print("Cleaned up test_grea_model.pt")
243247

244248
if __name__ == "__main__":
245-
# test_grea_predictor()
249+
test_grea_predictor()
246250
test_grea_upload()

torch_molecule/predictor/grea/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,11 @@ def forward(self, batched_data):
138138
h_rep = (h_r.unsqueeze(1) + h_env.unsqueeze(0)).view(-1, self.hidden_size)
139139
h_r, h_rep = self._augment_graph_features(batched_data, h_r, h_rep)
140140
prediction = self.predictor(h_r)
141-
variance = self.predictor(h_rep).view(h_r.size(0), -1).var(dim=-1, keepdim=True)
141+
pred_rep = self.predictor(h_rep).view(h_r.size(0), -1)
142+
if pred_rep.size(1) > 1:
143+
variance = pred_rep.var(dim=-1, keepdim=True)
144+
else:
145+
variance = torch.zeros_like(pred_rep)
142146
num_graphs = batched_data.batch.max().item() + 1
143147
score_by_graph = [node_score[batched_data.batch == i].view(-1).tolist() for i in range(num_graphs)]
144148
return {"prediction": prediction, "variance": variance, "score": score_by_graph, "representation": h_r}

0 commit comments

Comments
 (0)