Skip to content

Commit 17c4461

Browse files
authored
Fix the incompatibility issue caused by top_p=0 when using vllm to inference (#1265) (#1277)
* Fix vLLM top_p=0 handling * remove test file
1 parent 52c5620 commit 17c4461

3 files changed

Lines changed: 22 additions & 19 deletions

File tree

lmms_eval/models/chat/vllm.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,7 @@ def make_one_request(self, request: Instance) -> Tuple[list[dict], dict]:
6969
_gen["max_new_tokens"] = self._select_max_new_tokens(_gen.get("max_new_tokens"))
7070
_gen.setdefault("temperature", 0)
7171
_gen.setdefault("top_p", 0.95)
72-
73-
params = {
74-
"temperature": _gen["temperature"],
75-
"max_tokens": _gen["max_new_tokens"],
76-
"top_p": _gen["top_p"],
77-
}
72+
params = self._build_sampling_params_dict(_gen)
7873

7974
video_kwargs = {
8075
"max_pixels": self.max_pixels,

lmms_eval/models/chat/vllm_generate.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,7 @@ def make_one_request(self, request: Instance) -> Tuple[list[dict], dict]:
9090
_gen["max_new_tokens"] = self._select_max_new_tokens(_gen.get("max_new_tokens"))
9191
_gen.setdefault("temperature", 0)
9292
_gen.setdefault("top_p", 0.95)
93-
94-
params = {
95-
"temperature": _gen["temperature"],
96-
"max_tokens": _gen["max_new_tokens"],
97-
"top_p": _gen["top_p"],
98-
}
93+
params = self._build_sampling_params_dict(_gen)
9994

10095
video_kwargs = {
10196
"max_pixels": self.max_pixels,

lmms_eval/models/simple/vllm.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,25 @@ def _select_max_new_tokens(self, request_max_new_tokens: Any) -> int:
337337
return self.max_new_tokens
338338
return max(request_max_new_tokens, self.max_new_tokens)
339339

340+
@staticmethod
341+
def _normalize_top_p_for_vllm(top_p: Any) -> Any:
342+
if isinstance(top_p, bool):
343+
return top_p
344+
try:
345+
numeric_top_p = float(top_p)
346+
except (TypeError, ValueError):
347+
return top_p
348+
if numeric_top_p == 0.0:
349+
return 1.0
350+
return top_p
351+
352+
def _build_sampling_params_dict(self, gen_kwargs: dict[str, Any]) -> dict[str, Any]:
353+
return {
354+
"max_tokens": gen_kwargs["max_new_tokens"],
355+
"temperature": gen_kwargs["temperature"],
356+
"top_p": self._normalize_top_p_for_vllm(gen_kwargs["top_p"]),
357+
}
358+
340359
def _run_tp_synced(
341360
self,
342361
local_inputs: list[Any],
@@ -456,13 +475,7 @@ def generate_until(self, requests) -> List[str]:
456475
gen_kwargs["max_new_tokens"] = self._select_max_new_tokens(gen_kwargs.get("max_new_tokens"))
457476
gen_kwargs.setdefault("temperature", 0)
458477
gen_kwargs.setdefault("top_p", 0.95)
459-
460-
params = {
461-
"max_tokens": gen_kwargs["max_new_tokens"],
462-
"temperature": gen_kwargs["temperature"],
463-
"top_p": gen_kwargs["top_p"],
464-
}
465-
sampling_params = SamplingParams(**params)
478+
sampling_params = SamplingParams(**self._build_sampling_params_dict(gen_kwargs))
466479

467480
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
468481
if None in visuals:

0 commit comments

Comments
 (0)