diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py index d9d5b4100..8d924069b 100644 --- a/src/lighteval/models/abstract_model.py +++ b/src/lighteval/models/abstract_model.py @@ -133,24 +133,24 @@ def _parse_args(args: str) -> dict: 'model': {'model_name': 'gpt2', 'use_cache': True, 'generation_parameters': {'temperature': 0.7}}, } """ - # Looking for generation_parameters in the model_args - generation_parameters_dict = None - pattern = re.compile(r"(\w+)=(\{.*\}|[^,]+)") - matches = pattern.findall(args) - for key, value in matches: - key = key.strip() - if key == "generation_parameters": - # Keys must be quoted (since they are strings) - gen_params = re.sub(r"(\w+):", r'"\1":', value) - # for k, v where v are strings, we quote them too - gen_params = re.sub(r":\s*([A-Za-z_][\w.-]*)\s*(?=[,}])", r':"\1"', gen_params) - generation_parameters_dict = json.loads(gen_params) - - args = re.sub(r"generation_parameters=\{.*?\},?", "", args).strip(",") - model_config = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in args.split(",")} - - if generation_parameters_dict is not None: - model_config["generation_parameters"] = generation_parameters_dict + # Extract all key={...} dict parameters (e.g. generation_parameters, extra_body) + dict_params = {} + dict_pattern = re.compile(r"(\w+)=(\{[^}]*\})") + for match in dict_pattern.finditer(args): + key = match.group(1) + value = match.group(2) + # Keys must be quoted (since they are strings) + parsed = re.sub(r"(\w+):", r'"\1":', value) + # for k, v where v are strings, we quote them too + parsed = re.sub(r":\s*([A-Za-z_][\w.-]*)\s*(?=[,}])", r':"\1"', parsed) + dict_params[key] = json.loads(parsed) + + # Strip dict params from args string, then parse remaining simple key=value pairs + remaining = dict_pattern.sub("", args) + remaining = re.sub(r",+", ",", remaining).strip(",") + model_config = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in remaining.split(",") if k} + + model_config.update(dict_params) return model_config diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index 87332d1d7..18f95eab1 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -131,6 +131,11 @@ class LiteLLMModelConfig(ModelConfig): api_retry_multiplier: float = 2.0 timeout: float | None = None + # Extra body parameters passed through to the API request body. + # Useful for provider-specific params not in the standard OpenAI API, + # e.g., vLLM supports top_k, min_p, repetition_penalty via extra_body. + extra_body: dict | None = None + @requires("litellm") class LiteLLMClient(LightevalModel): @@ -153,6 +158,7 @@ def __init__(self, config: LiteLLMModelConfig) -> None: self.API_RETRY_SLEEP = config.api_retry_sleep self.API_RETRY_MULTIPLIER = config.api_retry_multiplier self.timeout = config.timeout + self.extra_body = config.extra_body self._tokenizer = encode self.pairwise_tokenization = False @@ -165,6 +171,12 @@ def __init__(self, config: LiteLLMModelConfig) -> None: # Initialize cache for tokenization and predictions self._cache = SampleCache(config) + # Log sampling params so the user can verify what will be sent to the server + sampling_params = self.generation_parameters.to_litellm_dict() + if self.extra_body: + sampling_params["extra_body"] = self.extra_body + logger.info(f"Sampling parameters: {sampling_params}") + def _prepare_stop_sequence(self, stop_sequence): """Prepare and validate stop sequence.""" if self.provider == "anthropic": @@ -210,6 +222,7 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se "n": num_samples, "caching": True, "timeout": self.timeout, + "extra_body": self.extra_body, } if "o1" in self.model: