Skip to content

Commit e342300

Browse files
committed
fix: update offline PTQ test to use YAML config CLI
After rebase, launch_train.sh requires --config <yaml> with OmegaConf dotlist overrides instead of individual --flag arguments. Updated test_eagle_offline_ptq.py to match the new interface. Signed-off-by: Ye Yu <yey@nvidia.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent a0f2507 commit e342300

1 file changed

Lines changed: 30 additions & 40 deletions

File tree

tests/examples/speculative_decoding/test_eagle_offline_ptq.py

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,24 @@
2525
by stage 2 is correctly quantized in stage 3.
2626
"""
2727

28-
import json
29-
3028
import pytest
3129
import safetensors.torch
3230
import torch
33-
from _test_utils.examples.run_command import run_example_command
31+
from _test_utils.examples.run_command import MODELOPT_ROOT, run_example_command
32+
33+
EAGLE3_YAML = str(
34+
MODELOPT_ROOT / "modelopt_recipes" / "general" / "speculative_decoding" / "eagle3.yaml"
35+
)
36+
37+
# Tiny EAGLE architecture overrides (dotlist entries)
38+
_TINY_EAGLE_ARCH = [
39+
"eagle.eagle_architecture_config.max_position_embeddings=128",
40+
"eagle.eagle_architecture_config.num_hidden_layers=1",
41+
"eagle.eagle_architecture_config.intermediate_size=64",
42+
"eagle.eagle_architecture_config.num_attention_heads=2",
43+
"eagle.eagle_architecture_config.num_key_value_heads=2",
44+
"eagle.eagle_architecture_config.head_dim=64",
45+
]
3446

3547

3648
@pytest.fixture(scope="module")
@@ -72,49 +84,27 @@ def test_collect_hidden_states(tiny_llama_path, tiny_daring_anteater_path, offli
7284

7385
def test_offline_eagle_training(tiny_llama_path, tiny_daring_anteater_path, offline_ptq_dirs):
7486
"""Stage 2: train an EAGLE3 draft model using the offline hidden states."""
75-
tiny_eagle_config = {
76-
"max_position_embeddings": 128,
77-
"num_hidden_layers": 1,
78-
"intermediate_size": 64,
79-
"num_attention_heads": 2,
80-
"num_key_value_heads": 2,
81-
"head_dim": 64,
82-
}
83-
config_file = offline_ptq_dirs["eagle_ckpt"] / "tiny_eagle_config.json"
84-
with open(config_file, "w") as f:
85-
json.dump(tiny_eagle_config, f)
87+
output_dir = offline_ptq_dirs["eagle_ckpt"] / "trained"
88+
89+
overrides = [
90+
f"model.model_name_or_path={tiny_llama_path}",
91+
f"data.data_path={tiny_daring_anteater_path}",
92+
f"data.offline_data_path={offline_ptq_dirs['hidden_states']}",
93+
f"training.output_dir={output_dir}",
94+
"training.num_train_epochs=1",
95+
"training.learning_rate=1e-5",
96+
"training.training_seq_len=64",
97+
"training.save_steps=1",
98+
*_TINY_EAGLE_ARCH,
99+
]
86100

87101
run_example_command(
88-
[
89-
"./launch_train.sh",
90-
"--model",
91-
tiny_llama_path,
92-
"--data",
93-
str(tiny_daring_anteater_path),
94-
"--offline-data",
95-
str(offline_ptq_dirs["hidden_states"]),
96-
"--num_epochs",
97-
"1",
98-
"--lr",
99-
"1e-5",
100-
"--mode",
101-
"eagle3",
102-
"--eagle_config",
103-
str(config_file),
104-
"--output_dir",
105-
str(offline_ptq_dirs["eagle_ckpt"] / "trained"),
106-
"--training_seq_len",
107-
"64",
108-
"--save_steps",
109-
"1",
110-
],
102+
["./launch_train.sh", "--config", EAGLE3_YAML, *overrides],
111103
"speculative_decoding",
112104
setup_free_port=True,
113105
)
114106

115-
assert (offline_ptq_dirs["eagle_ckpt"] / "trained").exists(), (
116-
"EAGLE training did not produce an output directory"
117-
)
107+
assert output_dir.exists(), "EAGLE training did not produce an output directory"
118108

119109

120110
def test_offline_ptq(offline_ptq_dirs):

0 commit comments

Comments
 (0)