@@ -345,29 +345,36 @@ def to_numpy(loss_dict):
345345 return loss_dict
346346
347347 # Compute between residue backbone violations of bonds and angles.
348- connection_violations = to_numpy (
349- functional .between_residue_bond_loss (
350- pred_points , points_mask , residue_index , aatypes
348+ violations = {}
349+ violations .update (
350+ to_numpy (
351+ functional .between_residue_bond_loss (
352+ pred_points , points_mask , residue_index , aatypes
353+ )
351354 )
352355 )
353- between_residue_violations = to_numpy (
354- functional .between_residue_clash_loss (
355- pred_points , points_mask , residue_index , aatypes
356+ violations .update (
357+ to_numpy (
358+ functional .between_residue_clash_loss (
359+ pred_points , points_mask , residue_index , aatypes
360+ )
356361 )
357362 )
358- within_residue_violations = to_numpy (
359- functional .within_residue_clash_loss (
360- pred_points , points_mask , residue_index , aatypes
363+ violations .update (
364+ to_numpy (
365+ functional .within_residue_clash_loss (
366+ pred_points , points_mask , residue_index , aatypes
367+ )
361368 )
362369 )
363370
364371 # Combine them to a single per-residue violation mask (used later for LDDT).
365372 per_residue_violation_mask = np .max (
366373 np .stack (
367374 [
368- connection_violations ['per_residue_violation_mask' ],
369- np .max (between_residue_violations [ 'per_atom_clash_mask ' ], axis = - 1 ),
370- np .max (within_residue_violations [ 'per_atom_clash_mask ' ], axis = - 1 )
375+ violations ['per_residue_violation_mask' ],
376+ np .max (violations [ 'between_residue_per_atom_clash_mask ' ], axis = - 1 ),
377+ np .max (violations [ 'within_residue_per_atom_clash_mask ' ], axis = - 1 )
371378 ]
372379 ),
373380 axis = 0
0 commit comments