Skip to content

Commit 5d0e012

Browse files
authored
inplement mix hidden_states for eagle3; deprecate eagle1 (#946)
## What does this PR do? new feature **Overview:** Enable mix hidden_states in eagle3 training. Deprecate eagle1 ## Usage Add --mix_hidden_states True to launch_train.sh ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added --mix_hidden_states option to enable optional hidden-state mixing during training. * Added eagle_ttt_steps setting to control speculative multi-step iterations. * **Chores** * Consolidated speculative decoding to EAGLE3 only; legacy Medusa/EAGLE1 paths removed. * Unified configuration handling so models and plugins accept a single config object. * **Tests** * Updated and expanded tests for hidden-state mixing and EAGLE3-only scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 0ad287c commit 5d0e012

10 files changed

Lines changed: 208 additions & 256 deletions

File tree

examples/speculative_decoding/launch_train.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ while [ $# -gt 0 ]; do
110110
if [[ "$1" != *=* ]]; then shift; fi
111111
HEAD_NODE_IP="${1#*=}"
112112
;;
113+
--mix_hidden_states*)
114+
if [[ "$1" != *=* ]]; then shift; fi
115+
MIX_HIDDEN_STATES="${1#*=}"
116+
;;
113117
*)
114118
>&2 printf "Error: Invalid argument ${1#*=}\n"
115119
exit 1
@@ -149,6 +153,7 @@ CP_SIZE=${CP_SIZE:-1}
149153
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))}
150154
LOG_STEPS=${LOG_STEPS:-100}
151155
DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""}
156+
MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"}
152157

153158

154159
if [[ "$MODE" == "eagle3" ]]; then
@@ -234,6 +239,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
234239
--disable_tqdm $DISABLE_TQDM \
235240
--estimate_ar $ESTIMATE_AR \
236241
--ar_validate_steps $AR_VALIDATE_STEPS \
242+
--mix_hidden_states $MIX_HIDDEN_STATES \
237243
$DRAFT_VOCAB_CACHE_ARGS \
238244
$VLM_ARGS \
239245
$OFFLINE_TRAINING_ARGS \

examples/speculative_decoding/main.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
make_eagle_supervised_data_module,
4444
patch_ring_attention_for_ttt,
4545
)
46-
from medusa_utils import make_medusa_supervised_data_module
4746
from transformers.trainer_utils import get_last_checkpoint
4847

4948
import modelopt.torch.opt as mto
@@ -127,6 +126,10 @@ class EagleArguments:
127126
default="llama",
128127
metadata={"help": "The class of eagle decoder to use. Available options: llama, kimik2"},
129128
)
129+
mix_hidden_states: bool = field(
130+
default=False,
131+
metadata={"help": "Whether to mix hidden states from previous TTT step."},
132+
)
130133

131134

132135
def train():
@@ -204,6 +207,7 @@ def train():
204207
config = {
205208
"eagle_decoder_type": eagle_args.eagle_decoder_type,
206209
"eagle_offline": use_offline_training,
210+
"eagle_mix_hidden_states": eagle_args.mix_hidden_states,
207211
"eagle_architecture_config": custom_config,
208212
}
209213

@@ -221,9 +225,7 @@ def train():
221225
raise Exception(f"{training_args.mode} is not supported!")
222226

223227
print_rank_0("Loading dataset...")
224-
if training_args.mode == "medusa":
225-
data_module = make_medusa_supervised_data_module(tokenizer, data_args)
226-
elif training_args.mode == "eagle3":
228+
if training_args.mode == "eagle3":
227229
data_module = make_eagle_supervised_data_module(
228230
tokenizer, data_args, train_len=training_args.training_seq_len
229231
)

modelopt/torch/speculative/config.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,6 @@
2929
eagle3_default_config.update({"use_aux_hidden_state": True, "use_last_layernorm": True})
3030
eagle_mtp_default_config.update({"use_last_layernorm": True, "use_mtp_layernorm": True})
3131

