Skip to content

Commit 0a71b92

Browse files
tohtanastas00
authored andcommitted
Add document section explaining autocast nesting (#7883)
Add a document section clarifying the behavior of nesting autocast and why/when we need it. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Signed-off-by: nathon-lee <leejianwoo@gmail.com>
1 parent a8d2fc6 commit 0a71b92

1 file changed

Lines changed: 29 additions & 0 deletions

File tree

docs/code-docs/source/training.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,35 @@ If you call ``loss.backward()`` directly without using ``engine.scale()`` or ``e
119119
will raise a ``RuntimeError`` to prevent training with unscaled gradients, which can lead to incorrect results
120120
or gradient underflow.
121121

122+
Using torch.autocast Outside the Engine
123+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
124+
125+
DeepSpeed applies ``torch.autocast`` internally during ``engine.forward()``.
126+
However, you may also want autocast to cover code that runs **outside** the engine,
127+
such as a loss function or post-processing logic. In that case, wrap the entire
128+
forward-plus-loss block in your own ``torch.autocast`` context:
129+
130+
.. code-block:: python
131+
132+
# Autocast covers both the engine forward AND the loss computation
133+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
134+
logits = model_engine(input_ids)
135+
loss = loss_fn(logits.view(-1, vocab_size), labels.view(-1))
136+
137+
Without the outer ``torch.autocast``, only the model's forward pass benefits from
138+
autocast; the loss function would run in full precision.
139+
140+
When DeepSpeed detects a nested autocast context, it handles it as follows:
141+
142+
* If ``torch_autocast`` is **enabled** in the DeepSpeed config, the engine overrides the
143+
outer context with the dtype from the config. An info message is logged once.
144+
* If ``torch_autocast`` is **disabled** in the config (i.e., you are using DeepSpeed's
145+
built-in bf16/fp16 support instead), the engine disables autocast inside
146+
``engine.forward()`` and a warning is logged once.
147+
148+
In both cases, PyTorch's ``torch.autocast`` is idempotent when nested with the same
149+
dtype, so there is no performance or correctness penalty from the nesting.
150+
122151
.. autofunction:: deepspeed.runtime.torch_autocast.init_autocast_params
123152
.. autofunction:: deepspeed.runtime.torch_autocast.is_autocast_initialized
124153
.. autofunction:: deepspeed.runtime.torch_autocast.get_default_autocast_lower_precision_modules

0 commit comments

Comments
 (0)