Skip to content

Commit dd5aefb

Browse files
committed
add averaged conditioning
1 parent 9d0ff16 commit dd5aefb

7 files changed

Lines changed: 120 additions & 53 deletions

File tree

rfdiffusion/inference/model_runners.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,12 +428,14 @@ def sample_init(self, return_forward_trajectory=False):
428428
het_names = np.array([i['name'].strip() for i in self.target_feats['info_het']])
429429
xyz_het = self.target_feats['xyz_het'][het_names == self._conf.potentials.substrate]
430430
xyz_het = torch.from_numpy(xyz_het)
431+
info_het={x: self.target_feats['info_het'][x][het_names == self._conf.potentials.substrate] for x in self.target_feats['info_het']}
431432
assert xyz_het.shape[0] > 0, f'expected >0 heteroatoms from ligand with name {self._conf.potentials.substrate}'
432433
xyz_motif_prealign = xyz_motif_prealign[0,0][self.diffusion_mask.squeeze()]
433434
motif_prealign_com = xyz_motif_prealign[:,1].mean(dim=0)
434435
xyz_het_com = xyz_het.mean(dim=0)
435436
for pot in self.potential_manager.potentials_to_apply:
436437
pot.motif_substrate_atoms = xyz_het
438+
pot.substrate_info = info_het
437439
pot.diffusion_mask = self.diffusion_mask.squeeze()
438440
pot.xyz_motif = xyz_motif_prealign
439441
pot.diffuser = self.diffuser
@@ -444,9 +446,9 @@ def sample_init(self, return_forward_trajectory=False):
444446

445447
if self.potential_conf.guiding_potentials is not None:
446448
if any(list(filter(lambda x: "na_" in x, self.potential_conf.guiding_potentials))):
447-
assert len(self.target_feats['xyz_na']) > 0, "If you're using the NA Contact potential, \
449+
assert len(self.target_feats['na_xyz']) > 0, "If you're using the NA Contact potential, \
448450
you need to make sure there's a NA in the input_pdb file!"
449-
info_na = self.target_feats["na_info"],
451+
info_na = self.target_feats["na_info"]
450452
xyz_het = self.target_feats['na_xyz']
451453
xyz_het = torch.from_numpy(xyz_het)
452454
xyz_motif_prealign = xyz_motif_prealign[0,0][self.diffusion_mask.squeeze()]

rfdiffusion/inference/utils.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,28 @@ def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6):
122122
return mu, sigma
123123

124124

