Skip to content

Commit bcb127b

Browse files
authored
Add gradient accumulation to Llama3 recipe (#1386)
### Description Implements gradient accumulation for the Llama3 Native TE recipe, following the pattern from ESM2 PR #1254. This enables training with larger effective batch sizes without increasing GPU memory usage by accumulating gradients across multiple microbatches before performing an optimizer step. **Key Changes:** - **perf_logger.py**: Added `log_micro_step()` method to track metrics across microbatches, updated `log_step()` signature to use accumulated metrics, added configurable `pad_token_id` parameter (defaults to 1) - **train_ddp.py**: Implemented gradient accumulation loop with `model.no_sync()` for efficiency, added validation for `grad_acc_steps >= 1` - **train_fsdp2.py**: Implemented gradient accumulation loop (without `model.no_sync()` as FSDP2 handles synchronization internally), added validation - **defaults.yaml**: Added `grad_acc_steps` parameter (default: 1 for backward compatibility) - **test_gradient_accumulation.py**: Added golden value test that validates mathematical correctness of gradient accumulation **Validation:** Lingua1B DCLM Benchmark trained with Grad_Acc=4, 2 Nodes, MBS=4 -> GBS=256 https://api.wandb.ai/links/clara-discovery/5laqf4gm Has matching loss curve: <img width="3268" height="1454" alt="image" src="https://github.com/user-attachments/assets/55e5505f-5527-4a11-9a0d-9958eea046f0" /> DDP Results: https://api.wandb.ai/links/clara-discovery/6ncxn9n4 - DDP Training Loss curves for single node & 4 node training runs are similar with varying levels of gradient accumulation (grad acc=1, grad acc=2, grad acc=4) for a mbs=4: <img width="1260" height="641" alt="image" src="https://github.com/user-attachments/assets/02e610a7-704a-469b-97c0-fd6615c35cea" /> FSDP2 Results: https://api.wandb.ai/links/clara-discovery/lcvrsgm8 - FSDP2 Training Loss Curves for single node and 4 node training runs are similar with and without gradient accumulation: <img width="1265" height="627" alt="image" src="https://github.com/user-attachments/assets/0576bb6f-de0b-47b6-b305-9437366dd451" /> Golden value test confirms that `micro_batch=1, grad_acc=2` produces mathematically identical gradients to `micro_batch=2, grad_acc=1`. **References:** Adapts the gradient accumulation implementation from ESM2: #1254 #### Usage ##### Without gradient accumulation (default, backward compatible) python train_fsdp2.py --config-name L2_lingua_1b ##### With gradient accumulation (reduce memory usage) python train_fsdp2.py \ --config-name L2_lingua_1b \ dataset.micro_batch_size=2 \ grad_acc_steps=2 ##### Effective batch size formula: effective_batch = micro_batch_size × num_gpus × grad_acc_steps ##### Example: 2 × 16 × 2 = 64 samples per optimizer step**Benefits:** - Enables larger effective batch sizes on memory-constrained GPUs - Allows training larger models by reducing micro batch size - Maintains identical training dynamics to larger microbatches - Backward compatible: `grad_acc_steps=1` behaves as before #### Type of changes - [x] New feature (non-breaking change which adds functionality) ### CI Pipeline Configuration - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes to validate gradient accumulation doesn't break existing tests ### Pre-submit Checklist - [x] I have tested these changes locally (single-GPU validation on SLURM) - [x] I have updated the documentation accordingly (inline comments, test docstrings) - [x] I have added/updated tests as needed (test_gradient_accumulation.py with golden value test) - [x] All existing tests pass successfully (pre-commit hooks pass, golden value test passes) ### Testing Notes **Golden Value Test:** pytest bionemo-recipes/recipes/llama3_native_te/tests/test_gradient_accumulation.py -vValidates that gradient accumulation produces mathematically equivalent gradients by comparing: - Loss values (within 1% tolerance) - Gradient norms (within 1% tolerance) - Individual parameter gradients (within 0.1% tolerance) **Integration Testing:** Testing with Lingua-1B benchmark on DCLM dataset - loss curves match --------- Signed-off-by: savitha-eng <savithas@nvidia.com> Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
1 parent 4873914 commit bcb127b

5 files changed

Lines changed: 196 additions & 88 deletions

File tree

bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ config_name_or_path: ??? # E.g., meta-llama/Llama-3.2-1B or ./model_configs/meta
44
config_kwargs: {}
55

66
num_train_steps: ???
7+
grad_acc_steps: 1 # Gradient accumulation steps - effective batch = micro_batch_size * num_gpus * grad_acc_steps
78

89
use_meta_device: true
910

bionemo-recipes/recipes/llama3_native_te/perf_logger.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
5151
self.min_loss = float("inf")
5252

5353
self.logging_frequency = args.logger.frequency
54-
# Track whether to collect memory stats (disabled by default for max performance)
5554

5655
metrics_dict = {
5756
"train/loss": torchmetrics.MeanMetric(),
@@ -80,44 +79,64 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
8079
self._profiler = setup_profiler(args, self._wandb_run)
8180
self._profiler.__enter__()
8281

82+
# Gradient accumulation tracking
83+
self.num_tokens = 0
84+
self.num_unpadded_tokens = 0
85+
self.running_loss = 0.0
86+
self.grad_acc_step_count = 0
87+
88+
def log_micro_step(self, batch: dict[str, torch.Tensor], outputs: CausalLMOutputWithPast):
89+
"""Store data on micro step for gradient accumulation metrics.
90+
91+
Args:
92+
batch: The batch of data for the micro step.
93+
outputs: The outputs of the micro step.
94+
"""
95+
self.grad_acc_step_count += 1
96+
self.num_tokens += batch["input_ids"].numel()
97+
# Use attention_mask to count unpadded tokens (works for both BSHD and THD)
98+
if "attention_mask" in batch:
99+
self.num_unpadded_tokens += batch["attention_mask"].sum().item()
100+
else:
101+
# Fallback for pure sequence packing with no padding: all tokens are unpadded
102+
self.num_unpadded_tokens += batch["input_ids"].numel()
103+
self.running_loss += outputs.loss.item()
104+
83105
def log_step(
84106
self,
85107
step: int,
86-
batch: dict[str, torch.Tensor],
87-
outputs: CausalLMOutputWithPast,
88108
grad_norm: float,
89109
lr: float,
90110
):
91111
"""Log a step to the logger and wandb.
92112
93113
Args:
94114
step: The step number.
95-
batch: The batch of data for the step.
96-
outputs: The outputs of the step.
97115
grad_norm: The gradient norm of the step.
98116
lr: The learning rate of the step.
99117
"""
100-
num_tokens = batch["input_ids"].numel()
101-
if "attention_mask" in batch:
102-
num_unpadded_tokens = batch["attention_mask"].sum().item()
103-
else:
104-
num_unpadded_tokens = num_tokens
105-
106-
self.min_loss = min(self.min_loss, outputs.loss.item())
118+
# Use accumulated metrics from gradient accumulation
119+
assert self.grad_acc_step_count > 0, (
120+
f"Gradient accumulation steps ({self.grad_acc_step_count}) must be greater than 0, "
121+
f"and can be incremented by log_micro_step()."
122+
)
123+
124+
avg_loss = self.running_loss / self.grad_acc_step_count
125+
self.min_loss = min(self.min_loss, avg_loss)
107126
step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter()
108127

109-
self.metrics["train/loss"].update(outputs.loss)
128+
self.metrics["train/loss"].update(avg_loss)
110129
self.metrics["train/learning_rate"].update(lr)
111130
self.metrics["train/grad_norm"].update(grad_norm)
112131
self.metrics["train/step_time"].update(step_time)
113-
self.metrics["train/tokens_per_second_per_gpu"].update(num_tokens / step_time)
114-
self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time)
115-
self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens / self.logging_frequency)
132+
self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time)
133+
self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time)
134+
self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens / self.logging_frequency)
116135

