Skip entropy gradient computation when entropy_coef == 0#2130
Open
CSUN1997 wants to merge 1 commit into
Open
Conversation
`policy_loss_function` always requests entropy with `with_entropy=True`, and `calculate_log_probs_and_entropy` computed it with autograd enabled unconditionally. But entropy enters the loss as `loss = pg_loss - args.entropy_coef * entropy_loss`, so when `entropy_coef == 0` the entropy term contributes no gradient — yet the full `[num_tokens, vocab]` entropy autograd graph (plus a defensive `logits.clone()` per chunk) was still retained. For long multi-turn rollouts this activation memory dominates and OOMs. Add a `need_entropy_grad` flag to `calculate_log_probs_and_entropy`. The caller sets it to `with_entropy and args.entropy_coef != 0`. When false, entropy is computed under `torch.no_grad()` and the clone is skipped (the clone only exists to keep the backward's in-place ops off the shared logits tensor; there is no backward under `no_grad`). Entropy values are unchanged — only the graph is dropped — so the logged `entropy_loss` metric is identical. This makes the code match the existing `get_log_probs_and_entropy` docstring, which already claimed this behavior but was never implemented (the prior attempt, THUDM#1185, inverted the `no_grad` condition and was reverted in THUDM#1189). Add tests/test_entropy_grad_gating.py covering: entropy values match with/without grad; need_entropy_grad=False detaches the graph; need_entropy_grad=True remains differentiable; log_probs are unaffected; empty-input handling.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Skip entropy gradient computation when
entropy_coef == 0Summary
When
entropy_coef == 0, slime still computes the policy entropy with autogradenabled, retaining the full
[num_tokens, vocab]entropy computation graph (and aper-chunk
logits.clone()) even though the entropy term is multiplied out of the lossand contributes no gradient. This wasted activation memory dominates for long
multi-turn / agentic rollouts and is a frequent OOM source.
This PR gates the entropy autograd graph on whether it is actually needed.
Root cause
policy_loss_functionrequests entropy unconditionally:calculate_log_probs_and_entropythen computes entropy with grad tracking on:compute_entropy_from_logitsis atorch.autograd.Functionover the vocab-parallellogits, so with grad enabled it saves
[num_tokens, vocab]activations for backward.When
entropy_coef == 0that graph is built and held for nothing.Notably, the
get_log_probs_and_entropydocstring already claims this is handled:…but the implementation never did it. A prior attempt (#1185) added the gate with an
inverted condition —
with torch.no_grad() if args.entropy_coef else nullcontext()— which disabled the gradient exactly when
coef != 0(breaking the entropy-bonuspath) and kept it when
coef == 0. It was reverted the next day in #1189. This PRimplements the gate with the correct condition and locks the behavior with tests.
Change
Add a
need_entropy_gradparameter tocalculate_log_probs_and_entropy(default
True, so external callers are unaffected). The caller computes:When
need_entropy_gradisFalse, entropy is computed undertorch.no_grad()andthe defensive
logits.clone()is skipped (the clone only exists to keep the backward'sin-place ops —
_VocabParallelEntropy.backwardmutates the saved logits viasub_/mul_— off the shared logits tensor; with no backward there is nothing to protect).
The entropy values are identical; only the autograd graph is dropped. The logged
entropy_lossmetric is unchanged.Files
slime/utils/ppo_utils.py—calculate_log_probs_and_entropy: addneed_entropy_grad,wrap entropy compute in
no_gradand skip the clone when grad isn't needed.slime/backends/megatron_utils/loss.py—get_log_probs_and_entropy: deriveneed_entropy_gradfromargs.entropy_coefand pass it through; fix the stale docstring.tests/test_entropy_grad_gating.py— new unit tests.Why this is safe
need_entropy_graddefaults toTrue; the onlyin-repo caller opts into gating based on
entropy_coef. Any external caller that doesnot pass the flag gets the previous (grad-on) behavior.
no_gradonly changeswhether the graph is retained (test 1).
entropy_coef != 0,need_entropy_gradis
True, so entropy stays differentiable and still backpropagates to the logits(test 3). This is the exact correctness bug that sank Don't calculate entropy grad when coef is 0 #1185.
Testing
tests/test_entropy_grad_gating.py(CPU,world_size=1gloo group;compute_log_probsstubbed to avoid the Megatron fused-CE dependency), parametrized over chunked and
non-chunked paths:
test_entropy_values_match_regardless_of_grad— entropy values equal with/without grad.test_need_entropy_grad_false_detaches_graph—grad_fn is None,requires_grad False.test_need_entropy_grad_true_is_differentiable— entropy backprops to logits, finite grad.test_log_probs_unaffected_by_entropy_flag— log-prob output independent of the flag.test_empty_input_returns_empty_entropy— zero-length response edge case.ruff checkpasses on all changed files; existingtests/test_chunked_gae.pystill passes.Notes / alternatives
This is the minimal, targeted fix (drop the graph when the gradient is provably unused).
A complementary, more general optimization is to filter logits to
loss_mask == 1positions before the vocab-parallel softmax (cf. #1905), which shrinks
both the log-prob and entropy compute for agentic rollouts where most response tokens are
masked tool-result tokens. The two are orthogonal and can be combined; this PR addresses
only the
entropy_coef == 0wasted-gradient case and restores the behavior the docstringalready promises.