Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
HEAD_NODE_IP="${1#*=}"
;;
--mix_hidden_states*)
if [[ "$1" != *=* ]]; then shift; fi
MIX_HIDDEN_STATES="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -149,6 +153,7 @@ CP_SIZE=${CP_SIZE:-1}
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))}
LOG_STEPS=${LOG_STEPS:-100}
DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""}
MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"}


if [[ "$MODE" == "eagle3" ]]; then
Expand Down Expand Up @@ -234,6 +239,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
--disable_tqdm $DISABLE_TQDM \
--estimate_ar $ESTIMATE_AR \
--ar_validate_steps $AR_VALIDATE_STEPS \
--mix_hidden_states $MIX_HIDDEN_STATES \
$DRAFT_VOCAB_CACHE_ARGS \
$VLM_ARGS \
$OFFLINE_TRAINING_ARGS \
Expand Down
10 changes: 6 additions & 4 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
make_eagle_supervised_data_module,
patch_ring_attention_for_ttt,
)
from medusa_utils import make_medusa_supervised_data_module
from transformers.trainer_utils import get_last_checkpoint

import modelopt.torch.opt as mto
Expand Down Expand Up @@ -127,6 +126,10 @@ class EagleArguments:
default="llama",
metadata={"help": "The class of eagle decoder to use. Available options: llama, kimik2"},
)
mix_hidden_states: bool = field(
default=False,
metadata={"help": "Whether to mix hidden states from previous TTT step."},
)


def train():
Expand Down Expand Up @@ -204,6 +207,7 @@ def train():
config = {
"eagle_decoder_type": eagle_args.eagle_decoder_type,
"eagle_offline": use_offline_training,
"eagle_mix_hidden_states": eagle_args.mix_hidden_states,
"eagle_architecture_config": custom_config,
}

Expand All @@ -221,9 +225,7 @@ def train():
raise Exception(f"{training_args.mode} is not supported!")

print_rank_0("Loading dataset...")
if training_args.mode == "medusa":
data_module = make_medusa_supervised_data_module(tokenizer, data_args)
elif training_args.mode == "eagle3":
if training_args.mode == "eagle3":
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, train_len=training_args.training_seq_len
)
Comment on lines +228 to 231
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

data_module can be undefined for non-eagle3 modes.

TrainingArguments.mode still allows "medusa", but only the eagle3 branch initializes data_module, which can crash at trainer construction.

💡 Proposed fix
-    if training_args.mode == "eagle3":
-        data_module = make_eagle_supervised_data_module(
-            tokenizer, data_args, train_len=training_args.training_seq_len
-        )
+    if training_args.mode != "eagle3":
+        raise ValueError(f"{training_args.mode} is not supported!")
+    data_module = make_eagle_supervised_data_module(
+        tokenizer, data_args, train_len=training_args.training_seq_len
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if training_args.mode == "eagle3":
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, train_len=training_args.training_seq_len
)
if training_args.mode != "eagle3":
raise ValueError(f"{training_args.mode} is not supported!")
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, train_len=training_args.training_seq_len
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 228 - 231, The code only
initializes data_module inside the if training_args.mode == "eagle3" branch (via
make_eagle_supervised_data_module), leaving data_module undefined for other
modes and causing a crash at Trainer construction; initialize data_module before
the branch (e.g., data_module = None) and either (A) add an else branch that
constructs the appropriate data module for other modes (e.g., a
make_medusa_supervised_data_module call) or (B) ensure the Trainer is only
passed data_module when it is not None (guard the trainer construction or pass a
fallback) so training_args.mode and
make_eagle_supervised_data_module/data_module usage are consistent and never
leave data_module undefined.

Expand Down
17 changes: 11 additions & 6 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@
eagle3_default_config.update({"use_aux_hidden_state": True, "use_last_layernorm": True})
eagle_mtp_default_config.update({"use_last_layernorm": True, "use_mtp_layernorm": True})

EAGLE1_DEFAULT_CFG = {
"algorithm": "eagle",
"config": {
"eagle_architecture_config": deepcopy(default_eagle_config),
},
}

EAGLE3_DEFAULT_CFG = {
"algorithm": "eagle",
Expand Down Expand Up @@ -105,3 +99,14 @@ class EagleConfig(ModeloptBaseConfig):
default="llama",
description=("The class of eagle decoder to use. Available options: llama, kimik2"),
)

