Skip to content

Commit 213ef6e

Browse files
committed
removes torch.autocast and uses fsdp2 for autocasting to bf16
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent 50bc59f commit 213ef6e

4 files changed

Lines changed: 32 additions & 38 deletions

File tree

bionemo-recipes/recipes/esm2_minifold_te/eval_fsdp2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ def main(args: DictConfig) -> None:
179179
for batch in progress:
180180
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
181181

182-
with torch.autocast("cuda", dtype=torch.bfloat16):
183-
r_dict = model(batch, num_recycling=args.model.get("num_recycling", 0))
182+
r_dict = model(batch, num_recycling=args.model.get("num_recycling", 0))
184183

185184
# Distogram loss
186185
disto_loss = compute_distogram_loss(

bionemo-recipes/recipes/esm2_minifold_te/miniformer_te.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,11 @@ def _gate_ctx():
145145
# Apply mask
146146
x = x * mask.unsqueeze(-1)
147147

148-
# Triangular multiplication (MUST stay in FP32)
149-
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
150-
with torch.autocast(device.type, enabled=False):
151-
a1, b1, a2, b2 = torch.chunk(x.float(), 4, dim=-1)
152-
x1 = torch.einsum("bikd,bjkd->bijd", a1, b1)
153-
x2 = torch.einsum("bkid,bkjd->bijd", a2, b2)
154-
x = torch.cat([x1, x2], dim=-1).to(mask.dtype if mask.is_floating_point() else torch.float32)
148+
# Triangular multiplication (in FP32 via explicit .float() cast)
149+
a1, b1, a2, b2 = torch.chunk(x.float(), 4, dim=-1)
150+
x1 = torch.einsum("bikd,bjkd->bijd", a1, b1)
151+
x2 = torch.einsum("bkid,bkjd->bijd", a2, b2)
152+
x = torch.cat([x1, x2], dim=-1).to(mask.dtype if mask.is_floating_point() else torch.float32)
155153

156154
# Output gating: D/2 -> D
157155
x = te_layernorm_nd(self.output_norm, x)

bionemo-recipes/recipes/esm2_minifold_te/structure_te.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -323,32 +323,30 @@ def forward(self, s, z, aatype, mask):
323323
# Predict angles
324324
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
325325

326-
# Predict positions
327-
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
328-
with torch.autocast(device.type, enabled=False):
329-
n, ca, c = te_linear_nd(self.bb_update, s.float()).chunk(3, dim=-1)
330-
rigids = Rigid.make_transform_from_reference(n, ca, c, eps=1e-7)
331-
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
332-
333-
all_frames_to_global = torsion_angles_to_frames(scaled_rigids, angles, aatype, self.default_frames)
334-
pred_xyz = frames_and_literature_positions_to_atom14_pos(
335-
all_frames_to_global,
336-
aatype,
337-
self.default_frames,
338-
self.group_idx,
339-
self.atom_mask,
340-
self.lit_positions,
341-
)
342-
outputs.append(
343-
{
344-
"angles": angles,
345-
"unnormalized_angles": unnormalized_angles,
346-
"frames": scaled_rigids.to_tensor_4x4(),
347-
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
348-
"positions": pred_xyz,
349-
"states": s,
350-
}
351-
)
326+
# Predict positions (in FP32 via explicit .float() cast)
327+
n, ca, c = te_linear_nd(self.bb_update, s.float()).chunk(3, dim=-1)
328+
rigids = Rigid.make_transform_from_reference(n, ca, c, eps=1e-7)
329+
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
330+
331+
all_frames_to_global = torsion_angles_to_frames(scaled_rigids, angles, aatype, self.default_frames)
332+
pred_xyz = frames_and_literature_positions_to_atom14_pos(
333+
all_frames_to_global,
334+
aatype,
335+
self.default_frames,
336+
self.group_idx,
337+
self.atom_mask,
338+
self.lit_positions,
339+
)
340+
outputs.append(
341+
{
342+
"angles": angles,
343+
"unnormalized_angles": unnormalized_angles,
344+
"frames": scaled_rigids.to_tensor_4x4(),
345+
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
346+
"positions": pred_xyz,
347+
"states": s,
348+
}
349+
)
352350

353351
outputs = dict_multimap(torch.stack, outputs)
354352
outputs["single"] = s

bionemo-recipes/recipes/esm2_minifold_te/train_fsdp2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,8 @@ def main(args: DictConfig) -> float | None:
317317
for batch in train_dataloader:
318318
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
319319

320-
# Forward pass
321-
with torch.autocast("cuda", dtype=torch.bfloat16):
322-
r_dict = model(batch, num_recycling=args.model.get("num_recycling", 0))
320+
# Forward pass (BF16 handled by FSDP2 MixedPrecisionPolicy)
321+
r_dict = model(batch, num_recycling=args.model.get("num_recycling", 0))
323322

324323
# Compute distogram loss
325324
disto_loss = compute_distogram_loss(

0 commit comments

Comments
 (0)