Skip to content

Commit 4781517

Browse files
committed
add smoothing
1 parent dd5aefb commit 4781517

3 files changed

Lines changed: 83 additions & 17 deletions

File tree

rfdiffusion/inference/model_runners.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def assemble_config_from_chk(self) -> None:
216216
print(f'WARNING: You are changing {override.split("=")[0]} from the value this model was trained with. Are you sure you know what you are doing?')
217217
mytype = type(self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]])
218218
self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]] = mytype(override.split("=")[1])
219+
print(override, self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]])
219220

220221
def load_model(self):
221222
"""Create RosettaFold model from preloaded checkpoint."""
@@ -266,7 +267,7 @@ def sample_init(self, return_forward_trajectory=False):
266267
### Parse input pdb ###
267268
#######################
268269

269-
self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, parse_na=True, center=False)
270+
self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, parse_na=True, center=True)
270271

271272
################################
272273
### Generate specific contig ###
@@ -428,7 +429,9 @@ def sample_init(self, return_forward_trajectory=False):
428429
het_names = np.array([i['name'].strip() for i in self.target_feats['info_het']])
429430
xyz_het = self.target_feats['xyz_het'][het_names == self._conf.potentials.substrate]
430431
xyz_het = torch.from_numpy(xyz_het)
431-
info_het={x: self.target_feats['info_het'][x][het_names == self._conf.potentials.substrate] for x in self.target_feats['info_het']}
432+
info_het={x: np.array([y[x] for y in self.target_feats['info_het']
433+
if y['name'] == self._conf.potentials.substrate])
434+
for x in self.target_feats['info_het'][0]}
432435
assert xyz_het.shape[0] > 0, f'expected >0 heteroatoms from ligand with name {self._conf.potentials.substrate}'
433436
xyz_motif_prealign = xyz_motif_prealign[0,0][self.diffusion_mask.squeeze()]
434437
motif_prealign_com = xyz_motif_prealign[:,1].mean(dim=0)

rfdiffusion/inference/utils.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,21 +123,77 @@ def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6):
123123

124124

125125
def rigid_rotation_from_grads(Cas, Ca_grads, eps=1e-8):
126+
127+
"""
128+
Estimate best-fit infinitesimal rigid motion (translation + rotation)
129+
from per-residue gradients Ca_grads at positions Cas.
130+
131+
Intuition
132+
---------
133+
We decompose the gradient field on Ca atoms into:
134+
- a global translation (mean over residues), and
135+
- a small rigid rotation around the geometric center.
136+
137+
The rotation is solved in the least-squares sense via the 3x3 linear
138+
system (I + eps*I) · omega = tau, where
139+
- I = Σ(||r_i||^2)·I3 - Σ(r_i r_i^T) is an inertia-like matrix of the point set,
140+
- tau = Σ(r_i x d_i) is a torque-like vector of centered gradients,
141+
- r_i = Cas_i - center, d_i = Ca_grads_i - trans.
142+
This returns the angular-velocity vector omega that best explains the
143+
rotational component of the gradients.
144+
145+
Returns:
146+
trans (1,3): mean translation component applied to all residues
147+
omega (3,): angular velocity vector defining small rotation
148+
center (3,): geometric center of Cas
149+
rot (L,3): per-residue rotational component omega x (Cas - center)
150+
"""
151+
device, dtype = Cas.device, Cas.dtype
152+
L = Cas.shape[0]
153+
154+
# Guard empty input
155+
if L == 0:
156+
return (
157+
torch.zeros(1, 3, device=device, dtype=dtype),
158+
torch.zeros(3, device=device, dtype=dtype),
159+
torch.zeros(3, device=device, dtype=dtype),
160+
torch.zeros(0, 3, device=device, dtype=dtype),
161+
)
162+
163+
# Geometric center and centered positions r_i
126164
center=Cas.mean(dim=0) # (3,)
127165
r=Cas-center # (L,3)
128166

