diff --git a/lmms_eval/models/chat/vllm.py b/lmms_eval/models/chat/vllm.py index 24bc1099c..346533eae 100644 --- a/lmms_eval/models/chat/vllm.py +++ b/lmms_eval/models/chat/vllm.py @@ -69,12 +69,7 @@ def make_one_request(self, request: Instance) -> Tuple[list[dict], dict]: _gen["max_new_tokens"] = self._select_max_new_tokens(_gen.get("max_new_tokens")) _gen.setdefault("temperature", 0) _gen.setdefault("top_p", 0.95) - - params = { - "temperature": _gen["temperature"], - "max_tokens": _gen["max_new_tokens"], - "top_p": _gen["top_p"], - } + params = self._build_sampling_params_dict(_gen) video_kwargs = { "max_pixels": self.max_pixels, diff --git a/lmms_eval/models/chat/vllm_generate.py b/lmms_eval/models/chat/vllm_generate.py index 2b4e289f0..9b71d23ac 100644 --- a/lmms_eval/models/chat/vllm_generate.py +++ b/lmms_eval/models/chat/vllm_generate.py @@ -90,12 +90,7 @@ def make_one_request(self, request: Instance) -> Tuple[list[dict], dict]: _gen["max_new_tokens"] = self._select_max_new_tokens(_gen.get("max_new_tokens")) _gen.setdefault("temperature", 0) _gen.setdefault("top_p", 0.95) - - params = { - "temperature": _gen["temperature"], - "max_tokens": _gen["max_new_tokens"], - "top_p": _gen["top_p"], - } + params = self._build_sampling_params_dict(_gen) video_kwargs = { "max_pixels": self.max_pixels, diff --git a/lmms_eval/models/simple/vllm.py b/lmms_eval/models/simple/vllm.py index 2a83c633c..22946997e 100644 --- a/lmms_eval/models/simple/vllm.py +++ b/lmms_eval/models/simple/vllm.py @@ -337,6 +337,25 @@ def _select_max_new_tokens(self, request_max_new_tokens: Any) -> int: return self.max_new_tokens return max(request_max_new_tokens, self.max_new_tokens) + @staticmethod + def _normalize_top_p_for_vllm(top_p: Any) -> Any: + if isinstance(top_p, bool): + return top_p + try: + numeric_top_p = float(top_p) + except (TypeError, ValueError): + return top_p + if numeric_top_p == 0.0: + return 1.0 + return top_p + + def _build_sampling_params_dict(self, gen_kwargs: dict[str, Any]) -> dict[str, Any]: + return { + "max_tokens": gen_kwargs["max_new_tokens"], + "temperature": gen_kwargs["temperature"], + "top_p": self._normalize_top_p_for_vllm(gen_kwargs["top_p"]), + } + def _run_tp_synced( self, local_inputs: list[Any], @@ -456,13 +475,7 @@ def generate_until(self, requests) -> List[str]: gen_kwargs["max_new_tokens"] = self._select_max_new_tokens(gen_kwargs.get("max_new_tokens")) gen_kwargs.setdefault("temperature", 0) gen_kwargs.setdefault("top_p", 0.95) - - params = { - "max_tokens": gen_kwargs["max_new_tokens"], - "temperature": gen_kwargs["temperature"], - "top_p": gen_kwargs["top_p"], - } - sampling_params = SamplingParams(**params) + sampling_params = SamplingParams(**self._build_sampling_params_dict(gen_kwargs)) visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] if None in visuals: