Commit bcb127b
authored
Add gradient accumulation to Llama3 recipe (#1386)
### Description
Implements gradient accumulation for the Llama3 Native TE recipe,
following the pattern from ESM2 PR #1254.
This enables training with larger effective batch sizes without
increasing GPU memory usage by accumulating gradients across multiple
microbatches before performing an optimizer step.
**Key Changes:**
- **perf_logger.py**: Added `log_micro_step()` method to track metrics
across microbatches, updated `log_step()` signature to use accumulated
metrics, added configurable `pad_token_id` parameter (defaults to 1)
- **train_ddp.py**: Implemented gradient accumulation loop with
`model.no_sync()` for efficiency, added validation for `grad_acc_steps
>= 1`
- **train_fsdp2.py**: Implemented gradient accumulation loop (without
`model.no_sync()` as FSDP2 handles synchronization internally), added
validation
- **defaults.yaml**: Added `grad_acc_steps` parameter (default: 1 for
backward compatibility)
- **test_gradient_accumulation.py**: Added golden value test that
validates mathematical correctness of gradient accumulation
**Validation:**
Lingua1B DCLM Benchmark trained with Grad_Acc=4, 2 Nodes, MBS=4 ->
GBS=256 https://api.wandb.ai/links/clara-discovery/5laqf4gm
Has matching loss curve:
<img width="3268" height="1454" alt="image"
src="https://github.com/user-attachments/assets/55e5505f-5527-4a11-9a0d-9958eea046f0"
/>
DDP Results: https://api.wandb.ai/links/clara-discovery/6ncxn9n4
- DDP Training Loss curves for single node & 4 node training runs are
similar with varying levels of gradient accumulation (grad acc=1, grad
acc=2, grad acc=4) for a mbs=4:
<img width="1260" height="641" alt="image"
src="https://github.com/user-attachments/assets/02e610a7-704a-469b-97c0-fd6615c35cea"
/>
FSDP2 Results: https://api.wandb.ai/links/clara-discovery/lcvrsgm8
- FSDP2 Training Loss Curves for single node and 4 node training runs
are similar with and without gradient accumulation:
<img width="1265" height="627" alt="image"
src="https://github.com/user-attachments/assets/0576bb6f-de0b-47b6-b305-9437366dd451"
/>
Golden value test confirms that `micro_batch=1, grad_acc=2` produces
mathematically identical gradients to `micro_batch=2, grad_acc=1`.
**References:**
Adapts the gradient accumulation implementation from ESM2: #1254
#### Usage
##### Without gradient accumulation (default, backward compatible)
python train_fsdp2.py --config-name L2_lingua_1b
##### With gradient accumulation (reduce memory usage)
python train_fsdp2.py \
--config-name L2_lingua_1b \
dataset.micro_batch_size=2 \
grad_acc_steps=2
##### Effective batch size formula:
effective_batch = micro_batch_size × num_gpus × grad_acc_steps
##### Example: 2 × 16 × 2 = 64 samples per optimizer step**Benefits:**
- Enables larger effective batch sizes on memory-constrained GPUs
- Allows training larger models by reducing micro batch size
- Maintains identical training dynamics to larger microbatches
- Backward compatible: `grad_acc_steps=1` behaves as before
#### Type of changes
- [x] New feature (non-breaking change which adds functionality)
### CI Pipeline Configuration
-
[ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes)
- Run tests for all recipes to validate gradient accumulation doesn't
break existing tests
### Pre-submit Checklist
- [x] I have tested these changes locally (single-GPU validation on
SLURM)
- [x] I have updated the documentation accordingly (inline comments,
test docstrings)
- [x] I have added/updated tests as needed
(test_gradient_accumulation.py with golden value test)
- [x] All existing tests pass successfully (pre-commit hooks pass,
golden value test passes)
### Testing Notes
**Golden Value Test:**
pytest
bionemo-recipes/recipes/llama3_native_te/tests/test_gradient_accumulation.py
-vValidates that gradient accumulation produces mathematically
equivalent gradients by comparing:
- Loss values (within 1% tolerance)
- Gradient norms (within 1% tolerance)
- Individual parameter gradients (within 0.1% tolerance)
**Integration Testing:**
Testing with Lingua-1B benchmark on DCLM dataset - loss curves match
---------
Signed-off-by: savitha-eng <savithas@nvidia.com>
Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>1 parent 4873914 commit bcb127b
5 files changed
Lines changed: 196 additions & 88 deletions
File tree
- bionemo-recipes/recipes/llama3_native_te
- hydra_config
- tests
Lines changed: 1 addition & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| 7 | + | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
51 | 51 | | |
52 | 52 | | |
53 | 53 | | |
54 | | - | |
55 | 54 | | |
56 | 55 | | |
57 | 56 | | |
| |||
80 | 79 | | |
81 | 80 | | |
82 | 81 | | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
83 | 105 | | |
84 | 106 | | |
85 | 107 | | |
86 | | - | |
87 | | - | |
88 | 108 | | |
89 | 109 | | |
90 | 110 | | |
91 | 111 | | |
92 | 112 | | |
93 | 113 | | |
94 | 114 | | |
95 | | - | |
96 | | - | |
97 | 115 | | |
98 | 116 | | |
99 | 117 | | |
100 | | - | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
106 | | - | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
107 | 126 | | |
108 | 127 | | |
109 | | - | |
| 128 | + | |
110 | 129 | | |
111 | 130 | | |
112 | 131 | | |
113 | | - | |
114 | | - | |
115 | | - | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
116 | 135 | | |
117 | 136 | | |
118 | 137 | | |
119 | 138 | | |
120 | | - | |
| 139 | + | |
121 | 140 | | |
122 | 141 | | |
123 | 142 | | |
| |||
129 | 148 | | |
130 | 149 | | |
131 | 150 | | |
132 | | - | |
| 151 | + | |
133 | 152 | | |
134 | 153 | | |
135 | 154 | | |
136 | 155 | | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
137 | 162 | | |
138 | 163 | | |
139 | 164 | | |
| |||
Lines changed: 64 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
62 | 62 | | |
63 | 63 | | |
64 | 64 | | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
65 | 85 | | |
66 | 86 | | |
67 | 87 | | |
| |||
146 | 166 | | |
147 | 167 | | |
148 | 168 | | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
149 | 213 | | |
150 | 214 | | |
151 | 215 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| 17 | + | |
17 | 18 | | |
18 | 19 | | |
19 | 20 | | |
| |||
119 | 120 | | |
120 | 121 | | |
121 | 122 | | |
| 123 | + | |
122 | 124 | | |
123 | 125 | | |
124 | 126 | | |
125 | 127 | | |
126 | | - | |
127 | | - | |
128 | | - | |
129 | | - | |
130 | | - | |
131 | | - | |
132 | | - | |
133 | | - | |
134 | | - | |
135 | | - | |
136 | | - | |
137 | | - | |
138 | | - | |
139 | | - | |
140 | | - | |
141 | | - | |
142 | | - | |
143 | | - | |
144 | | - | |
145 | | - | |
146 | | - | |
147 | | - | |
148 | | - | |
149 | | - | |
150 | | - | |
151 | | - | |
152 | | - | |
153 | | - | |
154 | | - | |
155 | | - | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
156 | 155 | | |
157 | | - | |
158 | | - | |
159 | | - | |
160 | | - | |
| 156 | + | |
| 157 | + | |
161 | 158 | | |
162 | 159 | | |
163 | | - | |
164 | | - | |
165 | | - | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
166 | 176 | | |
167 | 177 | | |
168 | 178 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
134 | 134 | | |
135 | 135 | | |
136 | 136 | | |
| 137 | + | |
137 | 138 | | |
138 | 139 | | |
139 | 140 | | |
140 | 141 | | |
| 142 | + | |
| 143 | + | |
141 | 144 | | |
142 | 145 | | |
143 | 146 | | |
144 | 147 | | |
145 | | - | |
146 | | - | |
| 148 | + | |
| 149 | + | |
147 | 150 | | |
148 | 151 | | |
149 | | - | |
150 | | - | |
151 | | - | |
152 | | - | |
153 | | - | |
154 | | - | |
155 | | - | |
156 | | - | |
157 | | - | |
158 | | - | |
159 | | - | |
160 | | - | |
161 | | - | |
162 | | - | |
163 | | - | |
164 | | - | |
165 | | - | |
166 | | - | |
167 | | - | |
168 | | - | |
169 | | - | |
170 | | - | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
171 | 168 | | |
172 | | - | |
173 | | - | |
174 | | - | |
175 | | - | |
176 | | - | |
177 | | - | |
| 169 | + | |
| 170 | + | |
178 | 171 | | |
179 | 172 | | |
180 | | - | |
181 | | - | |
182 | | - | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
183 | 191 | | |
184 | 192 | | |
185 | 193 | | |
| |||
0 commit comments