Skip to content

Commit 5735aff

Browse files
committed
perf: reduce GPU memory usage during inference
- Move trajectory data (px0, x_t, seq, plddt) to CPU immediately after each denoising step instead of accumulating on GPU. This frees GPU memory for the next forward pass, reducing peak memory usage proportional to the number of diffusion steps. - Remove two unused ComputeAllAtomCoords() instantiations in get_next_ca() and get_next_pose() that were created every timestep but never referenced, wasting memory and compute.
1 parent 9535f19 commit 5735aff

2 files changed

Lines changed: 8 additions & 7 deletions

File tree

rfdiffusion/inference/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from rfdiffusion.diffusion import get_beta_schedule
77
from scipy.spatial.transform import Rotation as scipy_R
88
from rfdiffusion.util import rigid_from_3_points
9-
from rfdiffusion.util_module import ComputeAllAtomCoords
9+
1010
from rfdiffusion import util
1111
import random
1212
import logging
@@ -152,7 +152,6 @@ def get_next_ca(
152152
noise_scale: scale factor for the noise being added
153153
154154
"""
155-
get_allatom = ComputeAllAtomCoords().to(device=xt.device)
156155
L = len(xt)
157156

158157
# bring to origin after global alignment (when don't have a motif) or replace input motif and bring to origin, and then scale
@@ -435,7 +434,6 @@ def get_next_pose(
435434
include_motif_sidechains (bool): Provide sidechains of the fixed motif to the model
436435
"""
437436

438-
get_allatom = ComputeAllAtomCoords().to(device=xt.device)
439437
L, n_atom = xt.shape[:2]
440438
assert (xt.shape[1] == 14) or (xt.shape[1] == 27)
441439
assert (px0.shape[1] == 14) or (px0.shape[1] == 27)

scripts/run_inference.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,13 @@ def main(conf: HydraConfig) -> None:
9494
px0, x_t, seq_t, plddt = sampler.sample_step(
9595
t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step
9696
)
97-
px0_xyz_stack.append(px0)
98-
denoised_xyz_stack.append(x_t)
99-
seq_stack.append(seq_t)
100-
plddt_stack.append(plddt[0]) # remove singleton leading dimension
97+
# Move trajectory data to CPU immediately to free GPU memory
98+
# for the next denoising step. x_t/seq_t are cloned inside
99+
# sample_step, so moving previous outputs to CPU is safe.
100+
px0_xyz_stack.append(px0.cpu())
101+
denoised_xyz_stack.append(x_t.cpu())
102+
seq_stack.append(seq_t.cpu())
103+
plddt_stack.append(plddt[0].cpu()) # remove singleton leading dimension
101104

102105
# Flip order for better visualization in pymol
103106
denoised_xyz_stack = torch.stack(denoised_xyz_stack)

0 commit comments

Comments
 (0)