Skip to content

Commit e2bc8fa

Browse files
committed
ssr truncate ssr loss to 10
1 parent 4757718 commit e2bc8fa

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

torch_molecule/predictor/ssr/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _augmented_graph_features(self, batched_data, h_rep):
227227
return h_rep
228228

229229

230-
def compute_loss(self, batched_data, criterion, coarse_ratios=[0.8, 0.9],cmd_coeff=0.1,fine_grained=True,n_moments=5):
230+
def compute_loss(self, batched_data, criterion, coarse_ratios=[0.8, 0.9], cmd_coeff=0.1, fine_grained=True, n_moments=5):
231231
"""Compute loss with SSR regularization"""
232232
# Original forward pass
233233
h_node, _ = self.graph_encoder(batched_data)
@@ -268,6 +268,11 @@ def compute_loss(self, batched_data, criterion, coarse_ratios=[0.8, 0.9],cmd_coe
268268
ssr_loss = ssr_loss + torch.norm(h_rep - coarse_h_rep, dim=1).mean()
269269

270270
# Compute total loss
271+
ssr_loss = cmd_coeff * ssr_loss
272+
if ssr_loss > 10:
273+
import warnings
274+
warnings.warn(f"SSR loss is too large: {ssr_loss}, truncating to 10")
275+
ssr_loss = 10
271276
total_loss = pred_loss + cmd_coeff * ssr_loss
272277

273278
return total_loss, pred_loss, ssr_loss

0 commit comments

Comments
 (0)