Skip to content

Commit d440483

Browse files
committed
compute potentials from either partially denoised or predicted structures
1 parent d3b95ca commit d440483

3 files changed

Lines changed: 34 additions & 25 deletions

File tree

rfdiffusion/inference/utils.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def rmsd(V, W, eps=0):
360360
px0_[~atom_mask] = float("nan")
361361
return torch.Tensor(px0_)
362362

363-
def get_potential_gradients(self, xyz, diffusion_mask):
363+
def get_potential_gradients(self, xyz, diffusion_mask, t, predicted=False):
364364
"""
365365
This could be moved into potential manager if desired - NRB
366366
@@ -375,7 +375,8 @@ def get_potential_gradients(self, xyz, diffusion_mask):
375375
Ca_grads (torch.tensor): [L,3] The gradient at each Ca atom
376376
"""
377377

378-
if self.potential_manager == None or self.potential_manager.is_empty():
378+
if (self.potential_manager == None or
379+
sum([potential.predicted==predicted for potential in self.potential_manager.potentials_to_apply])<1):
379380
return torch.zeros(xyz.shape[0], 3)
380381

381382
use_Cb = False
@@ -386,7 +387,7 @@ def get_potential_gradients(self, xyz, diffusion_mask):
386387
if not xyz.grad is None:
387388
xyz.grad.zero_()
388389

389-
current_potential = self.potential_manager.compute_all_potentials(xyz)
390+
current_potential = self.potential_manager.compute_all_potentials(xyz, predicted)
390391
current_potential.backward()
391392

392393
# Since we are not moving frames, Cb grads are same as Ca grads
@@ -395,6 +396,10 @@ def get_potential_gradients(self, xyz, diffusion_mask):
395396

396397
if not diffusion_mask == None:
397398
Ca_grads[diffusion_mask, :] = 0
399+
400+
# clamp gradients
401+
dist=(Ca_grads**2).sum(-1, keepdim=True).sqrt()
402+
Ca_grads=torch.where(dist<=1/self.alphabar_schedule[t-1], Ca_grads, Ca_grads/dist/self.alphabar_schedule[t-1]).detach()
398403