117136
if self._profiler is not None:
118137
self._profiler.step()
119138

120-
if (step + 1) % self.logging_frequency == 0:
139+
if step % self.logging_frequency == 0 and step > 0:
121140
memory_allocated = torch.cuda.memory_allocated() / (1024**3)
122141
self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated)
123142
self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated)
@@ -129,11 +148,17 @@ def log_step(
129148
if self._dist_config.is_main_process():
130149
wandb.log(metrics, step=step)
131150
self._progress_bar.update(self.logging_frequency)
132-
self._progress_bar.set_postfix({"loss": outputs.loss.item()})
151+
self._progress_bar.set_postfix({"loss": avg_loss})
133152

134153
if self._dist_config.local_rank == 0:
135154
logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()]))
136155

156+
# Reset gradient accumulation tracking for next step
157+
self.num_tokens = 0
158+
self.num_unpadded_tokens = 0
159+
self.running_loss = 0.0
160+
self.grad_acc_step_count = 0
161+
137162
def finish(self):
138163
"""Finish the logger and close the progress bar."""
139164
if self._profiler is not None:

bionemo-recipes/recipes/llama3_native_te/tests/test_train.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,26 @@ def test_sanity_convergence_ddp_te(tmp_path, recipe_path):
6262
assert final_loss < 2.0, f"Final loss {final_loss} is too high, expected < 2.0"
6363