129-
trans=Ca_grads.mean(dim=0, keepdim=True) # (L, 3)
167+
# Mean translation across residues (global shift suggested by gradients)
168+
trans=Ca_grads.mean(dim=0, keepdim=True) # (1, 3)
169+
# Centered gradients remove the pure-translation component
130170
d=Ca_grads-trans # (L,3)
131-
eye=torch.eye(3, device=Cas.device, dtype=Cas.dtype)
171+
172+
eye=torch.eye(3, device=device, dtype=dtype)
173+
# Inertia-like matrix of centered points (well-known identity for Σ [r]_x^T [r]_x)
174+
# I = Σ(||r_i||^2)·I3 − Σ(r_i r_i^T) ∈ R^{3x3}
175+
'''
132176
r2=(r**2).sum(dim=1) # (L,)
133177
rrT=r[:,:,None]*r[:,None,:] # (L,3,3)
134178
I=(r2[:, None, None] * eye[None, :, :] - rrT).sum(dim=0) # (3,3)
179+
'''
180+
r2_sum = (r * r).sum() # scalar Σ ||r_i||^2
181+
rr_sum = r.T @ r # (3,3) Σ r_i r_i^T
182+
I = r2_sum * eye - rr_sum # (3,3)
183+
184+
# Torque-like vector: τ = Σ (r_i × d_i)
185+
# Captures the net tendency of gradients to induce rotation about the center.
135186
tau=torch.cross(r,d, dim=1).sum(dim=0) # (3,)
187+
188+
# Solve for small-rotation vector ω from (I + eps·I3)·ω = τ.
189+
# eps stabilizes near-singular geometries (e.g., collinear or coplanar points).
136190
try:
137191
omega = torch.linalg.solve(I + eps*eye, tau) # (3,)
138192
except RuntimeError:
193+
# Fallback in case of numerical issues (rare): least-squares solution.
139194
omega = torch.linalg.lstsq(I + eps*eye, tau.unsqueeze(-1)).solution.squeeze(-1)
140195

196+
# Per-residue rotational component: ω × r_i
141197
rot = torch.cross(omega.unsqueeze(0).expand_as(r), r, dim=1) #(L,3)
142198

143199
return trans, omega, center, rot
@@ -652,7 +708,7 @@ def parse_pdb_lines(lines, parse_hetatom=False, parse_na=False, ignore_het_h=Tru
652708
)
653709
xyz_het.append([float(l[30:38]), float(l[38:46]), float(l[46:54])])
654710

655-
out["xyz_het"] = np.array(xyz_het)
711+
out["xyz_het"] = np.array(xyz_het).reshape((len(xyz_het),3))
656712
out["info_het"] = info_het
657713

658714
# nucleic acids
@@ -671,7 +727,7 @@ def parse_pdb_lines(lines, parse_hetatom=False, parse_na=False, ignore_het_h=Tru
671727
] # chain letter, res num
672728

673729
# 3 BB + up to 20 SC atoms
674-
xyz = np.full((len(res), 23, 3), np.nan, dtype=np.float32)
730+
xyz = np.full((len(res), 23, 3), np.nan, dtype=np.float64)
675731
atom_id = np.full((len(res), 23), np.nan, dtype=np.object)
676732
atom_type = np.full((len(res), 23), np.nan, dtype=np.object)
677733
for l in lines:
@@ -751,7 +807,7 @@ def process_target(pdb_path, parse_hetatom=False, parse_na=False, center=True):
751807
"pdb_idx": target_struct["pdb_idx"],
752808
}
753809
if parse_hetatom:
754-
out["xyz_het"] = target_struct["xyz_het"]
810+
out["xyz_het"] = target_struct["xyz_het"] - ca_center
755811
out["info_het"] = target_struct["info_het"]
756812

757813
if parse_na:
@@ -760,7 +816,7 @@ def process_target(pdb_path, parse_hetatom=False, parse_na=False, center=True):
760816
'atom_type':target_struct['na_atom_type'],
761817
'seq':target_struct["na_seq"],
762818
'pdb_idx':target_struct["na_pdb_idx"]}
763-
out["na_xyz"]= target_struct["na_xyz"]
819+
out["na_xyz"]= target_struct["na_xyz"] - ca_center
764820

765821
return out
766822