399404
# check for NaN's
400405
if torch.isnan(Ca_grads).any():
@@ -452,16 +457,7 @@ def get_next_pose(
452457
# Now done with diffusion mask. if fix motif is False, just set diffusion mask to be all True, and all coordinates can diffuse
453458
if not fix_motif:
454459
diffusion_mask[:] = False
455-
456-
grad_ca = self.get_potential_gradients(
457-
px0.clone(), diffusion_mask=diffusion_mask
458-
)
459-
# clamp gradients
460-
dist=(grad_ca**2).sum(-1, keepdim=True).sqrt()
461-
grad_ca=torch.where(dist<=1/self.alphabar_schedule[t-1], grad_ca, grad_ca/dist/self.alphabar_schedule[t-1]).detach()
462-
463-
px0 += self.potential_manager.get_guide_scale(t) * grad_ca[:, None, :]
464-
460+
465461
# get the next set of CA coordinates
466462
noise_scale_ca = self.noise_schedule_ca(t)
467463
_, ca_deltas = get_next_ca(
@@ -489,16 +485,16 @@ def get_next_pose(
489485

490486
# Apply gradient step from guiding potentials
491487
# This can be moved to below where the full atom representation is calculated to allow for potentials involving sidechains
492-
'''
488+
493489
grad_ca = self.get_potential_gradients(
494-
xt.clone(), diffusion_mask=diffusion_mask
490+
xt.clone(), diffusion_mask=diffusion_mask, t=t
491+
) + self.get_potential_gradients(
492+
px0.clone(), diffusion_mask=diffusion_mask,
493+
t=t, predicted=True
495494
)
496-
# clamp gradients
497-
dist=(grad_ca**2).sum(-1, keepdim=True).sqrt()
498-
grad_ca=torch.where(dist<=1/self.alphabar_schedule[t-1], grad_ca, grad_ca/dist/self.alphabar_schedule[t-1]).detach()
499-
495+
500496
ca_deltas += self.potential_manager.get_guide_scale(t) * grad_ca
501-
'''
497+
502498
# add the delta to the new frames
503499
frames_next = torch.from_numpy(frames_next) + ca_deltas[:, None, :] # translate
504500

rfdiffusion/potentials/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,12 @@ def initialize_all_potentials(self, setting_list):
167167

168168
return to_apply
169169

170-
def compute_all_potentials(self, xyz):
170+
def compute_all_potentials(self, xyz, predicted=False):
171171
'''
172172
This is the money call. Take the current sequence and structure information and get the sum of all of the potentials that are being used
173173
'''
174174

175-
potential_list = [potential.compute(xyz) for potential in self.potentials_to_apply]
175+
potential_list = [potential.compute(xyz) for potential in self.potentials_to_apply if potential.predicted==predicted]
176176
potential_stack = torch.stack(potential_list, dim=0)
177177

178178
return torch.sum(potential_stack, dim=0)

rfdiffusion/potentials/potentials.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ class Potential:
88
'''
99
Interface class that defines the functions a potential must implement
1010
'''
11+
def __init__(self):
12+
13+
self.predicted=False
1114

1215
def compute(self, xyz):
1316
'''
@@ -31,7 +34,7 @@ class monomer_ROG(Potential):
3134
'''
3235

3336
def __init__(self, weight=1, min_dist=15):
34-
37+
super().__init__()
3538
self.weight = weight
3639
self.min_dist = min_dist
3740

@@ -57,6 +60,7 @@ class binder_ROG(Potential):
5760

5861
def __init__(self, binderlen, weight=1, min_dist=15):
5962

63+
super().__init__()
6064
self.binderlen = binderlen
6165
self.min_dist = min_dist
6266
self.weight = weight
@@ -87,6 +91,7 @@ class dimer_ROG(Potential):
8791

8892
def __init__(self, binderlen, weight=1, min_dist=15):
8993

94+
super().__init__()
9095
self.binderlen = binderlen
9196
self.min_dist = min_dist
9297
self.weight = weight
@@ -127,6 +132,7 @@ class binder_ncontacts(Potential):
127132

128133
def __init__(self, binderlen, weight=1, r_0=8, d_0=4):
129134

135+
super().__init__()
130136
self.binderlen = binderlen
131137
self.r_0 = r_0
132138
self.weight = weight
@@ -161,6 +167,7 @@ class interface_ncontacts(Potential):
161167

162168
def __init__(self, binderlen, weight=1, r_0=8, d_0=6):
163169

170+
super().__init__()
164171
self.binderlen = binderlen
165172
self.r_0 = r_0
166173
self.weight = weight
@@ -200,6 +207,7 @@ class monomer_contacts(Potential):
200207

201208
def __init__(self, weight=1, r_0=8, d_0=2, eps=1e-6):
202209

210+
super().__init__()
203211
self.r_0 = r_0
204212
self.weight = weight
205213
self.d_0 = d_0
@@ -245,6 +253,7 @@ def __init__(self,
245253
246254
weight (int/float, optional): Scaling/weighting factor
247255
"""
256+
super().__init__()
248257
self.contact_matrix = contact_matrix
249258
self.weight_intra = weight_intra
250259
self.weight_inter = weight_inter
@@ -354,6 +363,7 @@ class substrate_contacts(Potential):
354363

355364
def __init__(self, weight=1, r_0=8, d_0=2, s=1, eps=1e-6, rep_r_0=5, rep_s=2, rep_r_min=1):
356365

366+
super().__init__()
357367
self.r_0 = r_0
358368
self.weight = weight
359369
self.d_0 = d_0
@@ -464,13 +474,16 @@ class dmasif_interactions(Potential):
464474

465475
def __init__(self, binderlen, int_weight=1, non_int_weight=1, threshold=3, seq_model_type='protein_mpnn'):
466476

477+
super().__init__()
478+
self.predicted=True
479+
467480
submodule_path='/'.join(__file__.split('/')[:-4])
468481
import sys
469482
sys.path.append(submodule_path)
470483

471-
from masif_martini.rfdiff_potential import Potential_from_bb
484+
from masif_martini.potential import RFdiff_potential_from_bb
472485

473-
self.potential=Potential_from_bb(binderlen=binderlen,
486+
self.potential=RFdiff_potential_from_bb(binderlen=binderlen,
474487
int_weight=int_weight,
475488
non_int_weight=non_int_weight,
476489
threshold=threshold,

0 commit comments

Comments
 (0)