Skip to content

Commit 0c19663

Browse files
committed
refactor conversion API
Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent fb7c8a4 commit 0c19663

4 files changed

Lines changed: 24 additions & 90 deletions

File tree

modelopt/torch/speculative/eagle/conversion.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,7 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu
4848
config.eagle_architecture_config = {**default_arch_config, **custom_config}
4949

5050
eagle_model = EagleDMRegistry.convert(model)
51-
eagle_model.modify(
52-
eagle_offline=config.eagle_offline,
53-
eagle_hidden_state_distillation=config.eagle_hidden_state_distillation,
54-
eagle_self_logit_distillation=config.eagle_self_logit_distillation,
55-
eagle_freeze_base_model=config.eagle_freeze_base_model,
56-
eagle_report_acc=config.eagle_report_acc,
57-
eagle_reuse_base_decoder=config.eagle_reuse_base_decoder,
58-
eagle_loss_decay_factor=config.eagle_loss_decay_factor,
59-
eagle_architecture_config=config.eagle_architecture_config,
60-
eagle_decoder_type=config.eagle_decoder_type,
61-
eagle_ttt_steps=config.eagle_ttt_steps,
62-
eagle_mix_hidden_states=config.eagle_mix_hidden_states,
63-
)
51+
eagle_model.modify(config)
6452

6553
# no metadata, all specified via config.
6654
metadata = {}

modelopt/torch/speculative/eagle/eagle_model.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,16 @@ def _setup(self):
2626

2727
def modify(
2828
self,
29-
eagle_offline,
30-
eagle_hidden_state_distillation,
31-
eagle_self_logit_distillation,
32-
eagle_freeze_base_model,
33-
eagle_report_acc,
34-
eagle_reuse_base_decoder,
35-
eagle_loss_decay_factor,
36-
eagle_architecture_config,
37-
eagle_decoder_type,
38-
eagle_ttt_steps,
39-
eagle_mix_hidden_states,
29+
config,
4030
):
4131
"""Base Eagle Model modify function. Child class should implement the details."""
42-
self.eagle_offline = eagle_offline
43-
self.eagle_hidden_state_distillation = eagle_hidden_state_distillation
44-
self.eagle_self_logit_distillation = eagle_self_logit_distillation
45-
self.eagle_freeze_base_model = eagle_freeze_base_model
46-
self.eagle_report_acc = eagle_report_acc
47-
self.eagle_reuse_base_decoder = eagle_reuse_base_decoder
48-
self.eagle_loss_decay_factor = eagle_loss_decay_factor
49-
self.eagle_decoder_type = eagle_decoder_type
50-
self.eagle_ttt_steps = eagle_ttt_steps
51-
self.eagle_mix_hidden_states = eagle_mix_hidden_states
32+
self.eagle_offline = config.eagle_offline
33+
self.eagle_hidden_state_distillation = config.eagle_hidden_state_distillation
34+
self.eagle_self_logit_distillation = config.eagle_self_logit_distillation
35+
self.eagle_freeze_base_model = config.eagle_freeze_base_model
36+
self.eagle_report_acc = config.eagle_report_acc
37+
self.eagle_reuse_base_decoder = config.eagle_reuse_base_decoder
38+
self.eagle_loss_decay_factor = config.eagle_loss_decay_factor
39+
self.eagle_decoder_type = config.eagle_decoder_type
40+
self.eagle_ttt_steps = config.eagle_ttt_steps
41+
self.eagle_mix_hidden_states = config.eagle_mix_hidden_states

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -682,17 +682,7 @@ def _setup(self):
682682

