@@ -123,21 +123,77 @@ def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6):
123123
124124
125125def rigid_rotation_from_grads (Cas , Ca_grads , eps = 1e-8 ):
126+
127+ """
128+ Estimate best-fit infinitesimal rigid motion (translation + rotation)
129+ from per-residue gradients Ca_grads at positions Cas.
130+
131+ Intuition
132+ ---------
133+ We decompose the gradient field on Ca atoms into:
134+ - a global translation (mean over residues), and
135+ - a small rigid rotation around the geometric center.
136+
137+ The rotation is solved in the least-squares sense via the 3x3 linear
138+ system (I + eps*I) · omega = tau, where
139+ - I = Σ(||r_i||^2)·I3 - Σ(r_i r_i^T) is an inertia-like matrix of the point set,
140+ - tau = Σ(r_i x d_i) is a torque-like vector of centered gradients,
141+ - r_i = Cas_i - center, d_i = Ca_grads_i - trans.
142+ This returns the angular-velocity vector omega that best explains the
143+ rotational component of the gradients.
144+
145+ Returns:
146+ trans (1,3): mean translation component applied to all residues
147+ omega (3,): angular velocity vector defining small rotation
148+ center (3,): geometric center of Cas
149+ rot (L,3): per-residue rotational component omega x (Cas - center)
150+ """
151+ device , dtype = Cas .device , Cas .dtype
152+ L = Cas .shape [0 ]
153+
154+ # Guard empty input
155+ if L == 0 :
156+ return (
157+ torch .zeros (1 , 3 , device = device , dtype = dtype ),
158+ torch .zeros (3 , device = device , dtype = dtype ),
159+ torch .zeros (3 , device = device , dtype = dtype ),
160+ torch .zeros (0 , 3 , device = device , dtype = dtype ),
161+ )
162+
163+ # Geometric center and centered positions r_i
126164 center = Cas .mean (dim = 0 ) # (3,)
127165 r = Cas - center # (L,3)
128166
129- trans = Ca_grads .mean (dim = 0 , keepdim = True ) # (L, 3)
167+ # Mean translation across residues (global shift suggested by gradients)
168+ trans = Ca_grads .mean (dim = 0 , keepdim = True ) # (1, 3)
169+ # Centered gradients remove the pure-translation component
130170 d = Ca_grads - trans # (L,3)
131- eye = torch .eye (3 , device = Cas .device , dtype = Cas .dtype )
171+
172+ eye = torch .eye (3 , device = device , dtype = dtype )
173+ # Inertia-like matrix of centered points (well-known identity for Σ [r]_x^T [r]_x)
174+ # I = Σ(||r_i||^2)·I3 − Σ(r_i r_i^T) ∈ R^{3x3}
175+ '''
132176 r2=(r**2).sum(dim=1) # (L,)
133177 rrT=r[:,:,None]*r[:,None,:] # (L,3,3)
134178 I=(r2[:, None, None] * eye[None, :, :] - rrT).sum(dim=0) # (3,3)
179+ '''
180+ r2_sum = (r * r ).sum () # scalar Σ ||r_i||^2
181+ rr_sum = r .T @ r # (3,3) Σ r_i r_i^T
182+ I = r2_sum * eye - rr_sum # (3,3)
183+
184+ # Torque-like vector: τ = Σ (r_i × d_i)
185+ # Captures the net tendency of gradients to induce rotation about the center.
135186 tau = torch .cross (r ,d , dim = 1 ).sum (dim = 0 ) # (3,)
187+
188+ # Solve for small-rotation vector ω from (I + eps·I3)·ω = τ.
189+ # eps stabilizes near-singular geometries (e.g., collinear or coplanar points).
136190 try :
137191 omega = torch .linalg .solve (I + eps * eye , tau ) # (3,)
138192 except RuntimeError :
193+ # Fallback in case of numerical issues (rare): least-squares solution.
139194 omega = torch .linalg .lstsq (I + eps * eye , tau .unsqueeze (- 1 )).solution .squeeze (- 1 )
140195
196+ # Per-residue rotational component: ω × r_i
141197 rot = torch .cross (omega .unsqueeze (0 ).expand_as (r ), r , dim = 1 ) #(L,3)
142198
143199 return trans , omega , center , rot
@@ -652,7 +708,7 @@ def parse_pdb_lines(lines, parse_hetatom=False, parse_na=False, ignore_het_h=Tru
652708 )
653709 xyz_het .append ([float (l [30 :38 ]), float (l [38 :46 ]), float (l [46 :54 ])])
654710
655- out ["xyz_het" ] = np .array (xyz_het )
711+ out ["xyz_het" ] = np .array (xyz_het ). reshape (( len ( xyz_het ), 3 ))
656712 out ["info_het" ] = info_het
657713
658714 # nucleic acids
@@ -671,7 +727,7 @@ def parse_pdb_lines(lines, parse_hetatom=False, parse_na=False, ignore_het_h=Tru
671727 ] # chain letter, res num
672728
673729 # 3 BB + up to 20 SC atoms
674- xyz = np .full ((len (res ), 23 , 3 ), np .nan , dtype = np .float32 )
730+ xyz = np .full ((len (res ), 23 , 3 ), np .nan , dtype = np .float64 )
675731 atom_id = np .full ((len (res ), 23 ), np .nan , dtype = np .object )
676732 atom_type = np .full ((len (res ), 23 ), np .nan , dtype = np .object )
677733 for l in lines :
@@ -751,7 +807,7 @@ def process_target(pdb_path, parse_hetatom=False, parse_na=False, center=True):
751807 "pdb_idx" : target_struct ["pdb_idx" ],
752808 }
753809 if parse_hetatom :
754- out ["xyz_het" ] = target_struct ["xyz_het" ]
810+ out ["xyz_het" ] = target_struct ["xyz_het" ] - ca_center
755811 out ["info_het" ] = target_struct ["info_het" ]
756812
757813 if parse_na :
@@ -760,7 +816,7 @@ def process_target(pdb_path, parse_hetatom=False, parse_na=False, center=True):
760816 'atom_type' :target_struct ['na_atom_type' ],
761817 'seq' :target_struct ["na_seq" ],
762818 'pdb_idx' :target_struct ["na_pdb_idx" ]}
763- out ["na_xyz" ]= target_struct ["na_xyz" ]
819+ out ["na_xyz" ]= target_struct ["na_xyz" ] - ca_center
764820
765821 return out
766822
0 commit comments