-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Expand file tree
/
Copy pathinterface.py
More file actions
230 lines (200 loc) · 9.02 KB
/
Copy pathinterface.py
File metadata and controls
230 lines (200 loc) · 9.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
import random
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Union
import numpy as np
import torch
from tqdm import tqdm
import tensorrt_llm.profiler as profiler
from ..llmapi import RequestOutput
from ..logger import logger
from ..sampling_params import SamplingParams
# Per-request upper bound (seconds) on how long an evaluator waits for a single response before
# failing fast. A stalled or dead executor worker would otherwise block `future.result()`
# indefinitely, turning an evaluation into potential hangs.
# This is a backstop: it is intentionally larger than the executor's stall watchdog
# (`TLLM_EXECUTOR_STALL_TIMEOUT_SECS`, default 300s) so the watchdog's more-informative
# `RequestError` normally surfaces first; no healthy single request should come close to it.
RESULT_WAIT_TIMEOUT_SECS = float(
os.environ.get("TLLM_EVAL_RESULT_TIMEOUT_SECS", "600"))
def get_chat_template_kwargs(
template_owner: Any,
chat_template_kwargs: Optional[dict[str,
Any]] = None) -> dict[str, Any]:
"""Return effective chat template kwargs for evaluation.
Some chat templates, such as Qwen3-family templates, enable a long-form
thinking mode by default. For exact-match style benchmarks, that can consume
the full generation budget before the model reaches its final answer. Keep
reasoning disabled unless the caller explicitly opts in.
"""
effective_kwargs = dict(chat_template_kwargs or {})
owner = getattr(template_owner, "tokenizer", template_owner)
chat_template = getattr(owner, "chat_template", None)
if isinstance(chat_template, str) and "enable_thinking" in chat_template:
effective_kwargs.setdefault("enable_thinking", False)
return effective_kwargs
def get_model_context(llm: Any) -> tuple[str, str]:
"""Return the HF model directory and model type for an LLM object."""
model_dir = getattr(llm, "_hf_model_dir", None) or getattr(
llm, "model", None)
if model_dir is None:
raise ValueError("The LLM object does not expose a model directory.")
config_path = os.path.join(str(model_dir), "config.json")
with open(config_path, "r", encoding="utf-8") as config_file:
config = json.load(config_file)
model_type = config.get("model_type")
if model_type is None:
raise KeyError(f"'model_type' is missing from {config_path}.")
return str(model_dir), str(model_type)
class Evaluator(ABC):
def __init__(self,
random_seed: int = 0,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
output_dir: Optional[str] = None):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
self.apply_chat_template = apply_chat_template
self.fewshot_as_multiturn = fewshot_as_multiturn
self.system_prompt = system_prompt
self.chat_template_kwargs = chat_template_kwargs
self.output_dir = output_dir
@abstractmethod
def generate_samples(self) -> Iterable[tuple]:
raise NotImplementedError()
@abstractmethod
def compute_score(self, outputs: List[RequestOutput], references: List[str],
*auxiliaries) -> float:
raise NotImplementedError()
def do_apply_chat_template(self, llm: Any,
prompt: Union[str, List[dict]]) -> str:
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
else:
messages = prompt
if self.system_prompt is not None:
messages = [{
"role": "system",
"content": self.system_prompt
}] + messages
chat_template_kwargs = get_chat_template_kwargs(
llm.tokenizer, self.chat_template_kwargs)
return llm.tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True,
**chat_template_kwargs)
def _get_sampline_params(self, sampling_params: Optional[SamplingParams],
sampling_args: Optional[dict]) -> SamplingParams:
if sampling_params is None:
sampling_params = SamplingParams()
else:
sampling_params = copy.deepcopy(sampling_params)
if sampling_args is not None:
for key, value in sampling_args.items():
setattr(sampling_params, key, value)
return sampling_params
def evaluate(self,
llm: Any,
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False) -> float:
profiler.start("trtllm exec")
outputs, references, auxiliaries = [], [], []
for prompt, sampling_args, reference, *aux in tqdm(
self.generate_samples(), desc="Submitting requests"):
if self.apply_chat_template:
prompt = self.do_apply_chat_template(llm, prompt)
sampling_params = self._get_sampline_params(sampling_params,
sampling_args)
output = llm.generate_async(
prompt,
sampling_params,
streaming=streaming,
)
outputs.append(output)
references.append(reference)
auxiliaries.append(aux)
results = []
for output in tqdm(outputs, desc="Fetching responses"):
results.append(output.result(timeout=RESULT_WAIT_TIMEOUT_SECS))
if self.output_dir:
dump_inference_results(self.output_dir, results,
getattr(llm, 'tokenizer', None))
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
profiler.reset("trtllm exec")
score = self.compute_score(results, references, *zip(*auxiliaries))
return score
@staticmethod
def command(ctx, *args, **kwargs) -> None:
raise NotImplementedError()
def dump_inference_results(output_dir: str, results: List[dict],
tokenizer: Any):
if not output_dir:
return
os.makedirs(output_dir, exist_ok=True)
# Collect results
results_list = []
for task_id, result in enumerate(results):
output_ids = result.outputs[0].token_ids
output_text = result.outputs[0].text.strip()
input_text = result.prompt.strip()
input_ids = tokenizer.encode(input_text)
results_list.append({
"task_id": task_id,
"input_ids": input_ids,
"output_ids": output_ids,
"input_text": input_text,
"output_text": output_text
})
# Dump token ids
ids_path = os.path.join(output_dir, "dumped_ids.json")
try:
with open(ids_path, "w") as f:
for item in results_list:
data = {
"task_id": item["task_id"],
"input_ids": item["input_ids"],
"output_ids": item["output_ids"],
"input_tokens": len(item["input_ids"]),
"output_tokens": len(item["output_ids"])
}
f.write(json.dumps(data) + "\n")
logger.info(f"Dumped IDs to {ids_path}")
except Exception as e:
logger.warning(f"Failed to dump IDs to {ids_path}: {e}")
# Dump text
text_path = os.path.join(output_dir, "dumped_text.json")
try:
with open(text_path, "w") as f:
for item in results_list:
data = {
"task_id": item["task_id"],
"input_text": item["input_text"],
"output_text": item["output_text"],
"input_len": len(item["input_text"]),
"output_len": len(item["output_text"])
}
f.write(json.dumps(data) + "\n")
logger.info(f"Dumped text to {text_path}")
except Exception as e:
logger.warning(f"Failed to dump text to {text_path}: {e}")