683683
def modify(
684684
self,
685-
eagle_offline,
686-
eagle_hidden_state_distillation,
687-
eagle_self_logit_distillation,
688-
eagle_freeze_base_model,
689-
eagle_report_acc,
690-
eagle_reuse_base_decoder,
691-
eagle_loss_decay_factor,
692-
eagle_architecture_config,
693-
eagle_decoder_type,
694-
eagle_ttt_steps,
695-
eagle_mix_hidden_states,
685+
config,
696686
):
697687
if self.config.pipeline_model_parallel_size > 1:
698688
warnings.warn(
@@ -705,26 +695,14 @@ def modify(
705695
if hasattr(self.config, "hetereogenous_dist_checkpoint"):
706696
self.config.hetereogenous_dist_checkpoint = True
707697

708-
super().modify(
709-
eagle_offline=eagle_offline,
710-
eagle_hidden_state_distillation=eagle_hidden_state_distillation,
711-
eagle_self_logit_distillation=eagle_self_logit_distillation,
712-
eagle_freeze_base_model=eagle_freeze_base_model,
713-
eagle_report_acc=eagle_report_acc,
714-
eagle_reuse_base_decoder=eagle_reuse_base_decoder,
715-
eagle_loss_decay_factor=eagle_loss_decay_factor,
716-
eagle_architecture_config=eagle_architecture_config,
717-
eagle_decoder_type=eagle_decoder_type,
718-
eagle_ttt_steps=eagle_ttt_steps,
719-
eagle_mix_hidden_states=eagle_mix_hidden_states,
720-
)
698+
super().modify(config)
721699

722700
# sequence_parallel is not used in offline eagle
723701
if self.eagle_offline:
724702
self.config.sequence_parallel = False
725703

726704
self.eagle_config = dict_to_config(
727-
eagle_architecture_config,
705+
config.eagle_architecture_config,
728706
self.config.use_cpu_initialization,
729707
self.config.fp16,
730708
self.config.bf16,
@@ -740,7 +718,7 @@ def modify(
740718
)
741719

742720
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
743-
assert eagle_self_logit_distillation, (
721+
assert self.eagle_self_logit_distillation, (
744722
"Only logit distillation is supported when draft_vocab_size != vocab_size!"
745723
)
746724

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -532,45 +532,23 @@ def _get_eagle_device(self):
532532

533533
def modify(
534534
self,
535-
eagle_offline,
536-
eagle_hidden_state_distillation,
537-
eagle_self_logit_distillation,
538-
eagle_freeze_base_model,
539-
eagle_report_acc,
540-
eagle_reuse_base_decoder,
541-
eagle_loss_decay_factor,
542-
eagle_architecture_config,
543-
eagle_decoder_type,
544-
eagle_ttt_steps,
545-
eagle_mix_hidden_states,
535+
config,
546536
):
547537
"""Constructor.
548538
549539
Args:
550540
config: The config for eagle decoder layers.
551541
"""
552-
super().modify(
553-
eagle_offline=eagle_offline,
554-
eagle_hidden_state_distillation=eagle_hidden_state_distillation,
555-
eagle_self_logit_distillation=eagle_self_logit_distillation,
556-
eagle_freeze_base_model=eagle_freeze_base_model,
557-
eagle_report_acc=eagle_report_acc,
558-
eagle_reuse_base_decoder=eagle_reuse_base_decoder,
559-
eagle_loss_decay_factor=eagle_loss_decay_factor,
560-
eagle_architecture_config=eagle_architecture_config,
561-
eagle_decoder_type=eagle_decoder_type,
562-
eagle_ttt_steps=eagle_ttt_steps,
563-
eagle_mix_hidden_states=eagle_mix_hidden_states,
564-
)
542+
super().modify(config)
565543

566-
if eagle_decoder_type == "llama":
544+
if self.eagle_decoder_type == "llama":
567545
# Use default eagle config
568546
decoder_cls = LlamaDecoderLayer
569-
elif eagle_decoder_type == "kimik2":
547+
elif self.eagle_decoder_type == "kimik2":
570548
decoder_cls = _setup_kimi_k2_decoder()
571549

572-
self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config)
573-
self.eagle_config.eagle_decoder_type = eagle_decoder_type
550+
self.eagle_config = PretrainedConfig.from_dict(config.eagle_architecture_config)
551+
self.eagle_config.eagle_decoder_type = self.eagle_decoder_type
574552
# Hidden size and vocab size must match base model
575553
self.eagle_config.hidden_size = self._base_llm_config.hidden_size
576554
self.eagle_config.vocab_size = self._base_llm_config.vocab_size
@@ -609,14 +587,14 @@ def modify(
609587
self.eagle_module.to(self._base_model.dtype).to(self._get_eagle_device())
610588

611589
# EAGLE-3 auxiliary hidden_states
612-
if (not eagle_offline) and self.eagle_config.use_aux_hidden_state:
590+
if (not self.eagle_offline) and self.eagle_config.use_aux_hidden_state:
613591
self._aux_hidden_states = []
614592
for layer_idx, layer in enumerate(self._base_model.layers):
615593
if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
616594
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)
617595

618596
# delete base model layers for offline training
619-
if eagle_offline:
597+
if self.eagle_offline:
620598
self._base_model._modules.pop("layers")
621599

622600
# NOTE: this is a temporary hack to bypass hf trainer check:

0 commit comments

Comments
 (0)