Skip to content

Commit 5d76667

Browse files
committed
use sequence mask in potentials
1 parent 5735041 commit 5d76667

3 files changed

Lines changed: 523 additions & 12 deletions

File tree

rfdiffusion/inference/model_runners.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,17 @@ def sample_init(self, return_forward_trajectory=False):
384384
self.pair_prev = None
385385
self.state_prev = None
386386

387+
#########################################
388+
### Parse seq for potentials using sidechains ###
389+
#########################################
390+
391+
if self.potential_conf.guiding_potentials is not None:
392+
if True: #any(list(filter(lambda x: "sidechain" in x, self.potential_conf.guiding_potentials))):
393+
for pot in self.potential_manager.potentials_to_apply:
394+
pot.seq = seq_t
395+
pot.mask_seq = self.mask_seq
396+
397+
387398
#########################################
388399
### Parse ligand for ligand potential ###
389400
#########################################

rfdiffusion/potentials/potentials.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def compute(self, xyz):
389389
# This operates on self.xyz_motif, which is assigned to this class in the model runner (for horrible plumbing reasons)
390390
self._grab_motif_residues(self.xyz_motif)
391391

392-
# for checking affine transformation is corect
392+
# for checking affine transformation is correct
393393
first_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(self.motif_substrate_atoms[0] - self.motif_frame[0]), dim=-1)))
394394

395395
# grab the coordinates of the corresponding atoms in the new frame using mapping
@@ -472,30 +472,44 @@ class dmasif_interactions(Potential):
472472
Differentiable way to optinize binding and non-binding surface
473473
'''
474474

475-
def __init__(self, binderlen, int_weight=1, non_int_weight=1,
476-
pos_threshold=3, neg_threshold=5, seq_model_type='protein_mpnn'):
475+
def __init__(self, binderlen, int_weight=1, non_int_weight=1, disable=False,
476+
pos_threshold=3, neg_threshold=3, seq_model_type='protein_mpnn'):
477477

478478
super().__init__()
479-
self.predicted=True
479+
480+
self.disable=disable
480481

481482
submodule_path='/'.join(__file__.split('/')[:-4])
482483
import sys
483484
sys.path.append(submodule_path)
485+
486+
from rfdiffusion.recover_sidechains import GetMartiniSidechains
487+
488+
self.get_sidechains=GetMartiniSidechains(binderlen=binderlen,
489+
seq_model_type=seq_model_type)
484490

485-
from masif_martini.potential import RFdiff_potential_from_bb
491+
from masif_martini.potential import dmasif_potential
486492

487-
self.potential=RFdiff_potential_from_bb(binderlen=binderlen,
488-
int_weight=int_weight,
489-
non_int_weight=non_int_weight,
490-
pos_threshold=pos_threshold,
491-
neg_threshold=neg_threshold,
492-
seq_model_type=seq_model_type)
493+
self.potential=dmasif_potential(binderlen=binderlen,
494+
int_weight=int_weight,
495+
non_int_weight=non_int_weight,
496+
pos_threshold=pos_threshold,
497+
neg_threshold=neg_threshold)
493498

494499
self.allatom=ComputeAllAtomCoords()
495500

496501
def compute(self, xyz):
497502

498-
return self.potential(xyz.squeeze()).to('cpu')
503+
d=self.get_sidechains(xyz, self.seq, self.mask_seq)
504+
505+
potential=self.potential(d)
506+
507+
if self.disable:
508+
potential.detach_()
509+
potential.requires_grad_()
510+
xyz.grad=torch.zeros_like(xyz)
511+
512+
return potential
499513

500514
# Dictionary of types of potentials indexed by name of potential. Used by PotentialManager.
501515
# If you implement a new potential you must add it to this dictionary for it to be used by

0 commit comments

Comments
 (0)