rfdiffusion/potentials/potentials.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -378,14 +378,17 @@ class substrate_contacts(Potential):
378378
Implicitly models a ligand with an attractive-repulsive potential.
379379
'''
380380

381-
def __init__(self, weight=1, r_0=8, d_0=2, s=1, eps=1e-6, rep_r_0=5, rep_s=2, rep_r_min=1, sidechain=False):
381+
def __init__(self, weight=1, r_0=8, d_0=2, s=1, eps=1e-6, rep_r_0=5, rep_s=2, rep_r_min=1,
382+
sidechain=False, smooth=0, predicted=False):
382383

383384
super().__init__()
384385
self.r_0 = r_0
385386
self.weight = weight
386387
self.d_0 = d_0
387388
self.eps = eps
388389
self.sidechain=sidechain
390+
self.predicted=predicted
391+
self.smooth=smooth
389392

390393
# motif frame coordinates
391394
# NOTE: these probably need to be set after sample_init() call, because the motif sequence position in design must be known
@@ -451,7 +454,7 @@ def compute(self, xyz):
451454
self.current_substrate_atoms = substrate_atoms.clone().detach()
452455
energy = sum(all_energies)
453456
print('SUBSTRATE CONTACT LOSS:',energy.item())
454-
return self.weight * energy
457+
return - self.weight * energy
455458

456459
#Potential value is the average of both radii of gyration (is avg. the best way to do this?)
457460
return self.weight * ncontacts.sum()
@@ -544,22 +547,26 @@ def __init__(self, weight=1, r_0=8, d_0=2, s=1, eps=1e-6, rep_r_0=5, rep_s=2, re
544547
def compute(self, xyz):
545548

546549
if self.xyz_motif==None or self.xyz_motif.shape[0]<3:
547-
substrate_atoms=(self.na_atoms-self.na_atoms[:,:11,:].mean(dim=(0,1))[None,None,:]).detach()
550+
substrate_atoms=self.na_atoms.clone().detach()
551+
#substrate_atoms=(self.na_atoms-self.na_atoms[:,:11,:].mean(dim=(0,1))[None,None,:]).detach()
552+
self.current_na_atoms = substrate_atoms.clone().detach()
548553

549554
else:
550555
self._grab_motif_residues(self.xyz_motif)
551556

552-
first_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(self.motif_substrate_atoms[0] - self.motif_frame[0]), dim=-1)))
557+
L, D, _ = self.na_atoms.shape
558+
idx=torch.argmin(torch.sum(torch.square(self.na_atoms.view(-1,3) - self.motif_frame[0]), dim=-1))
559+
first_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(self.na_atoms.view(-1,3)[idx] - self.motif_frame[0]), dim=-1)))
553560

554561
res = torch.tensor([k[0] for k in self.motif_mapping])
555562
atoms = torch.tensor([k[1] for k in self.motif_mapping])
556563
new_frame = xyz[self.diffusion_mask][res,atoms,:]
557564
A, t = self._recover_affine(self.motif_frame, new_frame)
558-
substrate_atoms = torch.mm(A, self.motif_substrate_atoms.transpose(0,1)).transpose(0,1) + t
559-
second_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(new_frame[0] - substrate_atoms[0]), dim=-1)))
565+
substrate_atoms = torch.mm(A, self.na_atoms.view(-1,3).transpose(0,1)).transpose(0,1) + t
566+
second_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(new_frame[0] - substrate_atoms[idx]), dim=-1)))
560567
assert abs(first_distance - second_distance) < 0.01, "Alignment seems to be bad"
568+
self.current_na_atoms = substrate_atoms.view(L, D, 3).clone().detach()
561569

562-
self.current_na_atoms = substrate_atoms.clone().detach()
563570
substrate_atoms=substrate_atoms.view(-1,3)
564571
mask=torch.from_numpy(self.na_info['mask']).view(-1)
565572
substrate_atoms=substrate_atoms[mask,:]
@@ -580,7 +587,7 @@ def compute(self, xyz):
580587
all_energies.append(energy.sum())
581588
energy = sum(all_energies)
582589
print('NA CONTACT LOSS:',energy.item())
583-
return self.weight * energy
590+
return - self.weight * energy
584591

585592

586593
class dmasif_interactions(Potential):
@@ -628,7 +635,7 @@ def compute(self, xyz):
628635
potential.requires_grad_()
629636
xyz.grad=torch.zeros_like(xyz)
630637

631-
return potential
638+
return - potential
632639

633640
# Dictionary of types of potentials indexed by name of potential. Used by PotentialManager.
634641
# If you implement a new potential you must add it to this dictionary for it to be used by

0 commit comments

Comments
 (0)