6464

65+
def test_sanity_convergence_ddp_te_grad_acc(tmp_path, recipe_path):
66+
"""Test DDP training with gradient accumulation."""
67+
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
68+
sanity_config = compose(
69+
config_name="L0_sanity",
70+
overrides=[
71+
f"+wandb.dir={tmp_path}",
72+
f"checkpoint.ckpt_dir={tmp_path}",
73+
"checkpoint.resume_from_checkpoint=false",
74+
"grad_acc_steps=2",
75+
],
76+
)
77+
78+
final_loss = main_ddp(sanity_config)
79+
gc.collect()
80+
torch.cuda.empty_cache()
81+
82+
assert final_loss < 2.0, f"Final loss {final_loss} is too high, expected < 2.0"
83+
84+
6585
def test_sanity_convergence_ddp_hf(tmp_path, recipe_path):
6686
"""Test that DDP training converges on mock genomic data.
6787
@@ -146,6 +166,50 @@ def test_sanity_convergence_fsdp2_te_thd(tmp_path, recipe_path):
146166
assert final_loss < 2.0, f"Final loss {final_loss} is too high, expected < 2.0"
147167

148168

169+
def test_sanity_convergence_fsdp2_te_bshd_grad_acc(tmp_path, recipe_path):
170+
"""Test FSDP2 training with BSHD format and gradient accumulation."""
171+
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
172+
sanity_config = compose(
173+
config_name="L0_sanity",
174+
overrides=[
175+
f"+wandb.dir={tmp_path}",
176+
f"checkpoint.ckpt_dir={tmp_path}",
177+
"checkpoint.resume_from_checkpoint=false",
178+
"config_kwargs.attn_input_format=bshd",
179+
"grad_acc_steps=2",
180+
],
181+
)
182+
183+
final_loss = main_fsdp2(sanity_config)
184+
gc.collect()
185+
torch.cuda.empty_cache()
186+
187+
assert final_loss < 2.0, f"Final loss {final_loss} is too high, expected < 2.0"
188+
189+
190+
def test_sanity_convergence_fsdp2_te_thd_grad_acc(tmp_path, recipe_path):
191+
"""Test FSDP2 training with THD format and gradient accumulation."""
192+
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
193+
sanity_config = compose(
194+
config_name="L0_sanity",
195+
overrides=[
196+
f"+wandb.dir={tmp_path}",
197+
f"checkpoint.ckpt_dir={tmp_path}",
198+
"checkpoint.resume_from_checkpoint=false",
199+
"use_sequence_packing=true",
200+
"config_kwargs.attn_input_format=thd",
201+
"dataset.max_seq_length=1024",
202+
"grad_acc_steps=2",
203+
],
204+
)
205+
206+
final_loss = main_fsdp2(sanity_config)
207+
gc.collect()
208+
torch.cuda.empty_cache()
209+
210+
assert final_loss < 2.0, f"Final loss {final_loss} is too high, expected < 2.0"
211+
212+
149213
def test_sanity_convergence_fsdp2_hf(tmp_path, recipe_path):
150214
"""Test that FSDP2 training converges on mock genomic data.
151215

bionemo-recipes/recipes/llama3_native_te/train_ddp.py

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import logging
17+
from contextlib import nullcontext
1718
from pathlib import Path
1819

1920
import hydra
@@ -119,50 +120,59 @@ def main(args: DictConfig) -> float | None:
119120

120121
# Training loop
121122
step = start_step
123+
micro_step = 0
122124
while step < args.num_train_steps:
123125
for batch in train_dataloader:
124126
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa PLW2901
125127

