Skip to content

Commit 48cc15f

Browse files
committed
refactor conversion API
Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent ad1d17c commit 48cc15f

4 files changed

Lines changed: 24 additions & 90 deletions

File tree

modelopt/torch/speculative/eagle/conversion.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,7 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu
4848
config.eagle_architecture_config = {**default_arch_config, **custom_config}
4949

5050
eagle_model = EagleDMRegistry.convert(model)
51-
eagle_model.modify(
52-
eagle_offline=config.eagle_offline,
53-
eagle_hidden_state_distillation=config.eagle_hidden_state_distillation,
54-
eagle_self_logit_distillation=config.eagle_self_logit_distillation,
55-
eagle_freeze_base_model=config.eagle_freeze_base_model,
56-
eagle_report_acc=config.eagle_report_acc,
57-
eagle_reuse_base_decoder=config.eagle_reuse_base_decoder,
58-
eagle_loss_decay_factor=config.eagle_loss_decay_factor,
59-
eagle_architecture_config=config.eagle_architecture_config,
60-
eagle_decoder_type=config.eagle_decoder_type,
61-
eagle_ttt_steps=config.eagle_ttt_steps,
62-
eagle_mix_hidden_states=config.eagle_mix_hidden_states,
63-
)
51+
eagle_model.modify(config)
6452

6553
# no metadata, all specified via config.
6654
metadata = {}

modelopt/torch/speculative/eagle/eagle_model.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,16 @@ def _setup(self):
2626

2727
def modify(
2828
self,
29-
eagle_offline,
30-
eagle_hidden_state_distillation,
31-
eagle_self_logit_distillation,
32-
eagle_freeze_base_model,
33-
eagle_report_acc,
34-
eagle_reuse_base_decoder,
35-
eagle_loss_decay_factor,
36-
eagle_architecture_config,
37-
eagle_decoder_type,
38-
eagle_ttt_steps,
39-
eagle_mix_hidden_states,
29+
config,
4030
):
4131
"""Base Eagle Model modify function. Child class should implement the details."""
42-
self.eagle_offline = eagle_offline
43-
self.eagle_hidden_state_distillation = eagle_hidden_state_distillation
44-
self.eagle_self_logit_distillation = eagle_self_logit_distillation
45-
self.eagle_freeze_base_model = eagle_freeze_base_model
46-
self.eagle_report_acc = eagle_report_acc
47-
self.eagle_reuse_base_decoder = eagle_reuse_base_decoder
48-
self.eagle_loss_decay_factor = eagle_loss_decay_factor
49-
self.eagle_decoder_type = eagle_decoder_type
50-
self.eagle_ttt_steps = eagle_ttt_steps
51-
self.eagle_mix_hidden_states = eagle_mix_hidden_states
32+
self.eagle_offline = config.eagle_offline
33+
self.eagle_hidden_state_distillation = config.eagle_hidden_state_distillation
34+
self.eagle_self_logit_distillation = config.eagle_self_logit_distillation
35+
self.eagle_freeze_base_model = config.eagle_freeze_base_model
36+
self.eagle_report_acc = config.eagle_report_acc
37+
self.eagle_reuse_base_decoder = config.eagle_reuse_base_decoder
38+
self.eagle_loss_decay_factor = config.eagle_loss_decay_factor
39+
self.eagle_decoder_type = config.eagle_decoder_type
40+
self.eagle_ttt_steps = config.eagle_ttt_steps
41+
self.eagle_mix_hidden_states = config.eagle_mix_hidden_states

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -682,17 +682,7 @@ def _setup(self):
682682