eagle_ttt_steps: int = ModeloptField(
default=4, description=("The number of train-time-test steps in training.")
)
Comment on lines +103 to +105
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add a lower-bound validation for eagle_ttt_steps.

At Line 103, eagle_ttt_steps accepts non-positive values. That can bypass the TTT loop and trigger downstream training-time loss assertions. Constrain it at config level.

💡 Proposed fix
     eagle_ttt_steps: int = ModeloptField(
-        default=4, description=("The number of train-time-test steps in training.")
+        default=4,
+        ge=1,
+        description=("The number of train-time-test steps in training."),
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
eagle_ttt_steps: int = ModeloptField(
default=4, description=("The number of train-time-test steps in training.")
)
eagle_ttt_steps: int = ModeloptField(
default=4,
ge=1,
description=("The number of train-time-test steps in training."),
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/config.py` around lines 103 - 105, eagle_ttt_steps
currently allows non-positive integers which can skip the TTT loop; update its
ModeloptField declaration to enforce a lower bound of 1 (e.g., set a validation
constraint such as ge=1 or min=1 depending on ModeloptField API) so only
positive integers are accepted; if ModeloptField doesn't support a direct
constraint, add a validation step for the eagle_ttt_steps attribute (a
pydantic/field validator in the enclosing config class) that raises a clear
error when eagle_ttt_steps < 1.


eagle_mix_hidden_states: bool = ModeloptField(
default=False,
description=(
"Whether to mix hidden states of multiple TTT steps. It is a technique to reduce training cost."
),
)
12 changes: 1 addition & 11 deletions modelopt/torch/speculative/eagle/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,7 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu
config.eagle_architecture_config = {**default_arch_config, **custom_config}

eagle_model = EagleDMRegistry.convert(model)
eagle_model.modify(
eagle_offline=config.eagle_offline,
eagle_hidden_state_distillation=config.eagle_hidden_state_distillation,
eagle_self_logit_distillation=config.eagle_self_logit_distillation,
eagle_freeze_base_model=config.eagle_freeze_base_model,
eagle_report_acc=config.eagle_report_acc,
eagle_reuse_base_decoder=config.eagle_reuse_base_decoder,
eagle_loss_decay_factor=config.eagle_loss_decay_factor,
eagle_architecture_config=config.eagle_architecture_config,
eagle_decoder_type=config.eagle_decoder_type,
)
eagle_model.modify(config)

# no metadata, all specified via config.
metadata = {}
Expand Down
28 changes: 11 additions & 17 deletions modelopt/torch/speculative/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,16 @@ def _setup(self):

def modify(
self,
eagle_offline,
eagle_hidden_state_distillation,
eagle_self_logit_distillation,
eagle_freeze_base_model,
eagle_report_acc,
eagle_reuse_base_decoder,
eagle_loss_decay_factor,
eagle_architecture_config,
eagle_decoder_type,
config,
):
"""Base Eagle Model modify function. Child class should implement the details."""
self.eagle_offline = eagle_offline
self.eagle_hidden_state_distillation = eagle_hidden_state_distillation
self.eagle_self_logit_distillation = eagle_self_logit_distillation
self.eagle_freeze_base_model = eagle_freeze_base_model
self.eagle_report_acc = eagle_report_acc
self.eagle_reuse_base_decoder = eagle_reuse_base_decoder
self.eagle_loss_decay_factor = eagle_loss_decay_factor
self.eagle_decoder_type = eagle_decoder_type
self.eagle_offline = config.eagle_offline
self.eagle_hidden_state_distillation = config.eagle_hidden_state_distillation
self.eagle_self_logit_distillation = config.eagle_self_logit_distillation
self.eagle_freeze_base_model = config.eagle_freeze_base_model
self.eagle_report_acc = config.eagle_report_acc
self.eagle_reuse_base_decoder = config.eagle_reuse_base_decoder
self.eagle_loss_decay_factor = config.eagle_loss_decay_factor
self.eagle_decoder_type = config.eagle_decoder_type
self.eagle_ttt_steps = config.eagle_ttt_steps
self.eagle_mix_hidden_states = config.eagle_mix_hidden_states
Loading
Loading