126-
# Forward pass with mixed precision.
127-
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
128-
outputs = model(**batch)
129-
130-
# Backward pass.
131-
loss = outputs.loss
132-
loss.backward()
133-
134-
# Compute and clip gradient norms.
135-
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
136-
137-
# Step optimizer.
138-
optimizer.step()
139-
scheduler.step()
140-
optimizer.zero_grad()
141-
142-
perf_logger.log_step(
143-
step=step,
144-
batch=batch,
145-
outputs=outputs,
146-
grad_norm=total_norm,
147-
lr=optimizer.param_groups[0]["lr"],
148-
)
149-
150-
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
151-
save_checkpoint_ddp(
152-
model=model,
153-
optimizer=optimizer,
154-
scheduler=scheduler,
155-
ckpt_path=ckpt_path,
128+
micro_step += 1
129+
# Use no_sync to prevent gradient synchronization until the last microbatch
130+
with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext():
131+
# Forward pass with mixed precision.
132+
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
133+
outputs = model(**batch)
134+
135+
# Backward pass - scale loss by grad_acc_steps for proper gradient averaging
136+
loss = outputs.loss / args.grad_acc_steps
137+
loss.backward()
138+
139+
# Log microbatch step data for accumulation metrics
140+
perf_logger.log_micro_step(batch=batch, outputs=outputs)
141+
142+
# Gradient accumulation - only step optimizer after accumulating gradients
143+
if micro_step % args.grad_acc_steps == 0:
144+
micro_step = 0
145+
146+
# Compute and clip gradient norms.
147+
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
148+
149+
# Step optimizer.
150+
optimizer.step()
151+
scheduler.step()
152+
optimizer.zero_grad()
153+
154+
perf_logger.log_step(
156155
step=step,
157-
epoch=epoch,
158-
dist_config=dist_config,
159-
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
160-
max_checkpoints=args.checkpoint.max_checkpoints,
156+
grad_norm=total_norm,
157+
lr=optimizer.param_groups[0]["lr"],
161158
)
162159

163-
step += 1
164-
if step >= args.num_train_steps:
165-
break
160+
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
161+
save_checkpoint_ddp(
162+
model=model,
163+
optimizer=optimizer,
164+
scheduler=scheduler,
165+
ckpt_path=ckpt_path,
166+
step=step,
167+
epoch=epoch,
168+
dist_config=dist_config,
169+
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
170+
max_checkpoints=args.checkpoint.max_checkpoints,
171+
)
172+
173+
step += 1
174+
if step >= args.num_train_steps:
175+
break
166176

167177
# Dataloader exhausted, incrementing epoch
168178
epoch += 1

bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -134,52 +134,60 @@ def main(args: DictConfig) -> float | None:
134134
# Training loop
135135
logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}")
136136
step = start_step
137+
micro_step = 0
137138
while step < args.num_train_steps:
138139
for batch in train_dataloader:
139140
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901
140141

142+
micro_step += 1
143+
141144
# Forward pass with mixed precision.
142145
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
143146
outputs = model(**batch)
144147

145-
# Backward pass.
146-
loss = outputs.loss
148+
# Backward pass - scale loss by grad_acc_steps for proper gradient averaging
149+
loss = outputs.loss / args.grad_acc_steps
147150
loss.backward()
148151

149-
# Compute and clip gradient norms.
150-
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
151-
152-
# Step optimizer.
153-
optimizer.step()
154-
scheduler.step()
155-
optimizer.zero_grad()
156-
157-
perf_logger.log_step(
158-
step=step,
159-
batch=batch,
160-
outputs=outputs,
161-
grad_norm=total_norm,
162-
lr=optimizer.param_groups[0]["lr"],
163-
)
164-
165-
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
166-
save_checkpoint_fsdp2(
167-
model=model,
168-
optimizer=optimizer,
169-
scheduler=scheduler,
170-
ckpt_path=ckpt_path,
152+
# Log microbatch step data for accumulation metrics
153+
perf_logger.log_micro_step(batch=batch, outputs=outputs)
154+
155+
# Gradient accumulation - only step optimizer after accumulating gradients
156+
if micro_step % args.grad_acc_steps == 0:
157+
micro_step = 0
158+
159+
# Compute and clip gradient norms.
160+
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
161+
162+
# Step optimizer.
163+
optimizer.step()
164+
scheduler.step()
165+
optimizer.zero_grad()
166+
167+
perf_logger.log_step(
171168
step=step,
172-
epoch=epoch,
173-
dist_config=dist_config,
174-
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
175-
process_group=device_mesh.get_group("dp"),
176-
max_checkpoints=args.checkpoint.max_checkpoints,
177-
async_save=args.checkpoint.async_save,
169+
grad_norm=total_norm,
170+
lr=optimizer.param_groups[0]["lr"],
178171
)
179172

180-
step += 1
181-
if step >= args.num_train_steps:
182-
break
173+
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
174+
save_checkpoint_fsdp2(
175+
model=model,
176+
optimizer=optimizer,
177+
scheduler=scheduler,
178+
ckpt_path=ckpt_path,
179+
step=step,
180+
epoch=epoch,
181+
dist_config=dist_config,
182+
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
183+
process_group=device_mesh.get_group("dp"),
184+
max_checkpoints=args.checkpoint.max_checkpoints,
185+
async_save=args.checkpoint.async_save,
186+
)
187+
188+
step += 1
189+
if step >= args.num_train_steps:
190+
break
183191

184192
# Dataloader exhausted, incrementing epoch
185193
epoch += 1

0 commit comments

Comments
 (0)