Skip to content

Commit fb9e594

Browse files
authored
fix:rafi hotspot fix for inference.py for rfd3na
1 parent 9bfbb34 commit fb9e594

1 file changed

Lines changed: 10 additions & 5 deletions

File tree

models/rfd3na/src/rfd3na/utils/inference.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,21 +405,26 @@ def infer_ori_from_hotspots(atom_array: struc.AtomArray):
405405

406406
# We can only perform distance computations on atoms with non-NaN coordinates
407407
nan_coords_mask = np.any(np.isnan(atom_array.coord), axis=1)
408-
non_nan_atom_array = atom_array[~nan_coords_mask]
409-
408+
motif_mask = atom_array.is_motif_atom_with_fixed_coord.astype(bool)
409+
non_nan_motif_atom_array = atom_array[~nan_coords_mask & motif_mask]
410+
if non_nan_motif_atom_array.array_length() == 0:
411+
raise ValueError(
412+
"infer_ori_from_hotspots requires at least one fixed motif atom "
413+
"(is_motif_atom_with_fixed_coord=True) to compute nearby atoms COM."
414+
)
410415
# Perform the distance computation
411416
# RFD2 used 10 Angstroms instead of 12, but was for residue-level hotspots
412417
DISTANCE_CUTOFF = 12.0
413-
cell_list = struc.CellList(non_nan_atom_array, cell_size=DISTANCE_CUTOFF)
418+
cell_list = struc.CellList(non_nan_motif_atom_array, cell_size=DISTANCE_CUTOFF)
414419
nearby_atoms_mask = get_atom_mask_from_cell_list(
415420
hotspot_atom_array.coord,
416421
cell_list,
417-
len(non_nan_atom_array),
422+
len(non_nan_motif_atom_array),
418423
cutoff=DISTANCE_CUTOFF,
419424
) # (n_query, n_cell_list)
420425

421426
nearby_atoms_mask = np.any(nearby_atoms_mask, axis=0) # (n_cell_list,)
422-
nearby_atoms_com = non_nan_atom_array.coord[nearby_atoms_mask].mean(axis=0)
427+
nearby_atoms_com = non_nan_motif_atom_array.coord[nearby_atoms_mask].mean(axis=0)
423428

424429
vector_from_core_to_hotspot = hotspot_com - nearby_atoms_com
425430
vector_from_core_to_hotspot = vector_from_core_to_hotspot / np.linalg.norm(

0 commit comments

Comments
 (0)