-
Notifications
You must be signed in to change notification settings - Fork 378
Expand file tree
/
Copy pathtest_eagle_offline_ptq.py
More file actions
143 lines (121 loc) · 5.35 KB
/
test_eagle_offline_ptq.py
File metadata and controls
143 lines (121 loc) · 5.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end CI test for the offline speculative decoding PTQ workflow.
Covers the three-stage pipeline:
1. Collect hidden states from the base model → .pt files
2. Train an offline EAGLE draft model → ModelOpt checkpoint
3. PTQ the offline checkpoint → quantized export
Running all three stages in sequence validates that the data format produced
by stage 1 is correctly consumed by stage 2 and that the checkpoint produced
by stage 2 is correctly quantized in stage 3.
"""
import pytest
import safetensors.torch
import torch
from _test_utils.examples.run_command import MODELOPT_ROOT, run_example_command
EAGLE3_YAML = str(
MODELOPT_ROOT / "modelopt_recipes" / "general" / "speculative_decoding" / "eagle3.yaml"
)
# Tiny EAGLE architecture overrides (dotlist entries)
_TINY_EAGLE_ARCH = [
"eagle.eagle_architecture_config.max_position_embeddings=128",
"eagle.eagle_architecture_config.num_hidden_layers=1",
"eagle.eagle_architecture_config.intermediate_size=64",
"eagle.eagle_architecture_config.num_attention_heads=2",
"eagle.eagle_architecture_config.num_key_value_heads=2",
"eagle.eagle_architecture_config.head_dim=64",
]
@pytest.fixture(scope="module")
def offline_ptq_dirs(tmp_path_factory):
"""Shared output directories for all stages."""
return {
"hidden_states": tmp_path_factory.mktemp("hidden_states"),
"eagle_ckpt": tmp_path_factory.mktemp("eagle_ckpt"),
"ptq_export": tmp_path_factory.mktemp("ptq_export"),
}
def test_collect_hidden_states(tiny_llama_path, tiny_conversations_path, offline_ptq_dirs):
"""Stage 1: generate .pt hidden state files from the base model."""
run_example_command(
[
"python",
"collect_hidden_states/compute_hidden_states_hf.py",
"--model",
tiny_llama_path,
"--input-data",
str(tiny_conversations_path),
"--output-dir",
str(offline_ptq_dirs["hidden_states"]),
"--debug-max-num-conversations",
"2",
"--max-seq-len",
"32",
],
"speculative_decoding",
)
pt_files = list(offline_ptq_dirs["hidden_states"].glob("*.pt"))
assert len(pt_files) > 0, "No .pt files generated by compute_hidden_states_hf.py"
# Validate the format expected by OfflineSupervisedDataset
sample = torch.load(str(pt_files[0]))
assert "input_ids" in sample, "Missing 'input_ids' in .pt file"
assert "hidden_states" in sample, "Missing 'hidden_states' in .pt file"
def test_offline_eagle_training(tiny_llama_path, tiny_daring_anteater_path, offline_ptq_dirs):
"""Stage 2: train an EAGLE3 draft model using the offline hidden states."""
output_dir = offline_ptq_dirs["eagle_ckpt"] / "trained"
overrides = [
f"model.model_name_or_path={tiny_llama_path}",
f"data.data_path={tiny_daring_anteater_path}",
f"data.offline_data_path={offline_ptq_dirs['hidden_states']}",
f"training.output_dir={output_dir}",
"training.num_train_epochs=1",
"training.learning_rate=1e-5",
"training.training_seq_len=64",
"training.save_steps=1",
*_TINY_EAGLE_ARCH,
]
run_example_command(
["./launch_train.sh", "--config", EAGLE3_YAML, *overrides],
"speculative_decoding",
setup_free_port=True,
)
assert output_dir.exists(), "EAGLE training did not produce an output directory"
def test_offline_ptq(offline_ptq_dirs):
"""Stage 3: run PTQ on the offline EAGLE checkpoint using the hidden state dataset."""
run_example_command(
[
"python",
"hf_ptq.py",
"--pyt_ckpt_path",
str(offline_ptq_dirs["eagle_ckpt"] / "trained"),
"--qformat",
"fp8",
"--calib_size",
"2",
"--batch_size",
"1",
"--specdec_offline_dataset",
str(offline_ptq_dirs["hidden_states"]),
"--export_path",
str(offline_ptq_dirs["ptq_export"]),
],
"llm_ptq",
)
# Verify the exported checkpoint exists and has the expected EAGLE keys
export_dir = offline_ptq_dirs["ptq_export"]
assert (export_dir / "model.safetensors").exists(), "PTQ export missing model.safetensors"
assert (export_dir / "config.json").exists(), "PTQ export missing config.json"
from modelopt.torch.export.plugins.hf_spec_export import LLAMA_EAGLE_SINGLE_LAYER
state_dict = safetensors.torch.load_file(export_dir / "model.safetensors")
for key in LLAMA_EAGLE_SINGLE_LAYER["required"] - {"fc", "layers.0.hidden_norm"}:
assert f"{key}.weight" in state_dict, f"Missing key '{key}.weight' in exported state dict"