Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions openfold/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,8 @@ def chain_center_of_mass_loss(
clamp_distance:
Cutoff above which distance errors are disregarded
weight:
Weight for loss
Accepted for config/backward compatibility. The top-level loss
weight is applied by AlphaFoldLoss.
eps:
Small value used to regularize denominators
Returns:
Expand All @@ -1675,7 +1676,7 @@ def get_chain_center_of_mass(pos):

pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps)
true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps)
losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2
losses = torch.clamp(pred_dists - true_dists - clamp_distance, max=0) ** 2
loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :]

loss = masked_mean(loss_mask, losses, dim=(-1, -2))
Expand Down
64 changes: 64 additions & 0 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,67 @@ def run_tm_loss(representations, batch, value):

self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

def _make_chain_center_of_mass_inputs(self, requires_grad=False):
ca_pos = residue_constants.atom_order["CA"]
all_atom_positions = torch.zeros((1, 2, 37, 3), dtype=torch.float32)
all_atom_positions[0, 0, ca_pos] = torch.tensor([0.0, 0.0, 0.0])
all_atom_positions[0, 1, ca_pos] = torch.tensor([10.0, 0.0, 0.0])

all_atom_pred_pos = torch.zeros_like(all_atom_positions)
all_atom_pred_pos[0, 0, ca_pos] = torch.tensor([0.0, 0.0, 0.0])
all_atom_pred_pos[0, 1, ca_pos] = torch.tensor([2.0, 0.0, 0.0])
if requires_grad:
all_atom_pred_pos.requires_grad_()

all_atom_mask = torch.zeros((1, 2, 37), dtype=torch.float32)
all_atom_mask[:, :, ca_pos] = 1.0
asym_id = torch.tensor([[1, 2]], dtype=torch.float32)

return all_atom_pred_pos, all_atom_positions, all_atom_mask, asym_id

def test_chain_center_of_mass_loss_is_unweighted(self):
inputs = self._make_chain_center_of_mass_inputs()
loss = chain_center_of_mass_loss(
all_atom_pred_pos=inputs[0],
all_atom_positions=inputs[1],
all_atom_mask=inputs[2],
asym_id=inputs[3],
clamp_distance=-4.0,
weight=0.05,
)
larger_weight_loss = chain_center_of_mass_loss(
all_atom_pred_pos=inputs[0],
all_atom_positions=inputs[1],
all_atom_mask=inputs[2],
asym_id=inputs[3],
clamp_distance=-4.0,
weight=0.5,
)

expected = torch.tensor([32.0 / (4.0 + 1e-4)], dtype=loss.dtype)
self.assertTrue(torch.allclose(loss, expected, rtol=1e-5, atol=1e-5))
self.assertTrue(
torch.allclose(loss, larger_weight_loss, rtol=1e-6, atol=1e-6)
)

def test_chain_center_of_mass_loss_backpropagates(self):
inputs = self._make_chain_center_of_mass_inputs(requires_grad=True)
loss = chain_center_of_mass_loss(
all_atom_pred_pos=inputs[0],
all_atom_positions=inputs[1],
all_atom_mask=inputs[2],
asym_id=inputs[3],
clamp_distance=-4.0,
weight=0.05,
)

loss.sum().backward()

ca_pos = residue_constants.atom_order["CA"]
ca_grad = inputs[0].grad[:, :, ca_pos, :]
self.assertTrue(torch.all(torch.isfinite(ca_grad)))
self.assertGreater(torch.norm(ca_grad).item(), 0.0)

@compare_utils.skip_unless_alphafold_installed()
def test_chain_center_of_mass_loss(self):
batch_size = consts.batch_size
Expand Down Expand Up @@ -1139,6 +1200,9 @@ def test_chain_center_of_mass_loss(self):
)
out_repro = out_repro.cpu()

self.assertTrue(torch.all(torch.isfinite(out_repro)))
self.assertTrue(torch.all(out_repro >= 0))


if __name__ == "__main__":
unittest.main()