Skip to content

Commit f488231

Browse files
committed
reorg files
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 26ae8da commit f488231

13 files changed

Lines changed: 777 additions & 750 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ repos:
9999
modelopt/torch/quantization/plugins/attention.py|
100100
modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py|
101101
modelopt/torch/speculative/eagle/utils.py|
102-
modelopt/torch/speculative/plugins/transformers.py|
102+
modelopt/torch/speculative/plugins/hf_medusa.py|
103103
modelopt/torch/utils/plugins/megatron_mmlu.py|
104104
examples/chained_optimizations/bert_prune_distill_quantize.py|
105105
examples/deepseek/quantize_to_nvfp4.py|

examples/speculative_decoding/eagle_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def patched_templated_attn(*args, **kwargs):
358358
original_op = args[2]
359359

360360
# This patch is only enabled for eagle model by context manager, not base model.
361-
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
361+
patch_enbabled = modelopt.torch.speculative.plugins.hf_eagle.ENABLE_CP_TTT_PATCH
362362

363363
if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
364364
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")

examples/speculative_decoding/scripts/ar_validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from transformers import AutoTokenizer
2828

2929
import modelopt.torch.opt as mto
30-
from modelopt.torch.speculative.plugins.transformers import HFARValidation
30+
from modelopt.torch.speculative.plugins.hf_eagle import HFARValidation
3131
from modelopt.torch.speculative.utils import load_vlm_or_llm
3232

3333
mto.enable_huggingface_checkpointing()

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def _export_config(self):
171171
template_config = deepcopy(template_config)
172172

173173
def _get_config_from_draft_or_base(key: str, model: nn.Module):
174-
if getattr(model._draft_model_config, key, None) is not None:
175-
return getattr(model._draft_model_config, key)
174+
if getattr(model.eagle_config, key, None) is not None:
175+
return getattr(model.eagle_config, key)
176176
elif getattr(model.config, key, None) is not None:
177177
return getattr(model.config, key)
178178
else:

modelopt/torch/speculative/eagle/default_config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737
"use_aux_hidden_state": False,
3838
"eagle_aux_hidden_state_layer_ids": [],
3939
"use_mtp_layernorm": False,
40-
"parallel_draft_step": 1,
41-
"parallel_draft_heads_num_layers": 1,
4240
"has_lm_head": False,
4341
"head_dim": 128,
4442
}
@@ -107,7 +105,5 @@
107105
"use_aux_hidden_state": True,
108106
"eagle_aux_hidden_state_layer_ids": [],
109107
"use_mtp_layernorm": False,
110-
"parallel_draft_step": 1,
111-
"parallel_draft_heads_num_layers": 1,
112108
"has_lm_head": False,
113109
}

modelopt/torch/speculative/plugins/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Please check out the source code of this module for examples of how plugins work and how you can
1919
write your own one. Currently, we support plugins for
2020
21-
- :meth:`transformers<modelopt.torch.speculative.plugins.transformers>`
21+
- :meth:`hf_eagle<modelopt.torch.speculative.plugins.hf_eagle>`
2222
"""
2323

2424
from modelopt.torch.utils import import_plugin
@@ -31,4 +31,5 @@
3131

3232
with import_plugin("transformers"):
3333
from .hf_dflash import *
34-
from .transformers import *
34+
from .hf_eagle import *
35+
from .hf_medusa import *

0 commit comments

Comments
 (0)