Skip to content

Commit 027ee36

Browse files
committed
fix
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent adb9927 commit 027ee36

File tree

5 files changed

+23
-20
lines changed

5 files changed

+23
-20
lines changed

examples/speculative_decoding/eagle_utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -167,21 +167,6 @@ def make_eagle_supervised_data_module(
167167
}
168168

169169

170-
def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs):
171-
"""Load a VLM or LLM with kwargs. Returns the model and model config."""
172-
model_config = transformers.AutoConfig.from_pretrained(
173-
model_name_or_path, trust_remote_code=True
174-
)
175-
if "vl" in model_config.model_type.lower():
176-
model_cls = transformers.AutoModelForVision2Seq
177-
else:
178-
model_cls = transformers.AutoModelForCausalLM
179-
180-
return model_config, model_cls.from_pretrained(
181-
model_name_or_path, trust_remote_code=True, **kwargs
182-
)
183-
184-
185170
class EagleTrainerWithAccLog(Trainer):
186171
"""Wrapper around Trainer that logs training accuracy."""
187172

examples/speculative_decoding/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from eagle_utils import (
4141
EagleTrainerWithAccLog,
4242
EagleTrainingPlot,
43-
load_vlm_or_llm_with_kwargs,
4443
make_eagle_supervised_data_module,
4544
patch_ring_attention_for_ttt,
4645
)
@@ -49,6 +48,7 @@
4948

5049
import modelopt.torch.opt as mto
5150
import modelopt.torch.speculative as mtsp
51+
from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs
5252
from modelopt.torch.utils import print_rank_0
5353

5454
torch.manual_seed(0)

examples/speculative_decoding/scripts/export_hf_checkpoint.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121

2222
import modelopt.torch.opt as mto
2323
from modelopt.torch.export import export_hf_checkpoint
24-
25-
from ..eagle_utils import load_vlm_or_llm_with_kwargs
24+
from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs
2625

2726

2827
def parse_args():

modelopt/torch/speculative/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import torch
2727
import torch.distributed
28+
import transformers
2829
from huggingface_hub import snapshot_download
2930
from torch import nn
3031
from torch.nn.attention import SDPBackend, sdpa_kernel
@@ -471,3 +472,16 @@ def enable_cp_ttt_patch():
471472
yield
472473
finally:
473474
modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False
475+
476+
477+
def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs):
478+
"""Load a VLM or LLM with kwargs. Returns the model and model config."""
479+
model_config = transformers.AutoConfig.from_pretrained(
480+
model_name_or_path, trust_remote_code=True
481+
)
482+
if "vl" in model_config.model_type.lower():
483+
model_cls = transformers.AutoModelForVision2Seq
484+
else:
485+
model_cls = transformers.AutoModelForCausalLM
486+
487+
return model_config, model_cls.from_pretrained(model_name_or_path, **kwargs)

tests/examples/speculative_decoding/test_eagle.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import json
17+
import os
1718

1819
import pytest
1920
import safetensors.torch
@@ -54,6 +55,10 @@ def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path, draft
5455
"speculative_decoding",
5556
)
5657

58+
model_name = os.path.basename(os.path.normpath(tiny_llama_path))
59+
d2t = torch.load(os.path.join(draft_vocab_cache_dir, model_name, "d2t.pt"))
60+
assert d2t.shape[0] == 100, f"Expected draft vocab size 100, got {d2t.shape[0]}"
61+
5762

5863
# fmt: off
5964
@pytest.mark.parametrize("cp_size", [1, 2])
@@ -102,8 +107,8 @@ def test_ar_validate(eagle_output_dir):
102107
[
103108
"python", "./scripts/ar_validate.py",
104109
"--model_path", eagle_output_dir / "eagle-tinyllama-cp1",
105-
"--osl", "20",
106-
"--num_samples", "10",
110+
"--osl", "10",
111+
"--num_samples", "5",
107112
"--steps", "3"
108113
],
109114
"speculative_decoding",

0 commit comments

Comments
 (0)