Skip to content

Commit 65250e1

Browse files
timodonnellclaude
andcommitted
Add gh#9 memory probe: scan n_diffusion_samples on H100
Pre-flight for the full gh#9 fine-tune. Loads a default-size Helico with diffusion_pair_source="distogram_logits" + freeze_trunk, runs one forward+backward at crop_size=384 for n_d in {8, 16, 32, 64}, reports peak GPU memory per attempt. Catches OOM as a status rather than crashing the whole run, so the table always finishes. Used to set HELICO_TRAIN_N_DIFFUSION_SAMPLES for the production run. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 08307ab commit 65250e1

1 file changed

Lines changed: 109 additions & 0 deletions

File tree

modal/probe_diffusion_samples.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Probe peak GPU memory at different ``n_diffusion_samples`` values (gh#9).
2+
3+
With the trunk frozen + diffusion_pair_source="distogram_logits", we should
4+
be able to crank ``n_diffusion_samples`` well past 8 (the gh#6 default).
5+
This probe loads a real Helico from the protenix-v1 seed, runs one
6+
forward+backward at a representative crop_size for each candidate
7+
``n_diffusion_samples`` value, and reports peak GPU memory.
8+
9+
Run:
10+
modal run modal/probe_diffusion_samples.py
11+
12+
Reports a small table: n_d → peak GB / status. The largest value that
13+
stays under ~70 GB on H100 (80 GB) becomes the production knob for the
14+
full gh#9 fine-tune.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
from pathlib import Path
20+
21+
import modal
22+
23+
ROOT = Path(__file__).parent.parent
24+
25+
# Mirror the train image so the cuDNN / torch / cuequivariance versions
26+
# match the eventual training run.
27+
image = (
28+
modal.Image.debian_slim(python_version="3.11")
29+
.apt_install("wget", "curl", "git")
30+
.pip_install(
31+
"torch>=2.10,<2.11", # cuDNN 9.x — torch 2.11's cuDNN 13 broke val (gh#3)
32+
"cuequivariance-torch>=0.8,<0.9",
33+
"cuequivariance-ops-torch-cu12>=0.8,<0.9",
34+
"biopython>=1.80",
35+
"numpy",
36+
"scipy",
37+
"pyyaml>=6.0",
38+
"huggingface_hub>=0.20",
39+
"tqdm",
40+
)
41+
.add_local_dir(str(ROOT / "src"), remote_path="/root/helico/src")
42+
.add_local_file(str(ROOT / "pyproject.toml"), remote_path="/root/helico/pyproject.toml")
43+
.add_local_file(str(ROOT / "README.md"), remote_path="/root/helico/README.md")
44+
)
45+
46+
app = modal.App("helico-probe-diffusion-samples", image=image)
47+
ckpt_volume = modal.Volume.from_name("helico-checkpoints", create_if_missing=True)
48+
49+
50+
@app.function(gpu="H100:1", timeout=1800, volumes={"/ckpts": ckpt_volume})
51+
def probe(crop_size: int = 384) -> list:
52+
import os, subprocess, gc, sys
53+
subprocess.run(
54+
"cd /root/helico && uv venv --python 3.11 && uv pip install -e .",
55+
check=True, shell=True,
56+
)
57+
sys.path.insert(0, "/root/helico/.venv/lib/python3.11/site-packages")
58+
sys.path.insert(0, "/root/helico/src")
59+
60+
import torch
61+
from helico.model import Helico, HelicoConfig
62+
from helico.data import make_synthetic_batch
63+
64+
# Match the production fine-tune knobs — full-size model with the
65+
# gh#9 swap + trunk frozen.
66+
cfg = HelicoConfig(diffusion_pair_source="distogram_logits", n_diffusion_samples=8)
67+
model = Helico(cfg).cuda()
68+
# Freeze trunk via the same helper used by training.
69+
from helico.train import _freeze_trunk
70+
_freeze_trunk(model)
71+
72+
results = []
73+
for n_d in (8, 16, 32, 64):
74+
# Override n_d on the model config — read by Helico.forward.
75+
cfg.n_diffusion_samples = n_d
76+
torch.cuda.empty_cache()
77+
torch.cuda.reset_peak_memory_stats()
78+
try:
79+
batch = make_synthetic_batch(n_tokens=crop_size, device="cuda")
80+
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
81+
out = model(batch, compute_confidence=False)
82+
loss = out["diffusion_loss"]
83+
loss.backward()
84+
peak_gb = torch.cuda.max_memory_allocated() / 1e9
85+
status = f"OK loss={loss.item():.3g}"
86+
except torch.cuda.OutOfMemoryError as e:
87+
peak_gb = float("nan")
88+
status = f"OOM {str(e)[:80]}"
89+
except Exception as e:
90+
peak_gb = float("nan")
91+
status = f"FAIL {type(e).__name__}: {str(e)[:80]}"
92+
# Drop grads + cache so the next iteration starts clean.
93+
for p in model.parameters():
94+
p.grad = None
95+
gc.collect()
96+
torch.cuda.empty_cache()
97+
results.append((n_d, peak_gb, status))
98+
print(f"n_d={n_d:3d}: peak={peak_gb:6.2f} GB {status}", flush=True)
99+
return results
100+
101+
102+
@app.local_entrypoint()
103+
def main(crop_size: int = 384):
104+
res = probe.remote(crop_size=crop_size)
105+
print("\n=== summary ===")
106+
print(f"{'n_d':>5} {'peak_GB':>8} status")
107+
for n_d, peak, status in res:
108+
peak_str = f"{peak:.2f}" if peak == peak else " —"
109+
print(f"{n_d:>5} {peak_str:>8} {status}")

0 commit comments

Comments
 (0)