32-
EAGLE1_DEFAULT_CFG = {
33-
"algorithm": "eagle",
34-
"config": {
35-
"eagle_architecture_config": deepcopy(default_eagle_config),
36-
},
37-
}
3832

3933
EAGLE3_DEFAULT_CFG = {
4034
"algorithm": "eagle",
@@ -105,3 +99,14 @@ class EagleConfig(ModeloptBaseConfig):
10599
default="llama",
106100
description=("The class of eagle decoder to use. Available options: llama, kimik2"),
107101
)
102+
103+
eagle_ttt_steps: int = ModeloptField(
104+
default=4, description=("The number of train-time-test steps in training.")
105+
)
106+
107+
eagle_mix_hidden_states: bool = ModeloptField(
108+
default=False,
109+
description=(
110+
"Whether to mix hidden states of multiple TTT steps. It is a technique to reduce training cost."
111+
),
112+
)

modelopt/torch/speculative/eagle/conversion.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,7 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu
4848
config.eagle_architecture_config = {**default_arch_config, **custom_config}
4949

5050
eagle_model = EagleDMRegistry.convert(model)
51-
eagle_model.modify(
52-
eagle_offline=config.eagle_offline,
53-
eagle_hidden_state_distillation=config.eagle_hidden_state_distillation,
54-
eagle_self_logit_distillation=config.eagle_self_logit_distillation,
55-
eagle_freeze_base_model=config.eagle_freeze_base_model,
56-
eagle_report_acc=config.eagle_report_acc,
57-
eagle_reuse_base_decoder=config.eagle_reuse_base_decoder,
58-
eagle_loss_decay_factor=config.eagle_loss_decay_factor,
59-
eagle_architecture_config=config.eagle_architecture_config,
60-
eagle_decoder_type=config.eagle_decoder_type,
61-
)
51+
eagle_model.modify(config)
6252

6353
# no metadata, all specified via config.
6454
metadata = {}

modelopt/torch/speculative/eagle/eagle_model.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,16 @@ def _setup(self):
2626

2727
def modify(
2828
self,
29-
eagle_offline,
30-
eagle_hidden_state_distillation,
31-
eagle_self_logit_distillation,
32-
eagle_freeze_base_model,
33-
eagle_report_acc,
34-
eagle_reuse_base_decoder,
35-
eagle_loss_decay_factor,
36-
eagle_architecture_config,
37-
eagle_decoder_type,
29+
config,
3830
):
3931
"""Base Eagle Model modify function. Child class should implement the details."""
40-
self.eagle_offline = eagle_offline
41-
self.eagle_hidden_state_distillation = eagle_hidden_state_distillation
42-
self.eagle_self_logit_distillation = eagle_self_logit_distillation
43-
self.eagle_freeze_base_model = eagle_freeze_base_model
44-
self.eagle_report_acc = eagle_report_acc
45-
self.eagle_reuse_base_decoder = eagle_reuse_base_decoder
46-
self.eagle_loss_decay_factor = eagle_loss_decay_factor
47-
self.eagle_decoder_type = eagle_decoder_type
32+
self.eagle_offline = config.eagle_offline
33+
self.eagle_hidden_state_distillation = config.eagle_hidden_state_distillation
34+
self.eagle_self_logit_distillation = config.eagle_self_logit_distillation
35+
self.eagle_freeze_base_model = config.eagle_freeze_base_model
36+
self.eagle_report_acc = config.eagle_report_acc
37+
self.eagle_reuse_base_decoder = config.eagle_reuse_base_decoder
38+
self.eagle_loss_decay_factor = config.eagle_loss_decay_factor
39+
self.eagle_decoder_type = config.eagle_decoder_type
40+
self.eagle_ttt_steps = config.eagle_ttt_steps
41+
self.eagle_mix_hidden_states = config.eagle_mix_hidden_states

0 commit comments

Comments
 (0)