@@ -360,7 +360,7 @@ def rmsd(V, W, eps=0):
360360 px0_ [~ atom_mask ] = float ("nan" )
361361 return torch .Tensor (px0_ )
362362
363- def get_potential_gradients (self , xyz , diffusion_mask ):
363+ def get_potential_gradients (self , xyz , diffusion_mask , t , predicted = False ):
364364 """
365365 This could be moved into potential manager if desired - NRB
366366
@@ -375,7 +375,8 @@ def get_potential_gradients(self, xyz, diffusion_mask):
375375 Ca_grads (torch.tensor): [L,3] The gradient at each Ca atom
376376 """
377377
378- if self .potential_manager == None or self .potential_manager .is_empty ():
378+ if (self .potential_manager == None or
379+ sum ([potential .predicted == predicted for potential in self .potential_manager .potentials_to_apply ])< 1 ):
379380 return torch .zeros (xyz .shape [0 ], 3 )
380381
381382 use_Cb = False
@@ -386,7 +387,7 @@ def get_potential_gradients(self, xyz, diffusion_mask):
386387 if not xyz .grad is None :
387388 xyz .grad .zero_ ()
388389
389- current_potential = self .potential_manager .compute_all_potentials (xyz )
390+ current_potential = self .potential_manager .compute_all_potentials (xyz , predicted )
390391 current_potential .backward ()
391392
392393 # Since we are not moving frames, Cb grads are same as Ca grads
@@ -395,6 +396,10 @@ def get_potential_gradients(self, xyz, diffusion_mask):
395396
396397 if not diffusion_mask == None :
397398 Ca_grads [diffusion_mask , :] = 0
399+
400+ # clamp gradients
401+ dist = (Ca_grads ** 2 ).sum (- 1 , keepdim = True ).sqrt ()
402+ Ca_grads = torch .where (dist <= 1 / self .alphabar_schedule [t - 1 ], Ca_grads , Ca_grads / dist / self .alphabar_schedule [t - 1 ]).detach ()
398403
399404 # check for NaN's
400405 if torch .isnan (Ca_grads ).any ():
@@ -452,16 +457,7 @@ def get_next_pose(
452457 # Now done with diffusion mask. if fix motif is False, just set diffusion mask to be all True, and all coordinates can diffuse
453458 if not fix_motif :
454459 diffusion_mask [:] = False
455-
456- grad_ca = self .get_potential_gradients (
457- px0 .clone (), diffusion_mask = diffusion_mask
458- )
459- # clamp gradients
460- dist = (grad_ca ** 2 ).sum (- 1 , keepdim = True ).sqrt ()
461- grad_ca = torch .where (dist <= 1 / self .alphabar_schedule [t - 1 ], grad_ca , grad_ca / dist / self .alphabar_schedule [t - 1 ]).detach ()
462-
463- px0 += self .potential_manager .get_guide_scale (t ) * grad_ca [:, None , :]
464-
460+
465461 # get the next set of CA coordinates
466462 noise_scale_ca = self .noise_schedule_ca (t )
467463 _ , ca_deltas = get_next_ca (
@@ -489,16 +485,16 @@ def get_next_pose(
489485
490486 # Apply gradient step from guiding potentials
491487 # This can be moved to below where the full atom representation is calculated to allow for potentials involving sidechains
492- '''
488+
493489 grad_ca = self .get_potential_gradients (
494- xt.clone(), diffusion_mask=diffusion_mask
490+ xt .clone (), diffusion_mask = diffusion_mask , t = t
491+ ) + self .get_potential_gradients (
492+ px0 .clone (), diffusion_mask = diffusion_mask ,
493+ t = t , predicted = True
495494 )
496- # clamp gradients
497- dist=(grad_ca**2).sum(-1, keepdim=True).sqrt()
498- grad_ca=torch.where(dist<=1/self.alphabar_schedule[t-1], grad_ca, grad_ca/dist/self.alphabar_schedule[t-1]).detach()
499-
495+
500496 ca_deltas += self .potential_manager .get_guide_scale (t ) * grad_ca
501- '''
497+
502498 # add the delta to the new frames
503499 frames_next = torch .from_numpy (frames_next ) + ca_deltas [:, None , :] # translate
504500
0 commit comments