Skip to content

Commit 2bc8e4b

Browse files
committed
add test
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 719c6d5 commit 2bc8e4b

1 file changed

Lines changed: 114 additions & 0 deletions

File tree

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Parity test for compute_hidden_states_vllm.py vs compute_hidden_states_hf.py.
17+
18+
Runs both example scripts on the shared tiny dataset and asserts the per-conversation
19+
.pt outputs agree. input_ids / loss_mask must match exactly; hidden_states and
20+
aux_hidden_states are compared with cosine similarity because vLLM and HF use
21+
different attention kernels and bf16 rounding will differ.
22+
"""
23+
24+
import pytest
25+
import torch
26+
from _test_utils.examples.run_command import run_example_command
27+
28+
pytest.importorskip("vllm")
29+
30+
COS_THRESHOLD = 0.95
31+
32+
33+
def _cos_sim(a: torch.Tensor, b: torch.Tensor) -> float:
34+
a = a.flatten().float()
35+
b = b.flatten().float()
36+
return ((a @ b) / (a.norm() * b.norm() + 1e-12)).item()
37+
38+
39+
@pytest.fixture(scope="module")
40+
def hidden_states_dirs(tmp_path_factory):
41+
return {
42+
"hf": tmp_path_factory.mktemp("hf_hidden_states"),
43+
"vllm": tmp_path_factory.mktemp("vllm_hidden_states"),
44+
}
45+
46+
47+
def test_vllm_hf_parity(tiny_llama_path, tiny_conversations_path, hidden_states_dirs):
48+
common_args = [
49+
"--model",
50+
tiny_llama_path,
51+
"--input-data",
52+
str(tiny_conversations_path),
53+
"--max-seq-len",
54+
"32",
55+
"--debug-max-num-conversations",
56+
"2",
57+
]
58+
59+
run_example_command(
60+
[
61+
"python",
62+
"collect_hidden_states/compute_hidden_states_hf.py",
63+
*common_args,
64+
"--output-dir",
65+
str(hidden_states_dirs["hf"]),
66+
],
67+
"speculative_decoding",
68+
)
69+
70+
run_example_command(
71+
[
72+
"python",
73+
"collect_hidden_states/compute_hidden_states_vllm.py",
74+
*common_args,
75+
"--output-dir",
76+
str(hidden_states_dirs["vllm"]),
77+
"--min-seq-len",
78+
"1",
79+
"--gpu-memory-util",
80+
"0.3",
81+
"--enforce-eager",
82+
],
83+
"speculative_decoding",
84+
)
85+
86+
hf_files = sorted(hidden_states_dirs["hf"].glob("*.pt"))
87+
vllm_files = sorted(hidden_states_dirs["vllm"].glob("*.pt"))
88+
assert hf_files, "HF stage produced no .pt files"
89+
assert vllm_files, "vLLM stage produced no .pt files"
90+
assert {f.name for f in hf_files} == {f.name for f in vllm_files}, (
91+
"HF and vLLM produced different conversation IDs"
92+
)
93+
94+
for f_hf in hf_files:
95+
f_vl = hidden_states_dirs["vllm"] / f_hf.name
96+
pt_hf = torch.load(f_hf, map_location="cpu", weights_only=False)
97+
pt_vl = torch.load(f_vl, map_location="cpu", weights_only=False)
98+
99+
assert torch.equal(pt_hf["input_ids"], pt_vl["input_ids"]), (
100+
f"input_ids mismatch in {f_hf.name}"
101+
)
102+
assert torch.equal(pt_hf["loss_mask"], pt_vl["loss_mask"]), (
103+
f"loss_mask mismatch in {f_hf.name}"
104+
)
105+
106+
for key in ("hidden_states", "aux_hidden_states"):
107+
h_hf, h_vl = pt_hf[key], pt_vl[key]
108+
assert h_hf.shape == h_vl.shape, (
109+
f"{key} shape mismatch in {f_hf.name}: {tuple(h_hf.shape)} vs {tuple(h_vl.shape)}"
110+
)
111+
cs = _cos_sim(h_hf, h_vl)
112+
assert cs >= COS_THRESHOLD, (
113+
f"{key} cosine similarity {cs:.4f} below {COS_THRESHOLD} in {f_hf.name}"
114+
)

0 commit comments

Comments
 (0)