Skip to content

Commit d963bfc

Browse files
fix: EAGLE mix_hidden_states in-place op crash (#1088)
Clone eagle_input_hiddens before indexed assignment to avoid in-place modification of a tensor in the autograd graph, which causes RuntimeError during backward pass. Mirrors the existing fix in the Megatron backend (megatron_eagle.py:1201-1202). Add regression test parametrized over eagle_ttt_steps [1, 2]. Signed-off-by: javierdejesusda <javier.dejesusj9@gmail.com>
1 parent aad14d1 commit d963bfc

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,8 @@ def forward(
988988
batch_size, seq_len_s, device=eagle_input_hiddens.device
989989
).argsort(dim=1)[:, :num_to_replace]
990990

991+
# Clone to avoid inplace modification that breaks autograd
992+
eagle_input_hiddens = eagle_input_hiddens.clone()
991993
batch_indices = torch.arange(batch_size)[:, None]
992994
eagle_input_hiddens[batch_indices, rand_indices] = eagle_output_hiddens[
993995
batch_indices, rand_indices

tests/unit/torch/speculative/plugins/test_hf_speculative.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from copy import deepcopy
1818

1919
import pytest
20+
import torch
2021
from _test_utils.torch.transformers_models import (
2122
get_tiny_llama,
2223
tf_modelopt_state_and_output_tester,
@@ -48,3 +49,39 @@ def test_eagle_model_convert_save_and_restore(tmp_path, eagle_config):
4849
model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model")
4950
assert isinstance(model_test, mtsp.plugins.HFEagleModel)
5051
tf_modelopt_state_and_output_tester(model_ref, model_test)
52+
53+
54+
@pytest.mark.parametrize("eagle_config", [EAGLE3_DEFAULT_CFG])
55+
@pytest.mark.parametrize("eagle_ttt_steps", [1, 2])
56+
def test_eagle_mix_hidden_states_backward(eagle_config, eagle_ttt_steps):
57+
"""Regression test for GitHub issue #1088.
58+
59+
Verifies that the EAGLE training forward+backward pass does not crash with
60+
``eagle_mix_hidden_states=True`` due to an in-place tensor modification
61+
breaking autograd.
62+
"""
63+
model = get_tiny_llama(num_hidden_layers=8)
64+
65+
config = deepcopy(eagle_config["config"])
66+
config["eagle_architecture_config"].update(
67+
{
68+
"draft_vocab_size": model.config.vocab_size,
69+
"hidden_size": model.config.hidden_size,
70+
}
71+
)
72+
config["eagle_mix_hidden_states"] = True
73+
config["eagle_ttt_steps"] = eagle_ttt_steps
74+
config["eagle_use_torch_compile"] = False
75+
76+
mtsp.convert(model, mode=[("eagle", config)])
77+
model.train()
78+
79+
input_ids = torch.randint(0, model.config.vocab_size, (2, 16))
80+
labels = input_ids.clone()
81+
82+
outputs = model(input_ids=input_ids, labels=labels)
83+
assert outputs.loss is not None
84+
outputs.loss.backward()
85+
86+
eagle_grads = [p.grad for p in model.eagle_module.parameters() if p.grad is not None]
87+
assert len(eagle_grads) > 0, "Expected gradients to flow to eagle_module"

0 commit comments

Comments
 (0)