@@ -357,23 +357,28 @@ class SamplingParams(ctypes.Structure):
357357
358358 def init (self , tokenizer , ** kwargs ):
359359 super ().__init__ ()
360- self .best_of = kwargs .get ("best_of" , 1 )
361- self .n = kwargs .get ("n" , self .best_of )
362- self .do_sample = kwargs .get ("do_sample" , SamplingParams ._do_sample )
363- self .presence_penalty = kwargs .get ("presence_penalty" , SamplingParams ._presence_penalty )
364- self .frequency_penalty = kwargs .get ("frequency_penalty" , SamplingParams ._frequency_penalty )
365- self .repetition_penalty = kwargs .get ("repetition_penalty" , SamplingParams ._repetition_penalty )
366- self .temperature = kwargs .get ("temperature" , SamplingParams ._temperature )
367- self .top_p = kwargs .get ("top_p" , SamplingParams ._top_p )
368- self .top_k = kwargs .get ("top_k" , SamplingParams ._top_k )
369- self .ignore_eos = kwargs .get ("ignore_eos" , False )
370- self .min_pixels = kwargs .get ("min_pixels" , - 1 )
371- self .max_pixels = kwargs .get ("max_pixels" , - 1 )
372- self .max_new_tokens = kwargs .get ("max_new_tokens" , 16 )
373- self .min_new_tokens = kwargs .get ("min_new_tokens" , 1 )
374- self .input_penalty = kwargs .get ("input_penalty" , DEFAULT_INPUT_PENALTY )
375- self .group_request_id = kwargs .get ("group_request_id" , - 1 )
376- self .suggested_dp_index = kwargs .get ("suggested_dp_index" , - 1 )
360+
361+ def _get (key , default ):
362+ v = kwargs .get (key )
363+ return v if v is not None else default
364+
365+ self .best_of = _get ("best_of" , 1 )
366+ self .n = _get ("n" , self .best_of )
367+ self .do_sample = _get ("do_sample" , SamplingParams ._do_sample )
368+ self .presence_penalty = _get ("presence_penalty" , SamplingParams ._presence_penalty )
369+ self .frequency_penalty = _get ("frequency_penalty" , SamplingParams ._frequency_penalty )
370+ self .repetition_penalty = _get ("repetition_penalty" , SamplingParams ._repetition_penalty )
371+ self .temperature = _get ("temperature" , SamplingParams ._temperature )
372+ self .top_p = _get ("top_p" , SamplingParams ._top_p )
373+ self .top_k = _get ("top_k" , SamplingParams ._top_k )
374+ self .ignore_eos = _get ("ignore_eos" , False )
375+ self .min_pixels = _get ("min_pixels" , - 1 )
376+ self .max_pixels = _get ("max_pixels" , - 1 )
377+ self .max_new_tokens = _get ("max_new_tokens" , 16 )
378+ self .min_new_tokens = _get ("min_new_tokens" , 1 )
379+ self .input_penalty = _get ("input_penalty" , DEFAULT_INPUT_PENALTY )
380+ self .group_request_id = _get ("group_request_id" , - 1 )
381+ self .suggested_dp_index = _get ("suggested_dp_index" , - 1 )
377382
378383 self .skip_special_tokens = kwargs .get ("skip_special_tokens" , SKIP_SPECIAL_TOKENS )
379384 self .disable_prompt_cache = kwargs .get ("disable_prompt_cache" , False )
0 commit comments