Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/lighteval/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,24 @@ def _parse_args(args: str) -> dict:
'model': {'model_name': 'gpt2', 'use_cache': True, 'generation_parameters': {'temperature': 0.7}},
}
"""
# Looking for generation_parameters in the model_args
generation_parameters_dict = None
pattern = re.compile(r"(\w+)=(\{.*\}|[^,]+)")
matches = pattern.findall(args)
for key, value in matches:
key = key.strip()
if key == "generation_parameters":
# Keys must be quoted (since they are strings)
gen_params = re.sub(r"(\w+):", r'"\1":', value)
# for k, v where v are strings, we quote them too
gen_params = re.sub(r":\s*([A-Za-z_][\w.-]*)\s*(?=[,}])", r':"\1"', gen_params)
generation_parameters_dict = json.loads(gen_params)

args = re.sub(r"generation_parameters=\{.*?\},?", "", args).strip(",")
model_config = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in args.split(",")}

if generation_parameters_dict is not None:
model_config["generation_parameters"] = generation_parameters_dict
# Extract all key={...} dict parameters (e.g. generation_parameters, extra_body)
dict_params = {}
dict_pattern = re.compile(r"(\w+)=(\{[^}]*\})")
for match in dict_pattern.finditer(args):
key = match.group(1)
value = match.group(2)
# Keys must be quoted (since they are strings)
parsed = re.sub(r"(\w+):", r'"\1":', value)
# for k, v where v are strings, we quote them too
parsed = re.sub(r":\s*([A-Za-z_][\w.-]*)\s*(?=[,}])", r':"\1"', parsed)
dict_params[key] = json.loads(parsed)

# Strip dict params from args string, then parse remaining simple key=value pairs
remaining = dict_pattern.sub("", args)
remaining = re.sub(r",+", ",", remaining).strip(",")
model_config = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in remaining.split(",") if k}

model_config.update(dict_params)

return model_config

Expand Down
13 changes: 13 additions & 0 deletions src/lighteval/models/endpoints/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ class LiteLLMModelConfig(ModelConfig):
api_retry_multiplier: float = 2.0
timeout: float | None = None

# Extra body parameters passed through to the API request body.
# Useful for provider-specific params not in the standard OpenAI API,
# e.g., vLLM supports top_k, min_p, repetition_penalty via extra_body.
extra_body: dict | None = None


@requires("litellm")
class LiteLLMClient(LightevalModel):
Expand All @@ -153,6 +158,7 @@ def __init__(self, config: LiteLLMModelConfig) -> None:
self.API_RETRY_SLEEP = config.api_retry_sleep
self.API_RETRY_MULTIPLIER = config.api_retry_multiplier
self.timeout = config.timeout
self.extra_body = config.extra_body

self._tokenizer = encode
self.pairwise_tokenization = False
Expand All @@ -165,6 +171,12 @@ def __init__(self, config: LiteLLMModelConfig) -> None:
# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

# Log sampling params so the user can verify what will be sent to the server
sampling_params = self.generation_parameters.to_litellm_dict()
if self.extra_body:
sampling_params["extra_body"] = self.extra_body
logger.info(f"Sampling parameters: {sampling_params}")

def _prepare_stop_sequence(self, stop_sequence):
"""Prepare and validate stop sequence."""
if self.provider == "anthropic":
Expand Down Expand Up @@ -210,6 +222,7 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
"n": num_samples,
"caching": True,
"timeout": self.timeout,
"extra_body": self.extra_body,
}

if "o1" in self.model:
Expand Down