683683
def modify(
684684
self,
685-
eagle_offline,
686-
eagle_hidden_state_distillation,
687-
eagle_self_logit_distillation,
688-
eagle_freeze_base_model,
689-
eagle_report_acc,
690-
eagle_reuse_base_decoder,
691-
eagle_loss_decay_factor,
692-
eagle_architecture_config,
693-
eagle_decoder_type,
694-
eagle_ttt_steps,
695-
eagle_mix_hidden_states,
685+
config,
696686
):
697687
if self.config.pipeline_model_parallel_size > 1:
698688
warnings.warn(
@@ -705,26 +695,14 @@ def modify(
705695
if hasattr(self.config, "hetereogenous_dist_checkpoint"):
706696
self.config.hetereogenous_dist_checkpoint = True
707697

708-
super().modify(
709-
eagle_offline=eagle_offline,
710-
eagle_hidden_state_distillation=eagle_hidden_state_distillation,
711-
eagle_self_logit_distillation=eagle_self_logit_distillation,
712-
eagle_freeze_base_model=eagle_freeze_base_model,
713-
eagle_report_acc=eagle_report_acc,
714-
eagle_reuse_base_decoder=eagle_reuse_base_decoder,
715-
eagle_loss_decay_factor=eagle_loss_decay_factor,
716-
eagle_architecture_config=eagle_architecture_config,
717-
eagle_decoder_type=eagle_decoder_type,
718-
eagle_ttt_steps=eagle_ttt_steps,
719-
eagle_mix_hidden_states=eagle_mix_hidden_states,
720-
)
698+
super().modify(config)
721699

722700
# sequence_parallel is not used in offline eagle
723701
if self.eagle_offline:
724702
self.config.sequence_parallel = False
725703

726704
self.eagle_config = dict_to_config(
727-
eagle_architecture_config,
705+
config.eagle_architecture_config,
728706
self.config.use_cpu_initialization,
729707
self.config.fp16,
730708
self.config.bf16,
@@ -740,7 +718,7 @@ def modify(
740718
)
741719

742720
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
743-
assert eagle_self_logit_distillation, (
721+
assert self.eagle_self_logit_distillation, (
744722
"Only logit distillation is supported when draft_vocab_size != vocab_size!"
745723
)
746724

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -549,45 +549,23 @@ def _get_eagle_device(self):
549549

550550
def modify(
551551
self,
552-
eagle_offline,
553-
eagle_hidden_state_distillation,
554-
eagle_self_logit_distillation,
555-
eagle_freeze_base_model,
556-
eagle_report_acc,
557-
eagle_reuse_base_decoder,
558-
eagle_loss_decay_factor,
559-
eagle_architecture_config,
560-
eagle_decoder_type,
561-
eagle_ttt_steps,
562-
eagle_mix_hidden_states,
552+
config,
563553
):
564554
"""Constructor.
565555
566556
Args:
567557
config: The config for eagle decoder layers.
568558
"""
569-
super().modify(
570-
eagle_offline=eagle_offline,
571-
eagle_hidden_state_distillation=eagle_hidden_state_distillation,
572-
eagle_self_logit_distillation=eagle_self_logit_distillation,
573-
eagle_freeze_base_model=eagle_freeze_base_model,
574-
eagle_report_acc=eagle_report_acc,
575-
eagle_reuse_base_decoder=eagle_reuse_base_decoder,
576-
eagle_loss_decay_factor=eagle_loss_decay_factor,
577-
eagle_architecture_config=eagle_architecture_config,
578-
eagle_decoder_type=eagle_decoder_type,
579-
eagle_ttt_steps=eagle_ttt_steps,
580-
eagle_mix_hidden_states=eagle_mix_hidden_states,
581-
)
559+
super().modify(config)
582560

583-
if eagle_decoder_type == "llama":
561+
if self.eagle_decoder_type == "llama":
584562
# Use default eagle config
585563
decoder_cls = LlamaDecoderLayer
586-
elif eagle_decoder_type == "kimik2":
564+
elif self.eagle_decoder_type == "kimik2":
587565
decoder_cls = _setup_kimi_k2_decoder()
588566

589-
self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config)
590-
self.eagle_config.eagle_decoder_type = eagle_decoder_type
567+
self.eagle_config = PretrainedConfig.from_dict(config.eagle_architecture_config)
568+
self.eagle_config.eagle_decoder_type = self.eagle_decoder_type
591569
# Hidden size and vocab size must match base model
592570
self.eagle_config.hidden_size = self._base_llm_config.hidden_size
593571
self.eagle_config.vocab_size = self._base_llm_config.vocab_size
@@ -626,14 +604,14 @@ def modify(
626604
self.eagle_module.to(self._base_model.dtype).to(self._get_eagle_device())
627605

628606
# EAGLE-3 auxiliary hidden_states
629-
if (not eagle_offline) and self.eagle_config.use_aux_hidden_state:
607+
if (not self.eagle_offline) and self.eagle_config.use_aux_hidden_state:
630608
self._aux_hidden_states = []
631609
for layer_idx, layer in enumerate(self._base_model.layers):
632610
if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
633611
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)
634612

635613
# delete base model layers for offline training
636-
if eagle_offline:
614+
if self.eagle_offline:
637615
self._base_model._modules.pop("layers")
638616

639617
# NOTE: this is a temporary hack to bypass hf trainer check:

0 commit comments

Comments
 (0)