diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 079c40da71c..0ffe17486d0 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -118,6 +118,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi MIX_HIDDEN_STATES="${1#*=}" ;; + --disable_torch_compile*) + if [[ "$1" != *=* ]]; then shift; fi + DISABLE_TORCH_COMPILE="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -158,6 +162,7 @@ DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} LOG_STEPS=${LOG_STEPS:-100} DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} +DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"} NUM_TTT_STEPS=${NUM_TTT_STEPS:-3} @@ -245,6 +250,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --estimate_ar $ESTIMATE_AR \ --ar_validate_steps $AR_VALIDATE_STEPS \ --mix_hidden_states $MIX_HIDDEN_STATES \ + --disable_torch_compile $DISABLE_TORCH_COMPILE \ $DRAFT_VOCAB_CACHE_ARGS \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index cd7ef347588..0db3867ccba 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -130,6 +130,10 @@ class EagleArguments: default=False, metadata={"help": "Whether to mix hidden states from previous TTT step."}, ) + disable_torch_compile: bool = field( + default=False, + metadata={"help": "Disable torch.compile on eagle forward/loss methods."}, + ) num_ttt_steps: int = field( default=3, metadata={"help": "Number of train-time-test steps to use during training."}, @@ -149,9 +153,10 @@ def train(): model_args, data_args, training_args, medusa_args, eagle_args = ( parser.parse_args_into_dataclasses() ) - training_args.parallelism_config = ParallelismConfig( - cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size - ) + if training_args.cp_size > 1 or training_args.dp_shard_size > 1: + training_args.parallelism_config = ParallelismConfig( + cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size + ) if training_args.cp_size > 1: patch_ring_attention_for_ttt() # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 @@ -212,6 +217,7 @@ def train(): "eagle_decoder_type": eagle_args.eagle_decoder_type, "eagle_offline": use_offline_training, "eagle_mix_hidden_states": eagle_args.mix_hidden_states, + "eagle_use_torch_compile": not eagle_args.disable_torch_compile, "eagle_ttt_steps": eagle_args.num_ttt_steps, "eagle_architecture_config": custom_config, } diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 8172bc7bf86..69491c65994 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -110,3 +110,13 @@ class EagleConfig(ModeloptBaseConfig): "Whether to mix hidden states of multiple TTT steps. It is a technique to reduce training cost." ), ) + + eagle_use_torch_compile: bool = ModeloptField( + default=True, + description="Whether to use torch.compile on eagle forward/loss methods for faster training.", + ) + + eagle_enable_nvtx: bool = ModeloptField( + default=False, + description="Whether to enable NVTX ranges for profiling eagle forward/loss methods.", + ) diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index 85251c86a21..e2a08c5252a 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -39,3 +39,5 @@ def modify( self.eagle_decoder_type = config.eagle_decoder_type self.eagle_ttt_steps = config.eagle_ttt_steps self.eagle_mix_hidden_states = config.eagle_mix_hidden_states + self.eagle_use_torch_compile = config.eagle_use_torch_compile + self.eagle_enable_nvtx = config.eagle_enable_nvtx diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 1b85c342e75..37517d768d4 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -292,6 +292,10 @@ def __init__(self, config, decoder_layer_cls, bias=False): num_layers=self.config.parallel_draft_heads_num_layers, ) + def _maybe_init_rope(self): + if self.config.eagle_decoder_type == "llama" and not hasattr(self, "rotary_emb"): + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + def _expand_first_attn_in_dim(self, first_layer_attn): """Modify qkv projection in first layer to accept 2h hidden size.""" # Find Linear modules to expand @@ -372,11 +376,6 @@ def forward( self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) if self.config.eagle_decoder_type == "llama": - # Lazy init rope to avoid save/load meta tensor error - if not hasattr(self, "rotary_emb"): - self.rotary_emb = LlamaRotaryEmbedding( - config=self.config, device=hidden_states.device - ) position_embeddings = self.rotary_emb(hidden_states, position_ids) else: position_embeddings = None @@ -454,6 +453,23 @@ def _draft_model_config(self): """Return the llm config for the draft model.""" return self.eagle_config + def _enable_cp_ttt(self): + if self.training and not self.eagle_mix_hidden_states: + return enable_cp_ttt_patch() + return contextlib.nullcontext() + + def _nvtx_range(self, name): + """Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set.""" + if not self.eagle_enable_nvtx: + return contextlib.nullcontext() + try: + import torch.cuda.nvtx as nvtx + + return nvtx.range(name) + except Exception as e: + print(f"Failed to create NVTX range {name}: {e}") + return contextlib.nullcontext() + def get_exporter(self) -> SpeculativeDecodingExporter: """Get the exporter for the draft model.""" exporter_cls = ( @@ -618,8 +634,27 @@ def modify( # https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566 self.is_quantized = False + if self.eagle_use_torch_compile: + self._activate_torch_compile() + self._cached_attn_blk_masks = {} + def _activate_torch_compile(self): + import torch._dynamo + + torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode + + compile_targets = [ + ("_prepare_eagle_inputs", {}), + ("_eagle_forward", {"mode": "max-autotune"}), + ("_eagle_loss", {"fullgraph": True}), + ] + for name, kwargs in compile_targets: + try: + setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) + except Exception: # noqa: PERF203 + print(f"Disabling torch.compile for {name} due to compilation error.") + def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step): # compile and cached flex attention masks in first call if ttt_step not in self._cached_attn_blk_masks: @@ -716,7 +751,20 @@ def _prepare_eagle_inputs( else: eagle_position_ids = position_ids.view(-1, seq_length).long() - return eagle_input_embeds, eagle_input_hiddens, eagle_attention_mask, eagle_position_ids + base_model_logits = base_outputs.logits + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) + base_output_predict_tok = base_model_logits.argmax(dim=-1).detach() + base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach() + + return ( + eagle_input_embeds, + eagle_input_hiddens, + eagle_attention_mask, + eagle_position_ids, + base_output_predict_tok, + base_output_softmax_logits, + ) def _compute_ttt_attention_mask( self, batch_size, seq_length, ttt_step @@ -872,15 +920,16 @@ def forward( base_outputs.logits = self.lm_head(base_outputs.out_hiddens) past_key_values = None else: - base_outputs, past_key_values = self._base_model_forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - self.eagle_freeze_base_model, - labels, - **kwargs, - ) + with self._nvtx_range("base_model_forward"): + base_outputs, past_key_values = self._base_model_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + self.eagle_freeze_base_model, + labels, + **kwargs, + ) if not isinstance(past_key_values, Cache): past_key_values = _get_empty_cache(self._base_llm_config) @@ -890,20 +939,27 @@ def forward( # ====Prepare inputs for the first eagle forward pass==== eagle_loss = None - train_accs = [[] for _ in range(self.eagle_config.parallel_draft_step)] + num_parallel = self.eagle_config.parallel_draft_step + num_ttt = self.eagle_ttt_steps + train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device) b, seq_length, _ = base_outputs.out_hiddens.shape - ( - eagle_input_embeds, - eagle_input_hiddens, - eagle_attn_mask_0, - eagle_position_ids, - ) = self._prepare_eagle_inputs( - input_ids, - attention_mask, - position_ids, - eagle_cache, - base_outputs, - ) + with self._nvtx_range("prepare_eagle_inputs"): + ( + eagle_input_embeds, + eagle_input_hiddens, + eagle_attn_mask_0, + eagle_position_ids, + base_output_predict_tok, + base_output_softmax_logits, + ) = self._prepare_eagle_inputs( + input_ids, + attention_mask, + position_ids, + eagle_cache, + base_outputs, + ) + + self.eagle_module._maybe_init_rope() # ====Run eagle forward with extra training-time-test steps==== for ttt_step in range(self.eagle_ttt_steps): @@ -913,11 +969,7 @@ def forward( if self.eagle_mix_hidden_states or ttt_step == 0 else self._get_ttt_attention_mask(b, seq_length, ttt_step) ) - with ( - enable_cp_ttt_patch() - if self.training and not self.eagle_mix_hidden_states - else contextlib.nullcontext() - ): + with self._enable_cp_ttt(), self._nvtx_range("eagle_forward"): _, eagle_output_hiddens, eagle_logits, eagle_cache = self._eagle_forward( eagle_input_hiddens, eagle_input_embeds, @@ -945,23 +997,28 @@ def forward( for i in range(self.eagle_config.parallel_draft_step): eagle_logit = eagle_logits[i] - classification_loss, acc = self._eagle_loss( - # base model predict +1 tok, while eagle predict +2 - # so we shift base model outputs compared to eagle outputs - # additionally, we mask the first n tok of eagle outputs at nth TTT step - base_outputs.logits[:, 1 + i + ttt_step :], - eagle_logit[:, ttt_step : -(1 + i)], - loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i], - ) + with self._nvtx_range("eagle_loss"): + classification_loss, acc = self._eagle_loss( + # base model predict +1 tok, while eagle predict +2 + # so we shift base model outputs compared to eagle outputs + # additionally, we mask the first n tok of eagle outputs at nth TTT step + base_output_softmax_logits[:, 1 + i + ttt_step :], + base_output_predict_tok[:, 1 + i + ttt_step :], + eagle_logit[:, ttt_step : -(1 + i)], + loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i], + ) # Apply loss decay factor to focus on early steps classification_loss *= self.eagle_loss_decay_factor ** (ttt_step + i) eagle_loss = ( classification_loss if eagle_loss is None else eagle_loss + classification_loss ) - train_accs[i].append(acc) + train_accs[i, ttt_step] = acc if not self.training: break + # Slice by actual number of steps taken, in case of early return + train_accs = train_accs[:, : ttt_step + 1].tolist() + # Merge base model loss and eagle loss if base_outputs.loss is None and eagle_loss is None: loss = None @@ -979,27 +1036,23 @@ def forward( def _eagle_loss( self, - base_model_logits, + base_output_softmax_logits, + base_output_predict_tok, eagle_logits, loss_mask, ): """Function for EAGLE loss computing.""" - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) loss_mask = loss_mask[:, : eagle_logits.shape[1], None] - classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)( - eagle_logits - ) - classification_loss = -torch.sum(torch.sum(loss_mask * classification_loss, 2)) / ( - loss_mask.sum() + 1e-5 - ) - # Compute accuracy - base_predict_tok = base_model_logits.clone().detach().argmax(dim=-1) - eagle_predict_tok = eagle_logits.clone().detach().argmax(dim=-1) + eagle_logsoft = torch.log_softmax(eagle_logits, dim=2) + classification_loss = -torch.sum( + torch.sum(loss_mask * base_output_softmax_logits * eagle_logsoft, 2) + ) / (loss_mask.sum() + 1e-5) + # Compute accuracy (returned as tensor to avoid sync; .item() called after TTT loop) + eagle_predict_tok = eagle_logits.detach().argmax(dim=-1) valid = loss_mask[:, :, 0].bool() - correct = (base_predict_tok == eagle_predict_tok) & valid + correct = (base_output_predict_tok == eagle_predict_tok) & valid denom = valid.sum().clamp_min(1).float() - accuracy = round(correct.sum().float().div(denom).item(), 3) + accuracy = correct.sum().float() / denom return classification_loss, accuracy @@ -1039,6 +1092,7 @@ def pseudo_speculative_generate( else: eagle_input_hidden_states = base_model_hidden_states + self.eagle_module._maybe_init_rope() draft_tokens = [] for step in range(steps): b, seq_length = eagle_ids.shape @@ -1051,7 +1105,10 @@ def pseudo_speculative_generate( ) # Use SDPA attention during generation for both stability and performance - with temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"): + with ( + temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"), + self._nvtx_range("eagle_forward"), + ): _, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward( eagle_input_hidden_states, self._base_model_embeddings(eagle_ids),