Skip to content

Commit c76633a

Browse files
authored
[EAGLE] Configurable number of TTT steps (#1042)
### What does this PR do? Type of change: new CLI option for existing option <!-- Details about the change. --> - Added num_ttt_steps CLI flag - Changed num_ttt_steps default from 4 to 3 for consistency. Num_spec_tokens == 3 or == 7 are most common in practice, so rounding down to 3 and allowing users to increment higher on-demand. Will also improve training efficiency for the OOTB experience. ### Usage Users can now pass `--num_ttt_steps 7` to `launch_train.sh` when training an EAGLE3 model for extended speculation lengths. ### Testing N/A ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added ability to configure train-time-test steps for speculative decoding training via command-line argument. * Updated default train-time-test steps value from 4 to 3. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
1 parent 4292505 commit c76633a

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

examples/speculative_decoding/launch_train.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ while [ $# -gt 0 ]; do
8686
if [[ "$1" != *=* ]]; then shift; fi
8787
AR_VALIDATE_STEPS="${1#*=}"
8888
;;
89+
--num_ttt_steps*)
90+
if [[ "$1" != *=* ]]; then shift; fi
91+
NUM_TTT_STEPS="${1#*=}"
92+
;;
8993
--cp_size*)
9094
if [[ "$1" != *=* ]]; then shift; fi
9195
CP_SIZE="${1#*=}"
@@ -154,6 +158,7 @@ DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))}
154158
LOG_STEPS=${LOG_STEPS:-100}
155159
DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""}
156160
MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"}
161+
NUM_TTT_STEPS=${NUM_TTT_STEPS:-3}
157162

158163

159164
if [[ "$MODE" == "eagle3" ]]; then
@@ -247,6 +252,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
247252
$FSDP_ARGS \
248253
--cp_size $CP_SIZE \
249254
--dp_shard_size $DP_SHARD_SIZE \
255+
--num_ttt_steps $NUM_TTT_STEPS \
250256
"
251257

252258
start_time=$(date +%s)

examples/speculative_decoding/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ class EagleArguments:
130130
default=False,
131131
metadata={"help": "Whether to mix hidden states from previous TTT step."},
132132
)
133+
num_ttt_steps: int = field(
134+
default=3,
135+
metadata={"help": "Number of train-time-test steps to use during training."},
136+
)
133137

134138

135139
def train():
@@ -208,6 +212,7 @@ def train():
208212
"eagle_decoder_type": eagle_args.eagle_decoder_type,
209213
"eagle_offline": use_offline_training,
210214
"eagle_mix_hidden_states": eagle_args.mix_hidden_states,
215+
"eagle_ttt_steps": eagle_args.num_ttt_steps,
211216
"eagle_architecture_config": custom_config,
212217
}
213218

modelopt/torch/speculative/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class EagleConfig(ModeloptBaseConfig):
101101
)
102102

103103
eagle_ttt_steps: int = ModeloptField(
104-
default=4, description=("The number of train-time-test steps in training.")
104+
default=3, description=("The number of train-time-test steps in training.")
105105
)
106106

107107
eagle_mix_hidden_states: bool = ModeloptField(

0 commit comments

Comments
 (0)