@@ -122,6 +122,28 @@ def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6):
122122 return mu , sigma
123123
124124
125+ def rigid_rotation_from_grads (Cas , Ca_grads , eps = 1e-8 ):
126+ center = Cas .mean (dim = 0 ) # (3,)
127+ r = Cas - center # (L,3)
128+
129+ trans = Ca_grads .mean (dim = 0 , keepdim = True ) # (L, 3)
130+ d = Ca_grads - trans # (L,3)
131+ eye = torch .eye (3 , device = Cas .device , dtype = Cas .dtype )
132+ r2 = (r ** 2 ).sum (dim = 1 ) # (L,)
133+ rrT = r [:,:,None ]* r [:,None ,:] # (L,3,3)
134+ I = (r2 [:, None , None ] * eye [None , :, :] - rrT ).sum (dim = 0 ) # (3,3)
135+ tau = torch .cross (r ,d , dim = 1 ).sum (dim = 0 ) # (3,)
136+ try :
137+ omega = torch .linalg .solve (I + eps * eye , tau ) # (3,)
138+ except RuntimeError :
139+ omega = torch .linalg .lstsq (I + eps * eye , tau .unsqueeze (- 1 )).solution .squeeze (- 1 )
140+
141+ rot = torch .cross (omega .unsqueeze (0 ).expand_as (r ), r , dim = 1 ) #(L,3)
142+
143+ return trans , omega , center , rot
144+
145+
146+
125147def get_next_ca (
126148 xt ,
127149 px0 ,
@@ -392,7 +414,11 @@ def get_potential_gradients(self, xyz, diffusion_mask, t, predicted=False):
392414
393415 # Since we are not moving frames, Cb grads are same as Ca grads
394416 # Need access to calculated Cb coordinates to be able to get Cb grads though
395- Ca_grads = xyz .grad [:, 1 , :]
417+ if xyz .grad is None :
418+ print ("WARNING: NaN in potential gradients, replacing with zero grad." )
419+ Ca_grads = torch .zeros_like (xyz [:, 1 , :])
420+ else :
421+ Ca_grads = xyz .grad [:, 1 , :]
396422
397423 if not diffusion_mask == None :
398424 Ca_grads [diffusion_mask , :] = 0
@@ -406,6 +432,20 @@ def get_potential_gradients(self, xyz, diffusion_mask, t, predicted=False):
406432 print ("WARNING: NaN in potential gradients, replacing with zero grad." )
407433 Ca_grads [:] = 0
408434
435+ # smooth potential effects within protein subunits
436+ smooth_scale = max ([potential .smooth for potential in self .potential_manager .potentials_to_apply ])
437+ if smooth_scale > 0 :
438+ Cas = xyz [:, 1 , :]
439+ binderlen = self .potential_manager .binderlen
440+ if binderlen < 0 :
441+ borders = [(0 ,Ca_grads .shape [0 ])]
442+ else :
443+ borders = [(0 ,binderlen ),(binderlen ,Ca_grads .shape [0 ])]
444+ for a , b in borders :
445+ with torch .no_grad ():
446+ trans , omega , center , rot = rigid_rotation_from_grads (Cas [a :b ],Ca_grads [a :b ])
447+ Ca_grads [a :b ]= Ca_grads [a :b ]* (1 - smooth_scale )+ (trans + rot )* smooth_scale
448+
409449 return Ca_grads
410450
411451 def get_next_pose (
@@ -620,38 +660,40 @@ def parse_pdb_lines(lines, parse_hetatom=False, parse_na=False, ignore_het_h=Tru
620660 res , pdb_idx = [],[]
621661 for l in lines :
622662 if l [:4 ] == "ATOM" and l [12 :16 ].strip () == "C1'" :
623- res .append ((l [22 :26 ], l [17 :20 ]))
663+ res .append ((l [22 :26 ], l [17 :20 ]. strip () ))
624664 # chain letter, res num
625665 pdb_idx .append ((l [21 :22 ].strip (), int (l [22 :26 ].strip ())))
626- seq = [util .na2num [r [1 ]] if r [1 ] in util .na2num . keys () else 20 for r in res ]
666+ seq = [util .na2num [r [1 ]] if r [1 ] in util .na2num else 20 for r in res ]
627667 pdb_idx = [
628668 (l [21 :22 ].strip (), int (l [22 :26 ].strip ()))
629669 for l in lines
630670 if l [:4 ] == "ATOM" and l [12 :16 ].strip () == "C1'"
631671 ] # chain letter, res num
632672
633- # 4 BB + up to 10 SC atoms
673+ # 3 BB + up to 20 SC atoms
634674 xyz = np .full ((len (res ), 23 , 3 ), np .nan , dtype = np .float32 )
635- xyz_names = np .full ((len (res ), 23 , 3 ), np .nan )
675+ atom_id = np .full ((len (res ), 23 ), np .nan , dtype = np .object )
676+ atom_type = np .full ((len (res ), 23 ), np .nan , dtype = np .object )
636677 for l in lines :
637678 if l [:4 ] != "ATOM" :
638679 continue
639680 chain , resNo , atom , aa = (
640681 l [21 :22 ],
641682 int (l [22 :26 ]),
642683 " " + l [12 :16 ].strip ().ljust (3 ),
643- l [17 :20 ],
684+ l [17 :20 ]. strip () ,
644685 )
645686 if (chain ,resNo ) in pdb_idx :
646687 idx = pdb_idx .index ((chain , resNo ))
647688 for i_atm , tgtatm in enumerate (
648- util .na2long [util .na2num [aa ]][:14 ]
689+ util .na2long [util .na2num [aa ]][:23 ]
649690 ):
650691 if (
651692 tgtatm is not None and tgtatm .strip () == atom .strip ()
652693 ): # ignore whitespace
653694 xyz [idx , i_atm , :] = [float (l [30 :38 ]), float (l [38 :46 ]), float (l [46 :54 ])]
654- xyz_names [idx , i_atm , :] = l [16 :20 ]
695+ atom_id [idx , i_atm ] = atom
696+ atom_type [idx , i_atm ] = l [77 ]
655697 break
656698
657699 # save atom mask
@@ -674,7 +716,8 @@ def parse_pdb_lines(lines, parse_hetatom=False, parse_na=False, ignore_het_h=Tru
674716
675717 out ["na_xyz" ]= xyz # cartesian coordinates, [Lx23]
676718 out ["na_mask" ]= mask # mask showing which atoms are present in the PDB file, [Lx23]
677- out ['na_atom_names' ]= xyz_names
719+ out ['na_atom_id' ]= atom_id
720+ out ['na_atom_type' ]= atom_type
678721 out ["na_seq" ]= np .array (seq ) # amino acid sequence, [L]
679722 out ["na_pdb_idx" ]= pdb_idx # list of (chain letter, residue number) in the pdb file, [L]
680723
@@ -713,7 +756,8 @@ def process_target(pdb_path, parse_hetatom=False, parse_na=False, center=True):
713756
714757 if parse_na :
715758 out ['na_info' ]= {'mask' :target_struct ["na_mask" ],
716- 'atom_names' :target_struct ['na_atom_names' ],
759+ 'atom_id' :target_struct ['na_atom_id' ],
760+ 'atom_type' :target_struct ['na_atom_type' ],
717761 'seq' :target_struct ["na_seq" ],
718762 'pdb_idx' :target_struct ["na_pdb_idx" ]}
719763 out ["na_xyz" ]= target_struct ["na_xyz" ]
0 commit comments