Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions lmms_eval/models/chat/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,7 @@ def make_one_request(self, request: Instance) -> Tuple[list[dict], dict]:
_gen["max_new_tokens"] = self._select_max_new_tokens(_gen.get("max_new_tokens"))
_gen.setdefault("temperature", 0)
_gen.setdefault("top_p", 0.95)

params = {
"temperature": _gen["temperature"],
"max_tokens": _gen["max_new_tokens"],
"top_p": _gen["top_p"],
}
params = self._build_sampling_params_dict(_gen)

video_kwargs = {
"max_pixels": self.max_pixels,
Expand Down
7 changes: 1 addition & 6 deletions lmms_eval/models/chat/vllm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,7 @@ def make_one_request(self, request: Instance) -> Tuple[list[dict], dict]:
_gen["max_new_tokens"] = self._select_max_new_tokens(_gen.get("max_new_tokens"))
_gen.setdefault("temperature", 0)
_gen.setdefault("top_p", 0.95)

params = {
"temperature": _gen["temperature"],
"max_tokens": _gen["max_new_tokens"],
"top_p": _gen["top_p"],
}
params = self._build_sampling_params_dict(_gen)

video_kwargs = {
"max_pixels": self.max_pixels,
Expand Down
27 changes: 20 additions & 7 deletions lmms_eval/models/simple/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,25 @@ def _select_max_new_tokens(self, request_max_new_tokens: Any) -> int:
return self.max_new_tokens
return max(request_max_new_tokens, self.max_new_tokens)

@staticmethod
def _normalize_top_p_for_vllm(top_p: Any) -> Any:
if isinstance(top_p, bool):
return top_p
try:
numeric_top_p = float(top_p)
except (TypeError, ValueError):
return top_p
if numeric_top_p == 0.0:
return 1.0
return top_p

def _build_sampling_params_dict(self, gen_kwargs: dict[str, Any]) -> dict[str, Any]:
return {
"max_tokens": gen_kwargs["max_new_tokens"],
"temperature": gen_kwargs["temperature"],
"top_p": self._normalize_top_p_for_vllm(gen_kwargs["top_p"]),
}

def _run_tp_synced(
self,
local_inputs: list[Any],
Expand Down Expand Up @@ -456,13 +475,7 @@ def generate_until(self, requests) -> List[str]:
gen_kwargs["max_new_tokens"] = self._select_max_new_tokens(gen_kwargs.get("max_new_tokens"))
gen_kwargs.setdefault("temperature", 0)
gen_kwargs.setdefault("top_p", 0.95)

params = {
"max_tokens": gen_kwargs["max_new_tokens"],
"temperature": gen_kwargs["temperature"],
"top_p": gen_kwargs["top_p"],
}
sampling_params = SamplingParams(**params)
sampling_params = SamplingParams(**self._build_sampling_params_dict(gen_kwargs))

visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
if None in visuals:
Expand Down
Loading