@@ -389,7 +389,7 @@ def compute(self, xyz):
389389 # This operates on self.xyz_motif, which is assigned to this class in the model runner (for horrible plumbing reasons)
390390 self ._grab_motif_residues (self .xyz_motif )
391391
392- # for checking affine transformation is corect
392+ # for checking affine transformation is correct
393393 first_distance = torch .sqrt (torch .sqrt (torch .sum (torch .square (self .motif_substrate_atoms [0 ] - self .motif_frame [0 ]), dim = - 1 )))
394394
395395 # grab the coordinates of the corresponding atoms in the new frame using mapping
@@ -472,30 +472,44 @@ class dmasif_interactions(Potential):
472472 Differentiable way to optinize binding and non-binding surface
473473 '''
474474
475- def __init__ (self , binderlen , int_weight = 1 , non_int_weight = 1 ,
476- pos_threshold = 3 , neg_threshold = 5 , seq_model_type = 'protein_mpnn' ):
475+ def __init__ (self , binderlen , int_weight = 1 , non_int_weight = 1 , disable = False ,
476+ pos_threshold = 3 , neg_threshold = 3 , seq_model_type = 'protein_mpnn' ):
477477
478478 super ().__init__ ()
479- self .predicted = True
479+
480+ self .disable = disable
480481
481482 submodule_path = '/' .join (__file__ .split ('/' )[:- 4 ])
482483 import sys
483484 sys .path .append (submodule_path )
485+
486+ from rfdiffusion .recover_sidechains import GetMartiniSidechains
487+
488+ self .get_sidechains = GetMartiniSidechains (binderlen = binderlen ,
489+ seq_model_type = seq_model_type )
484490
485- from masif_martini .potential import RFdiff_potential_from_bb
491+ from masif_martini .potential import dmasif_potential
486492
487- self .potential = RFdiff_potential_from_bb (binderlen = binderlen ,
488- int_weight = int_weight ,
489- non_int_weight = non_int_weight ,
490- pos_threshold = pos_threshold ,
491- neg_threshold = neg_threshold ,
492- seq_model_type = seq_model_type )
493+ self .potential = dmasif_potential (binderlen = binderlen ,
494+ int_weight = int_weight ,
495+ non_int_weight = non_int_weight ,
496+ pos_threshold = pos_threshold ,
497+ neg_threshold = neg_threshold )
493498
494499 self .allatom = ComputeAllAtomCoords ()
495500
496501 def compute (self , xyz ):
497502
498- return self .potential (xyz .squeeze ()).to ('cpu' )
503+ d = self .get_sidechains (xyz , self .seq , self .mask_seq )
504+
505+ potential = self .potential (d )
506+
507+ if self .disable :
508+ potential .detach_ ()
509+ potential .requires_grad_ ()
510+ xyz .grad = torch .zeros_like (xyz )
511+
512+ return potential
499513
500514# Dictionary of types of potentials indexed by name of potential. Used by PotentialManager.
501515# If you implement a new potential you must add it to this dictionary for it to be used by
0 commit comments