Skip to content

Commit 8c0b5ee

Browse files
aymuos15ericspod
andauthored
Add GradientAccumulation utility for SupervisedTrainer (#8763)
## Summary - Adds `GradientAccumulation` callable class in `monai.engines.utils` for use as `iteration_update` in `SupervisedTrainer`, enabling gradient accumulation over multiple mini-batches to simulate larger effective batch sizes on memory-constrained hardware - Follows the callable-class `iteration_update` pattern established by `Interaction` in `monai.apps.deepedit` (as referenced by @wyli in #6101) - All `IterationEvents` fire every mini-batch, so existing handlers are unaffected - Epoch boundary flush ensures no gradients are silently discarded when `epoch_length % accumulation_steps != 0` - Mixed-precision (`GradScaler`) support included Closes #6100 Supersedes #6101 ### Usage ```python from monai.engines import SupervisedTrainer, GradientAccumulation trainer = SupervisedTrainer( ..., iteration_update=GradientAccumulation(accumulation_steps=4), ) ``` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] In-line docstrings updated. ## Test plan - [x] Input validation (zero, negative, float, string) - [x] Passthrough when `accumulation_steps=1` - [x] `zero_grad` / `optimizer.step` suppression patterns verified across full epochs - [x] Epoch boundary flush when `epoch_length` not divisible by `accumulation_steps` - [x] Iterable dataset (`epoch_length=None`) — no epoch flush - [x] Patching/restoration of all engine methods after each call - [x] Restoration after exception (`try/finally`) - [x] `GradScaler` patching when step suppressed, not patched when stepping - [x] No scaler attribute / `scaler=None` edge cases - [x] Batch data forwarded correctly to `_iteration` - [x] Output loss is unscaled (original value for loggers/metrics) - [x] Integration: gradient equivalence with manual accumulation (requires ignite) - [x] Integration: epoch boundary flush equivalence (requires ignite) - [x] Integration: multi-epoch correctness (requires ignite) --------- Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 199084a commit 8c0b5ee

File tree

2 files changed

+302
-6
lines changed

2 files changed

+302
-6
lines changed

monai/engines/trainer.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ class SupervisedTrainer(Trainer):
131131
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
132132
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
133133
https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
134+
accumulation_steps: number of mini-batches over which to accumulate gradients before
135+
calling ``optimizer.step()``, effectively simulating a larger batch size on
136+
memory-constrained hardware. Must be a positive integer. Default: 1 (no accumulation).
137+
When ``epoch_length`` is known and not divisible by ``accumulation_steps``, a flush
138+
(optimizer step) is performed at the end of each epoch so no gradients are silently
139+
discarded. The loss stored in ``engine.state.output`` is always the **unscaled** value.
134140
"""
135141

136142
def __init__(
@@ -160,7 +166,10 @@ def __init__(
160166
amp_kwargs: dict | None = None,
161167
compile: bool = False,
162168
compile_kwargs: dict | None = None,
169+
accumulation_steps: int = 1,
163170
) -> None:
171+
if accumulation_steps < 1:
172+
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
164173
super().__init__(
165174
device=device,
166175
max_epochs=max_epochs,
@@ -190,6 +199,7 @@ def __init__(
190199
self.loss_function = loss_function
191200
self.inferer = SimpleInferer() if inferer is None else inferer
192201
self.optim_set_to_none = optim_set_to_none
202+
self.accumulation_steps = accumulation_steps
193203

194204
def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]) -> dict:
195205
"""
@@ -245,21 +255,42 @@ def _compute_pred_loss():
245255
engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean()
246256
engine.fire_event(IterationEvents.LOSS_COMPLETED)
247257

258+
# Determine gradient accumulation state
259+
acc = engine.accumulation_steps
260+
if acc > 1:
261+
epoch_length = engine.state.epoch_length
262+
if epoch_length is not None:
263+
local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch
264+
should_zero_grad = local_iter % acc == 0
265+
should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length
266+
else:
267+
local_iter = engine.state.iteration - 1 # 0-indexed global
268+
should_zero_grad = local_iter % acc == 0
269+
should_step = (local_iter + 1) % acc == 0
270+
else:
271+
should_zero_grad = True
272+
should_step = True
273+
248274
engine.network.train()
249-
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
275+
if should_zero_grad:
276+
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
250277

251278
if engine.amp and engine.scaler is not None:
252279
with torch.autocast("cuda", **engine.amp_kwargs):
253280
_compute_pred_loss()
254-
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
281+
loss = engine.state.output[Keys.LOSS]
282+
engine.scaler.scale(loss / acc if acc > 1 else loss).backward()
255283
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
256-
engine.scaler.step(engine.optimizer)
257-
engine.scaler.update()
284+
if should_step:
285+
engine.scaler.step(engine.optimizer)
286+
engine.scaler.update()
258287
else:
259288
_compute_pred_loss()
260-
engine.state.output[Keys.LOSS].backward()
289+
loss = engine.state.output[Keys.LOSS]
290+
(loss / acc if acc > 1 else loss).backward()
261291
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
262-
engine.optimizer.step()
292+
if should_step:
293+
engine.optimizer.step()
263294
# copy back meta info
264295
if self.compile:
265296
if inputs_meta is not None:
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
import torch.nn as nn
18+
from parameterized import parameterized
19+
20+
from monai.utils import IgniteInfo, min_version, optional_import
21+
from monai.utils.enums import CommonKeys
22+
23+
_, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version)
24+
25+
INVALID_ACCUMULATION_STEPS = [(0,), (-1,)]
26+
27+
28+
def _make_model_pair(lr):
29+
"""Create a reference and test model pair with identical initial weights."""
30+
ref_model = nn.Linear(4, 1, bias=False)
31+
init_weight = ref_model.weight.data.clone()
32+
ref_opt = torch.optim.SGD(ref_model.parameters(), lr=lr)
33+
ref_model.train()
34+
35+
test_model = nn.Linear(4, 1, bias=False)
36+
test_model.weight.data.copy_(init_weight)
37+
test_opt = torch.optim.SGD(test_model.parameters(), lr=lr)
38+
39+
return ref_model, test_model, ref_opt, test_opt, init_weight
40+
41+
42+
@unittest.skipUnless(has_ignite, "Requires pytorch-ignite")
43+
class TestGradientAccumulation(unittest.TestCase):
44+
"""Test gradient accumulation integrated into SupervisedTrainer."""
45+
46+
# ---- input validation ----
47+
48+
@parameterized.expand(INVALID_ACCUMULATION_STEPS)
49+
def test_invalid_accumulation_steps(self, value) -> None:
50+
from monai.engines import SupervisedTrainer
51+
52+
with self.assertRaises(ValueError) as cm:
53+
SupervisedTrainer(
54+
device=torch.device("cpu"),
55+
max_epochs=1,
56+
train_data_loader=[{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}],
57+
network=nn.Linear(4, 1),
58+
optimizer=torch.optim.SGD(nn.Linear(4, 1).parameters(), lr=0.1),
59+
loss_function=nn.MSELoss(),
60+
accumulation_steps=value,
61+
)
62+
self.assertIn("positive integer", str(cm.exception))
63+
64+
# ---- passthrough (accumulation_steps=1) ----
65+
66+
def test_passthrough_when_accumulation_steps_1(self) -> None:
67+
"""With accumulation_steps=1, behaviour is identical to default training."""
68+
from monai.engines import SupervisedTrainer
69+
70+
torch.manual_seed(42)
71+
lr = 0.1
72+
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(4)]
73+
74+
ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr)
75+
76+
# Reference: standard training loop
77+
for batch in batches:
78+
ref_opt.zero_grad()
79+
loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean()
80+
loss.backward()
81+
ref_opt.step()
82+
83+
trainer = SupervisedTrainer(
84+
device=torch.device("cpu"),
85+
max_epochs=1,
86+
train_data_loader=batches,
87+
network=test_model,
88+
optimizer=test_opt,
89+
loss_function=nn.MSELoss(),
90+
accumulation_steps=1,
91+
)
92+
trainer.run()
93+
94+
for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()):
95+
torch.testing.assert_close(p_test.data, p_ref.data)
96+
97+
# ---- gradient equivalence ----
98+
99+
def test_gradient_equivalence(self) -> None:
100+
"""Accumulated gradients over N mini-batches equal one large-batch step."""
101+
from monai.engines import SupervisedTrainer
102+
103+
torch.manual_seed(42)
104+
acc_steps, lr = 4, 0.1
105+
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)]
106+
107+
ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr)
108+
109+
# Reference: manual accumulation
110+
ref_opt.zero_grad()
111+
for batch in batches:
112+
loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() / acc_steps
113+
loss.backward()
114+
ref_opt.step()
115+
116+
trainer = SupervisedTrainer(
117+
device=torch.device("cpu"),
118+
max_epochs=1,
119+
train_data_loader=batches,
120+
network=test_model,
121+
optimizer=test_opt,
122+
loss_function=nn.MSELoss(),
123+
accumulation_steps=acc_steps,
124+
)
125+
trainer.run()
126+
127+
for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()):
128+
torch.testing.assert_close(p_test.data, p_ref.data)
129+
130+
# ---- epoch boundary flush ----
131+
132+
def test_epoch_boundary_flush(self) -> None:
133+
"""When epoch_length is not divisible by acc_steps, flush at epoch end."""
134+
from monai.engines import SupervisedTrainer
135+
136+
torch.manual_seed(123)
137+
acc_steps, lr = 3, 0.1
138+
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(5)]
139+
140+
ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr)
141+
142+
# Reference: first 3 batches form one cycle, last 2 form a partial cycle flushed at epoch end
143+
for cycle_batches in [batches[:3], batches[3:]]:
144+
ref_opt.zero_grad()
145+
for batch in cycle_batches:
146+
loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() / acc_steps
147+
loss.backward()
148+
ref_opt.step()
149+
150+
trainer = SupervisedTrainer(
151+
device=torch.device("cpu"),
152+
max_epochs=1,
153+
train_data_loader=batches,
154+
network=test_model,
155+
optimizer=test_opt,
156+
loss_function=nn.MSELoss(),
157+
accumulation_steps=acc_steps,
158+
)
159+
trainer.run()
160+
161+
for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()):
162+
torch.testing.assert_close(p_test.data, p_ref.data)
163+
164+
# ---- multi-epoch ----
165+
166+
def test_multi_epoch(self) -> None:
167+
"""Verify gradient accumulation is correct across multiple epochs."""
168+
from monai.engines import SupervisedTrainer
169+
170+
torch.manual_seed(42)
171+
acc_steps, lr, num_epochs = 2, 0.1, 3
172+
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(4)]
173+
174+
ref_model, test_model, ref_opt, test_opt, _ = _make_model_pair(lr)
175+
176+
# Reference: manual multi-epoch accumulation
177+
for _epoch in range(num_epochs):
178+
for cycle_batches in [batches[:2], batches[2:]]:
179+
ref_opt.zero_grad()
180+
for batch in cycle_batches:
181+
loss = nn.MSELoss()(ref_model(batch[CommonKeys.IMAGE]), batch[CommonKeys.LABEL]).mean() / acc_steps
182+
loss.backward()
183+
ref_opt.step()
184+
185+
trainer = SupervisedTrainer(
186+
device=torch.device("cpu"),
187+
max_epochs=num_epochs,
188+
train_data_loader=batches,
189+
network=test_model,
190+
optimizer=test_opt,
191+
loss_function=nn.MSELoss(),
192+
accumulation_steps=acc_steps,
193+
)
194+
trainer.run()
195+
196+
for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()):
197+
torch.testing.assert_close(p_test.data, p_ref.data)
198+
199+
# ---- loss output is unscaled ----
200+
201+
def test_loss_output_is_unscaled(self) -> None:
202+
"""engine.state.output[LOSS] should be the unscaled loss, not loss/acc."""
203+
from monai.engines import SupervisedTrainer
204+
205+
torch.manual_seed(42)
206+
acc_steps = 4
207+
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)]
208+
209+
model = nn.Linear(4, 1, bias=False)
210+
opt = torch.optim.SGD(model.parameters(), lr=0.1)
211+
212+
trainer = SupervisedTrainer(
213+
device=torch.device("cpu"),
214+
max_epochs=1,
215+
train_data_loader=batches,
216+
network=model,
217+
optimizer=opt,
218+
loss_function=nn.MSELoss(),
219+
accumulation_steps=acc_steps,
220+
decollate=False,
221+
)
222+
trainer.run()
223+
224+
# The output loss should be the full (unscaled) loss value, not divided by acc_steps
225+
output_loss = trainer.state.output[CommonKeys.LOSS].item()
226+
self.assertGreater(output_loss, 0.0)
227+
228+
# ---- accumulation_steps attribute ----
229+
230+
def test_accumulation_steps_stored(self) -> None:
231+
"""Verify the accumulation_steps attribute is accessible on the trainer."""
232+
from monai.engines import SupervisedTrainer
233+
234+
model = nn.Linear(4, 1)
235+
trainer = SupervisedTrainer(
236+
device=torch.device("cpu"),
237+
max_epochs=1,
238+
train_data_loader=[{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}],
239+
network=model,
240+
optimizer=torch.optim.SGD(model.parameters(), lr=0.1),
241+
loss_function=nn.MSELoss(),
242+
accumulation_steps=8,
243+
)
244+
self.assertEqual(trainer.accumulation_steps, 8)
245+
246+
# ---- default is no accumulation ----
247+
248+
def test_default_no_accumulation(self) -> None:
249+
"""Default accumulation_steps=1 means no accumulation."""
250+
from monai.engines import SupervisedTrainer
251+
252+
model = nn.Linear(4, 1)
253+
trainer = SupervisedTrainer(
254+
device=torch.device("cpu"),
255+
max_epochs=1,
256+
train_data_loader=[{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}],
257+
network=model,
258+
optimizer=torch.optim.SGD(model.parameters(), lr=0.1),
259+
loss_function=nn.MSELoss(),
260+
)
261+
self.assertEqual(trainer.accumulation_steps, 1)
262+
263+
264+
if __name__ == "__main__":
265+
unittest.main()

0 commit comments

Comments
 (0)