@@ -337,6 +337,25 @@ def _select_max_new_tokens(self, request_max_new_tokens: Any) -> int:
337337 return self .max_new_tokens
338338 return max (request_max_new_tokens , self .max_new_tokens )
339339
340+ @staticmethod
341+ def _normalize_top_p_for_vllm (top_p : Any ) -> Any :
342+ if isinstance (top_p , bool ):
343+ return top_p
344+ try :
345+ numeric_top_p = float (top_p )
346+ except (TypeError , ValueError ):
347+ return top_p
348+ if numeric_top_p == 0.0 :
349+ return 1.0
350+ return top_p
351+
352+ def _build_sampling_params_dict (self , gen_kwargs : dict [str , Any ]) -> dict [str , Any ]:
353+ return {
354+ "max_tokens" : gen_kwargs ["max_new_tokens" ],
355+ "temperature" : gen_kwargs ["temperature" ],
356+ "top_p" : self ._normalize_top_p_for_vllm (gen_kwargs ["top_p" ]),
357+ }
358+
340359 def _run_tp_synced (
341360 self ,
342361 local_inputs : list [Any ],
@@ -456,13 +475,7 @@ def generate_until(self, requests) -> List[str]:
456475 gen_kwargs ["max_new_tokens" ] = self ._select_max_new_tokens (gen_kwargs .get ("max_new_tokens" ))
457476 gen_kwargs .setdefault ("temperature" , 0 )
458477 gen_kwargs .setdefault ("top_p" , 0.95 )
459-
460- params = {
461- "max_tokens" : gen_kwargs ["max_new_tokens" ],
462- "temperature" : gen_kwargs ["temperature" ],
463- "top_p" : gen_kwargs ["top_p" ],
464- }
465- sampling_params = SamplingParams (** params )
478+ sampling_params = SamplingParams (** self ._build_sampling_params_dict (gen_kwargs ))
466479
467480 visuals = [doc_to_visual (self .task_dict [task ][split ][doc_id ])]
468481 if None in visuals :
0 commit comments