|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Parity test for compute_hidden_states_vllm.py vs compute_hidden_states_hf.py. |
| 17 | +
|
| 18 | +Runs both example scripts on the shared tiny dataset and asserts the per-conversation |
| 19 | +.pt outputs agree. input_ids / loss_mask must match exactly; hidden_states and |
| 20 | +aux_hidden_states are compared with cosine similarity because vLLM and HF use |
| 21 | +different attention kernels and bf16 rounding will differ. |
| 22 | +""" |
| 23 | + |
| 24 | +import pytest |
| 25 | +import torch |
| 26 | +from _test_utils.examples.run_command import run_example_command |
| 27 | + |
| 28 | +pytest.importorskip("vllm") |
| 29 | + |
| 30 | +COS_THRESHOLD = 0.95 |
| 31 | + |
| 32 | + |
| 33 | +def _cos_sim(a: torch.Tensor, b: torch.Tensor) -> float: |
| 34 | + a = a.flatten().float() |
| 35 | + b = b.flatten().float() |
| 36 | + return ((a @ b) / (a.norm() * b.norm() + 1e-12)).item() |
| 37 | + |
| 38 | + |
| 39 | +@pytest.fixture(scope="module") |
| 40 | +def hidden_states_dirs(tmp_path_factory): |
| 41 | + return { |
| 42 | + "hf": tmp_path_factory.mktemp("hf_hidden_states"), |
| 43 | + "vllm": tmp_path_factory.mktemp("vllm_hidden_states"), |
| 44 | + } |
| 45 | + |
| 46 | + |
| 47 | +def test_vllm_hf_parity(tiny_llama_path, tiny_conversations_path, hidden_states_dirs): |
| 48 | + common_args = [ |
| 49 | + "--model", |
| 50 | + tiny_llama_path, |
| 51 | + "--input-data", |
| 52 | + str(tiny_conversations_path), |
| 53 | + "--max-seq-len", |
| 54 | + "32", |
| 55 | + "--debug-max-num-conversations", |
| 56 | + "2", |
| 57 | + ] |
| 58 | + |
| 59 | + run_example_command( |
| 60 | + [ |
| 61 | + "python", |
| 62 | + "collect_hidden_states/compute_hidden_states_hf.py", |
| 63 | + *common_args, |
| 64 | + "--output-dir", |
| 65 | + str(hidden_states_dirs["hf"]), |
| 66 | + ], |
| 67 | + "speculative_decoding", |
| 68 | + ) |
| 69 | + |
| 70 | + run_example_command( |
| 71 | + [ |
| 72 | + "python", |
| 73 | + "collect_hidden_states/compute_hidden_states_vllm.py", |
| 74 | + *common_args, |
| 75 | + "--output-dir", |
| 76 | + str(hidden_states_dirs["vllm"]), |
| 77 | + "--min-seq-len", |
| 78 | + "1", |
| 79 | + "--gpu-memory-util", |
| 80 | + "0.3", |
| 81 | + "--enforce-eager", |
| 82 | + ], |
| 83 | + "speculative_decoding", |
| 84 | + ) |
| 85 | + |
| 86 | + hf_files = sorted(hidden_states_dirs["hf"].glob("*.pt")) |
| 87 | + vllm_files = sorted(hidden_states_dirs["vllm"].glob("*.pt")) |
| 88 | + assert hf_files, "HF stage produced no .pt files" |
| 89 | + assert vllm_files, "vLLM stage produced no .pt files" |
| 90 | + assert {f.name for f in hf_files} == {f.name for f in vllm_files}, ( |
| 91 | + "HF and vLLM produced different conversation IDs" |
| 92 | + ) |
| 93 | + |
| 94 | + for f_hf in hf_files: |
| 95 | + f_vl = hidden_states_dirs["vllm"] / f_hf.name |
| 96 | + pt_hf = torch.load(f_hf, map_location="cpu", weights_only=False) |
| 97 | + pt_vl = torch.load(f_vl, map_location="cpu", weights_only=False) |
| 98 | + |
| 99 | + assert torch.equal(pt_hf["input_ids"], pt_vl["input_ids"]), ( |
| 100 | + f"input_ids mismatch in {f_hf.name}" |
| 101 | + ) |
| 102 | + assert torch.equal(pt_hf["loss_mask"], pt_vl["loss_mask"]), ( |
| 103 | + f"loss_mask mismatch in {f_hf.name}" |
| 104 | + ) |
| 105 | + |
| 106 | + for key in ("hidden_states", "aux_hidden_states"): |
| 107 | + h_hf, h_vl = pt_hf[key], pt_vl[key] |
| 108 | + assert h_hf.shape == h_vl.shape, ( |
| 109 | + f"{key} shape mismatch in {f_hf.name}: {tuple(h_hf.shape)} vs {tuple(h_vl.shape)}" |
| 110 | + ) |
| 111 | + cs = _cos_sim(h_hf, h_vl) |
| 112 | + assert cs >= COS_THRESHOLD, ( |
| 113 | + f"{key} cosine similarity {cs:.4f} below {COS_THRESHOLD} in {f_hf.name}" |
| 114 | + ) |
0 commit comments