|
25 | 25 | by stage 2 is correctly quantized in stage 3. |
26 | 26 | """ |
27 | 27 |
|
28 | | -import json |
29 | | - |
30 | 28 | import pytest |
31 | 29 | import safetensors.torch |
32 | 30 | import torch |
33 | | -from _test_utils.examples.run_command import run_example_command |
| 31 | +from _test_utils.examples.run_command import MODELOPT_ROOT, run_example_command |
| 32 | + |
| 33 | +EAGLE3_YAML = str( |
| 34 | + MODELOPT_ROOT / "modelopt_recipes" / "general" / "speculative_decoding" / "eagle3.yaml" |
| 35 | +) |
| 36 | + |
| 37 | +# Tiny EAGLE architecture overrides (dotlist entries) |
| 38 | +_TINY_EAGLE_ARCH = [ |
| 39 | + "eagle.eagle_architecture_config.max_position_embeddings=128", |
| 40 | + "eagle.eagle_architecture_config.num_hidden_layers=1", |
| 41 | + "eagle.eagle_architecture_config.intermediate_size=64", |
| 42 | + "eagle.eagle_architecture_config.num_attention_heads=2", |
| 43 | + "eagle.eagle_architecture_config.num_key_value_heads=2", |
| 44 | + "eagle.eagle_architecture_config.head_dim=64", |
| 45 | +] |
34 | 46 |
|
35 | 47 |
|
36 | 48 | @pytest.fixture(scope="module") |
@@ -72,49 +84,27 @@ def test_collect_hidden_states(tiny_llama_path, tiny_daring_anteater_path, offli |
72 | 84 |
|
73 | 85 | def test_offline_eagle_training(tiny_llama_path, tiny_daring_anteater_path, offline_ptq_dirs): |
74 | 86 | """Stage 2: train an EAGLE3 draft model using the offline hidden states.""" |
75 | | - tiny_eagle_config = { |
76 | | - "max_position_embeddings": 128, |
77 | | - "num_hidden_layers": 1, |
78 | | - "intermediate_size": 64, |
79 | | - "num_attention_heads": 2, |
80 | | - "num_key_value_heads": 2, |
81 | | - "head_dim": 64, |
82 | | - } |
83 | | - config_file = offline_ptq_dirs["eagle_ckpt"] / "tiny_eagle_config.json" |
84 | | - with open(config_file, "w") as f: |
85 | | - json.dump(tiny_eagle_config, f) |
| 87 | + output_dir = offline_ptq_dirs["eagle_ckpt"] / "trained" |
| 88 | + |
| 89 | + overrides = [ |
| 90 | + f"model.model_name_or_path={tiny_llama_path}", |
| 91 | + f"data.data_path={tiny_daring_anteater_path}", |
| 92 | + f"data.offline_data_path={offline_ptq_dirs['hidden_states']}", |
| 93 | + f"training.output_dir={output_dir}", |
| 94 | + "training.num_train_epochs=1", |
| 95 | + "training.learning_rate=1e-5", |
| 96 | + "training.training_seq_len=64", |
| 97 | + "training.save_steps=1", |
| 98 | + *_TINY_EAGLE_ARCH, |
| 99 | + ] |
86 | 100 |
|
87 | 101 | run_example_command( |
88 | | - [ |
89 | | - "./launch_train.sh", |
90 | | - "--model", |
91 | | - tiny_llama_path, |
92 | | - "--data", |
93 | | - str(tiny_daring_anteater_path), |
94 | | - "--offline-data", |
95 | | - str(offline_ptq_dirs["hidden_states"]), |
96 | | - "--num_epochs", |
97 | | - "1", |
98 | | - "--lr", |
99 | | - "1e-5", |
100 | | - "--mode", |
101 | | - "eagle3", |
102 | | - "--eagle_config", |
103 | | - str(config_file), |
104 | | - "--output_dir", |
105 | | - str(offline_ptq_dirs["eagle_ckpt"] / "trained"), |
106 | | - "--training_seq_len", |
107 | | - "64", |
108 | | - "--save_steps", |
109 | | - "1", |
110 | | - ], |
| 102 | + ["./launch_train.sh", "--config", EAGLE3_YAML, *overrides], |
111 | 103 | "speculative_decoding", |
112 | 104 | setup_free_port=True, |
113 | 105 | ) |
114 | 106 |
|
115 | | - assert (offline_ptq_dirs["eagle_ckpt"] / "trained").exists(), ( |
116 | | - "EAGLE training did not produce an output directory" |
117 | | - ) |
| 107 | + assert output_dir.exists(), "EAGLE training did not produce an output directory" |
118 | 108 |
|
119 | 109 |
|
120 | 110 | def test_offline_ptq(offline_ptq_dirs): |
|
0 commit comments