125+
def rigid_rotation_from_grads(Cas, Ca_grads, eps=1e-8):
126+
center=Cas.mean(dim=0) # (3,)
127+
r=Cas-center # (L,3)
128+
129+
trans=Ca_grads.mean(dim=0, keepdim=True) # (L, 3)
130+
d=Ca_grads-trans # (L,3)
131+
eye=torch.eye(3, device=Cas.device, dtype=Cas.dtype)
132+
r2=(r**2).sum(dim=1) # (L,)
133+
rrT=r[:,:,None]*r[:,None,:] # (L,3,3)
134+
I=(r2[:, None, None] * eye[None, :, :] - rrT).sum(dim=0) # (3,3)
135+
tau=torch.cross(r,d, dim=1).sum(dim=0) # (3,)
136+
try:
137+
omega = torch.linalg.solve(I + eps*eye, tau) # (3,)
138+
except RuntimeError:
139+
omega = torch.linalg.lstsq(I + eps*eye, tau.unsqueeze(-1)).solution.squeeze(-1)
140+
141+
rot = torch.cross(omega.unsqueeze(0).expand_as(r), r, dim=1) #(L,3)
142+
143+
return trans, omega, center, rot
144+
145+
146+
125147
def get_next_ca(
126148
xt,
127149
px0,
@@ -392,7 +414,11 @@ def get_potential_gradients(self, xyz, diffusion_mask, t, predicted=False):
392414

393415
# Since we are not moving frames, Cb grads are same as Ca grads
394416
# Need access to calculated Cb coordinates to be able to get Cb grads though
395-
Ca_grads = xyz.grad[:, 1, :]
417+
if xyz.grad is None:
418+
print("WARNING: NaN in potential gradients, replacing with zero grad.")
419+
Ca_grads=torch.zeros_like(xyz[:, 1, :])
420+
else:
421+
Ca_grads = xyz.grad[:, 1, :]
396422

397423
if not diffusion_mask == None:
398424
Ca_grads[diffusion_mask, :] = 0
@@ -406,6 +432,20 @@ def get_potential_gradients(self, xyz, diffusion_mask, t, predicted=False):
406432
print("WARNING: NaN in potential gradients, replacing with zero grad.")
407433
Ca_grads[:] = 0
408434

435+
# smooth potential effects within protein subunits
436+
smooth_scale=max([potential.smooth for potential in self.potential_manager.potentials_to_apply])
437+
if smooth_scale>0:
438+
Cas=xyz[:, 1, :]
439+
binderlen=self.potential_manager.binderlen
440+
if binderlen<0:
441+
borders=[(0,Ca_grads.shape[0])]
442+
else:
443+
borders=[(0,binderlen),(binderlen,Ca_grads.shape[0])]
444+
for a, b in borders:
445+
with torch.no_grad():
446+
trans, omega, center, rot = rigid_rotation_from_grads(Cas[a:b],Ca_grads[a:b])
447+
Ca_grads[a:b]=Ca_grads[a:b]*(1-smooth_scale)+(trans+rot)*smooth_scale
448+
409449
return Ca_grads
410450

411451
def get_next_pose(
@@ -620,38 +660,40 @@ def parse_pdb_lines(lines, parse_hetatom=False, parse_na=False, ignore_het_h=Tru
620660
res, pdb_idx = [],[]
621661
for l in lines:
622662
if l[:4] == "ATOM" and l[12:16].strip() == "C1'":
623-
res.append((l[22:26], l[17:20]))
663+
res.append((l[22:26], l[17:20].strip()))
624664
# chain letter, res num
625665
pdb_idx.append((l[21:22].strip(), int(l[22:26].strip())))
626-
seq = [util.na2num[r[1]] if r[1] in util.na2num.keys() else 20 for r in res]
666+
seq = [util.na2num[r[1]] if r[1] in util.na2num else 20 for r in res]
627667
pdb_idx = [
628668
(l[21:22].strip(), int(l[22:26].strip()))
629669
for l in lines
630670
if l[:4] == "ATOM" and l[12:16].strip() == "C1'"
631671
] # chain letter, res num
632672

633-
# 4 BB + up to 10 SC atoms
673+
# 3 BB + up to 20 SC atoms
634674
xyz = np.full((len(res), 23, 3), np.nan, dtype=np.float32)
635-
xyz_names = np.full((len(res), 23, 3), np.nan)
675+
atom_id = np.full((len(res), 23), np.nan, dtype=np.object)
676+
atom_type = np.full((len(res), 23), np.nan, dtype=np.object)
636677
for l in lines:
637678
if l[:4] != "ATOM":
638679
continue
639680
chain, resNo, atom, aa = (
640681
l[21:22],
641682
int(l[22:26]),
642683
" " + l[12:16].strip().ljust(3),
643-
l[17:20],
684+
l[17:20].strip(),
644685
)
645686
if (chain,resNo) in pdb_idx:
646687
idx = pdb_idx.index((chain, resNo))
647688
for i_atm, tgtatm in enumerate(
648-
util.na2long[util.na2num[aa]][:14]
689+
util.na2long[util.na2num[aa]][:23]
649690
):
650691
if (
651692
tgtatm is not None and tgtatm.strip() == atom.strip()
652693
): # ignore whitespace
653694
xyz[idx, i_atm, :] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
654-
xyz_names[idx, i_atm, :] = l[16:20]
695+
atom_id[idx, i_atm] = atom
696+
atom_type[idx, i_atm] = l[77]
655697
break
656698

657699
# save atom mask
@@ -674,7 +716,8 @@ def parse_pdb_lines(lines, parse_hetatom=False, parse_na=False, ignore_het_h=Tru
674716

675717
out["na_xyz"]= xyz # cartesian coordinates, [Lx23]
676718
out["na_mask"]= mask # mask showing which atoms are present in the PDB file, [Lx23]
677-
out['na_atom_names']= xyz_names
719+
out['na_atom_id']= atom_id
720+
out['na_atom_type']= atom_type
678721
out["na_seq"]= np.array(seq) # amino acid sequence, [L]
679722
out["na_pdb_idx"]= pdb_idx # list of (chain letter, residue number) in the pdb file, [L]
680723

@@ -713,7 +756,8 @@ def process_target(pdb_path, parse_hetatom=False, parse_na=False, center=True):
713756

714757
if parse_na:
715758
out['na_info']={'mask':target_struct["na_mask"],
716-
'atom_names':target_struct['na_atom_names'],
759+
'atom_id':target_struct['na_atom_id'],
760+
'atom_type':target_struct['na_atom_type'],
717761
'seq':target_struct["na_seq"],
718762
'pdb_idx':target_struct["na_pdb_idx"]}
719763
out["na_xyz"]= target_struct["na_xyz"]

rfdiffusion/potentials/manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(self,
9393
self.potentials_config = potentials_config
9494
self.ppi_config = ppi_config
9595
self.inference_config = inference_config
96+
self.binderlen=binderlen
9697

9798
self.guide_scale = potentials_config.guide_scale
9899
self.guide_decay = potentials_config.guide_decay

rfdiffusion/potentials/potentials.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ def __init__(self):
1414
self.sidechain=False
1515
self.current_substrate_atoms=None
1616
self.current_na_atoms=None
17+
self.smooth=0
1718

18-
def compute(self, xyz):
19+
def compute(self, xyz,**kwargs):
1920
'''
2021
Given the current structure of the model prediction, return the current
2122
potential as a PyTorch tensor with a single entry
@@ -434,7 +435,7 @@ def compute(self, xyz):
434435
assert abs(first_distance - second_distance) < 0.01, "Alignment seems to be bad"
435436

436437
if self.sidechain:
437-
d=self.get_sidechains(xyz, self.seq, self.mask_seq)
438+
d=self.get_sidechains(xyz, self.seq, self.mask_seq, substrate_atoms, self.substrate_info['atom_type'])
438439
dgram = torch.cdist(d['atom_xyz_p1'][None,...].contiguous(),
439440
substrate_atoms.float()[None].to(d['atom_xyz_p1'].device), p=2)[0] # [Lb,Lb]
440441
else:
@@ -506,14 +507,17 @@ def _grab_motif_residues(self, xyz) -> None:
506507

507508
class na_contacts(substrate_contacts):
508509

509-
def __init__(self, weight=1, r_0=8, d_0=2, s=1, eps=1e-6, rep_r_0=5, rep_s=2, rep_r_min=1, sidechain=False):
510+
def __init__(self, weight=1, r_0=8, d_0=2, s=1, eps=1e-6, rep_r_0=5, rep_s=2, rep_r_min=1,
511+
sidechain=False, smooth=0, predicted=False):
510512

511513
super().__init__()
512514
self.r_0 = r_0
513515
self.weight = weight
514516
self.d_0 = d_0
515517
self.eps = eps
516518
self.sidechain=sidechain
519+
self.predicted=predicted
520+
self.smooth=smooth
517521

518522
self.motif_frame = None # [4,3] xyz coordinates from 4 atoms of input motif
519523
self.motif_mapping = None # list of tuples giving positions of above atoms in design [(resi, atom_idx)]
@@ -540,7 +544,7 @@ def __init__(self, weight=1, r_0=8, d_0=2, s=1, eps=1e-6, rep_r_0=5, rep_s=2, re
540544
def compute(self, xyz):
541545

542546
if self.xyz_motif==None or self.xyz_motif.shape[0]<3:
543-
substrate_atoms=(self.na_atoms-self.na_atoms[:11].view(-1,3).mean(dim=0)).detach()
547+
substrate_atoms=(self.na_atoms-self.na_atoms[:,:11,:].mean(dim=(0,1))[None,None,:]).detach()
544548

545549
else:
546550
self._grab_motif_residues(self.xyz_motif)
@@ -557,11 +561,12 @@ def compute(self, xyz):
557561

558562
self.current_na_atoms = substrate_atoms.clone().detach()
559563
substrate_atoms=substrate_atoms.view(-1,3)
560-
mask=self.na_info['mask'].view(-1,3)
564+
mask=torch.from_numpy(self.na_info['mask']).view(-1)
561565
substrate_atoms=substrate_atoms[mask,:]
562566

563567
if self.sidechain:
564-
d=self.get_sidechains(xyz, self.seq, self.mask_seq)
568+
aatypes=self.na_info['atom_type'].reshape(-1,3)[mask,:]
569+
d=self.get_sidechains(xyz, self.seq, self.mask_seq, substrate_atoms, aatypes)
565570
dgram = torch.cdist(d['atom_xyz_p1'][None,...].contiguous(),
566571
substrate_atoms.float()[None].to(d['atom_xyz_p1'].device), p=2)[0] # [Lb,Lb]
567572
else:
@@ -635,8 +640,9 @@ def compute(self, xyz):
635640
'interface_ncontacts': interface_ncontacts,
636641
'monomer_contacts': monomer_contacts,
637642
'olig_contacts': olig_contacts,
638-
'substrate_contacts': substrate_contacts,
639-
'dmasif_interactions': dmasif_interactions}
643+
'substrate_contacts': substrate_contacts,
644+
'na_contacts': na_contacts,
645+
'dmasif_interactions': dmasif_interactions}
640646

641647
require_binderlen = { 'binder_ROG',
642648
'binder_distance_ReLU',

rfdiffusion/recover_sidechains.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,12 @@ def __init__(self, binderlen=-1, seq_model_type='protein_mpnn'):
218218

219219
import LigandMPNN
220220
from LigandMPNN.model_utils import ProteinMPNN
221-
from LigandMPNN.data_utils import restype_str_to_int, restype_1to3, alphabet
222-
221+
from LigandMPNN.data_utils import restype_str_to_int, restype_1to3, alphabet, featurize, element_list
223222

224223
restype_3to1={restype_1to3[x]: x for x in restype_1to3.keys()}
225224
self.renumber_aa_mpnn2rf=torch.tensor([restype_str_to_int[restype_3to1.get(x,'X')] for x in num2aa], dtype=int)
226225
self.renumber_aa_rf2mpnn=torch.tensor([aa2num[aa_123.get(x,'UNK')] for x in alphabet], dtype=int)
226+
self.element_dict = dict(zip(element_list, range(1, len(element_list))))
227227

228228
path_to_LigandMPNN=LigandMPNN.__path__._path[0]
229229

@@ -256,6 +256,7 @@ def __init__(self, binderlen=-1, seq_model_type='protein_mpnn'):
256256
self.seq_model.load_state_dict(seq_checkpoint["model_state_dict"])
257257
self.seq_model.to(self.device)
258258
self.seq_model.eval()
259+
self.featurize=featurize
259260
print('Load LigandMPNN model')
260261

261262
self.recover_sc=None
@@ -268,50 +269,64 @@ def init_recover_sc(self):
268269
device=self.device)
269270
self.recover_sc.eval()
270271

271-
272-
def run_LigandMPNN(self, xyz, seq, seq_mask):
272+
def run_LigandMPNN(self, xyz, seq, seq_mask, ligand_xyz=None, ligand_aatypes=None):
273273

274274
L=xyz.shape[0]
275275

276276
xyz=get_O_from_3_points(xyz)
277277

278-
feature_dict = {}
279-
feature_dict["batch_size"]=1
278+
input_dict = {}
279+
280+
input_dict["X"] = xyz[:,:4,:] # L*4*3 (bb atoms) ? normalize
281+
input_dict["mask"] = torch.ones([ L]).to(self.device)
282+
280283
if seq==None:
281-
feature_dict["S"] = torch.full((1, L),20,dtype=int).to(self.device) # encoded sequence
282-
feature_dict["chain_mask"] = torch.full((1, L),True,dtype=bool).to(self.device)
284+
input_dict["S"] = torch.full(( L),20,dtype=int).to(self.device) # encoded sequence
285+
input_dict["chain_mask"] = torch.full(( L),True,dtype=bool).to(self.device)
283286
raise AttributeError
284287
else:
285-
feature_dict["S"]=seq[None,:,self.renumber_aa_rf2mpnn].argmax(-1).detach()
286-
feature_dict["chain_mask"] = ~seq_mask
288+
input_dict["S"]=seq[:,self.renumber_aa_rf2mpnn].argmax(-1).squeeze().detach()
289+
input_dict["chain_mask"] = ~seq_mask.squeeze().detach()
290+
291+
if ligand_xyz==None:
292+
input_dict["Y"] = torch.zeros([1, 3]).to(self.device)
293+
input_dict["Y_t"] = torch.zeros([1]).to(self.device)
294+
input_dict["Y_m"] = torch.zeros([1]).to(self.device)
295+
else:
296+
input_dict["Y"] = ligand_xyz.to(self.device).detach()
297+
input_dict["Y_t"] = torch.tensor([self.element_dict.get(x,0) for x in ligand_aatypes],
298+
dtype=torch.int32, device=self.device)
299+
input_dict["Y_m"] = torch.ones_like(input_dict["Y_t"])
300+
301+
input_dict["R_idx"] = torch.arange(L).to(self.device) # L resnums
302+
303+
if self.binderlen>0:
304+
input_dict["chain_labels"] = torch.cat((torch.zeros((self.binderlen)),
305+
torch.ones((L-self.binderlen))),0).to(self.device) # L Chain indices
306+
else:
307+
input_dict["chain_labels"]=torch.zeros((L)).to(self.device)
287308

288-
feature_dict["X"] = xyz[None,:,:4,:] # B*L*4*3 (bb atoms) ? normalize
309+
feature_dict = self.featurize(input_dict,
310+
number_of_ligand_atoms=(self.seq_model.features.atom_context_num
311+
if self.seq_model.model_type=='ligand_mpnn'
312+
else 1) ,
313+
model_type=self.seq_model.model_type)
289314

290-
feature_dict["mask"] = torch.ones([1, L]).to(self.device)
315+
feature_dict["batch_size"]=1
291316
feature_dict["temperature"] = 0.1
292317
feature_dict["bias"] = torch.zeros((1,L,21)).to(self.device)
293318
feature_dict["randn"]=torch.randn((1,L)).to(self.device)
294319
feature_dict["symmetry_residues"] = [[]]
295320
feature_dict["symmetry_weights"]=[[]]
296-
feature_dict["Y"] = torch.zeros([1, L, 16, 3]).to(self.device)
297-
feature_dict["Y_t"] = torch.zeros([1, L, 16]).to(self.device)
298-
feature_dict["Y_m"] = torch.zeros([1, L, 16]).to(self.device)
299321

300-
feature_dict["R_idx"] = torch.arange(L)[None,:].to(self.device) # B*L resnums
301-
302-
if self.binderlen>0:
303-
feature_dict["chain_labels"] = torch.cat((torch.zeros((1,self.binderlen)),
304-
torch.ones((1,L-self.binderlen))),1).to(self.device) # B*L Chain indices
305-
else:
306-
feature_dict["chain_labels"]=torch.zeros((1,L)).to(self.device)
307-
308-
output_dict = self.seq_model.score(feature_dict, use_sequence=False)
322+
output_dict = self.seq_model.score(feature_dict, use_sequence=True)
309323

310324
return output_dict
311325

312-
def get_aa_probs(self, xyz, seq, seq_mask):
313326

314-
output_dict=self.run_LigandMPNN(xyz, seq, seq_mask)
327+
def get_aa_probs(self, xyz, seq, seq_mask, ligand_xyz=None, ligand_aatypes=None):
328+
329+
output_dict=self.run_LigandMPNN(xyz, seq, seq_mask, ligand_xyz, ligand_aatypes)
315330
probs=torch.nn.functional.softmax(output_dict['logits'], dim=-1)
316331
probs=probs[0,:,self.renumber_aa_mpnn2rf]
317332
probs[seq_mask.squeeze()]=seq[seq_mask.squeeze()]
@@ -334,7 +349,7 @@ def bb2martini(self, xyz, seq):
334349
return self.recover_sc(feature_dict)
335350

336351

337-
def __call__(self, xyz, seq=None, seq_mask=None):
352+
def __call__(self, xyz, seq=None, seq_mask=None, ligand_xyz=None, ligand_aatypes=None):
338353

339354
xyz=xyz.clone().to(self.device)
340355

@@ -348,7 +363,7 @@ def __call__(self, xyz, seq=None, seq_mask=None):
348363
if self.recover_sc==None:
349364
self.init_recover_sc()
350365

351-
seq=self.get_aa_probs(xyz, seq, seq_mask)
366+
seq=self.get_aa_probs(xyz, seq, seq_mask, ligand_xyz, ligand_aatypes)
352367

353368
d=self.bb2martini(xyz, seq)
354369

rfdiffusion/util.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,17 +422,16 @@ def writena(
422422
atomscpu = atoms.cpu().squeeze()
423423
if bfacts is None:
424424
bfacts = torch.zeros(atomscpu.shape[0])
425-
if idx_pdb is None:
426-
idx_pdb = 1 + torch.arange(atomscpu.shape[0])
427425

428426
Bfacts = torch.clamp(bfacts.cpu(), 0, 1)
429427
for i, s in enumerate(info['pdb_idx']):
430428
chain=s[0]
431429
idx_pdb=s[1]
432430
s=info['seq'][i]
433-
for j, atm_j in enumerate(info['atom_names'][i]):
431+
atms = na2long[s][:23]
432+
for j, atm_j in enumerate(atms):
434433
if (
435-
j < sum(info['mask'][i]) and atm_j is not None
434+
info['mask'][i][j]>0
436435
): # and not torch.isnan(atomscpu[i,j,:]).any()):
437436
f.write(
438437
"%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"

0 commit comments

Comments
 (0)