Skip to content

Commit e2bcf95

Browse files
authored
Merge pull request lightspeed-core#156 from asamal4/judge-panel-config
add llm pool & judge panel config
2 parents 3402ec9 + 8aacdca commit e2bcf95

4 files changed

Lines changed: 773 additions & 242 deletions

File tree

src/lightspeed_evaluation/core/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
APIConfig,
1919
CoreConfig,
2020
EmbeddingConfig,
21+
JudgePanelConfig,
2122
LLMConfig,
23+
LLMPoolConfig,
2224
LoggingConfig,
2325
OutputConfig,
2426
SystemConfig,
@@ -35,7 +37,9 @@
3537
"EvaluationScope",
3638
# System config models
3739
"CoreConfig",
40+
"JudgePanelConfig",
3841
"LLMConfig",
42+
"LLMPoolConfig",
3943
"EmbeddingConfig",
4044
"APIConfig",
4145
"OutputConfig",

src/lightspeed_evaluation/core/models/system.py

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
77

8+
from lightspeed_evaluation.core.system.exceptions import ConfigurationError
89
from lightspeed_evaluation.core.constants import (
910
DEFAULT_API_BASE,
1011
DEFAULT_API_CACHE_DIR,
@@ -318,6 +319,289 @@ class CoreConfig(BaseModel):
318319
)
319320

320321

322+
class LLMParametersConfig(BaseModel):
323+
"""Dynamic parameters passed to LLM API calls.
324+
325+
These parameters are passed directly to the LLM provider.
326+
All fields are optional - unset fields inherit from parent level.
327+
Uses extra="allow" to pass through any provider-specific parameters.
328+
"""
329+
330+
model_config = ConfigDict(extra="allow")
331+
332+
temperature: Optional[float] = Field(
333+
default=None,
334+
ge=0.0,
335+
le=2.0,
336+
description="Sampling temperature",
337+
)
338+
max_completion_tokens: Optional[int] = Field(
339+
default=None,
340+
ge=1,
341+
description="Maximum tokens in response",
342+
)
343+
344+
def to_dict(self, exclude_none: bool = True) -> dict[str, Any]:
345+
"""Convert parameters to dict for passing to LLM.
346+
347+
Args:
348+
exclude_none: If True, exclude None values from output
349+
350+
Returns:
351+
Dict of parameters ready for LLM API call
352+
"""
353+
params = self.model_dump()
354+
if exclude_none:
355+
return {k: v for k, v in params.items() if v is not None}
356+
return params
357+
358+
359+
class LLMDefaultsConfig(BaseModel):
360+
"""Global default settings for all LLMs in the pool.
361+
362+
These are shared defaults that apply to all LLMs unless overridden
363+
at the provider or model level.
364+
"""
365+
366+
model_config = ConfigDict(extra="forbid")
367+
368+
cache_enabled: bool = Field(
369+
default=True,
370+
description="Is caching of LLM queries enabled?",
371+
)
372+
cache_dir: str = Field(
373+
default=DEFAULT_LLM_CACHE_DIR,
374+
min_length=1,
375+
description="Base cache directory",
376+
)
377+
378+
timeout: int = Field(
379+
default=DEFAULT_API_TIMEOUT,
380+
ge=1,
381+
description="Request timeout in seconds",
382+
)
383+
384+
num_retries: int = Field(
385+
default=DEFAULT_LLM_RETRIES,
386+
ge=0,
387+
description="Retry attempts for failed requests",
388+
)
389+
390+
# Default dynamic parameters
391+
parameters: LLMParametersConfig = Field(
392+
default_factory=lambda: LLMParametersConfig(
393+
temperature=DEFAULT_LLM_TEMPERATURE,
394+
max_completion_tokens=DEFAULT_LLM_MAX_TOKENS,
395+
),
396+
description="Default dynamic parameters for LLM calls",
397+
)
398+
399+
400+
class LLMProviderConfig(BaseModel):
401+
"""Configuration for a single LLM provider/model in the pool.
402+
403+
Contains model-specific settings. Cache and retry settings are managed
404+
at the pool defaults level, not per-model.
405+
406+
The dict key is the unique model ID used for referencing.
407+
"""
408+
409+
model_config = ConfigDict(extra="forbid")
410+
411+
# Required: Provider type
412+
provider: str = Field(
413+
min_length=1,
414+
description="Provider type (e.g., openai, watsonx, gemini, hosted_vllm)",
415+
)
416+
417+
# Model identity (optional - defaults to dict key)
418+
model: Optional[str] = Field(
419+
default=None,
420+
min_length=1,
421+
description="Actual model name. If not set, uses the dict key as model name.",
422+
)
423+
424+
# SSL settings (optional - inherit from defaults or use system defaults)
425+
ssl_verify: Optional[bool] = Field(
426+
default=None,
427+
description="Verify SSL certificates. Inherits from defaults if not set.",
428+
)
429+
ssl_cert_file: Optional[str] = Field(
430+
default=None,
431+
description="Path to custom CA certificate file",
432+
)
433+
434+
# API endpoint/key configuration (optional - falls back to environment variable)
435+
api_base: Optional[str] = Field(
436+
default=None,
437+
min_length=1,
438+
description=(
439+
"Base URL for the API endpoint. "
440+
"If not set, falls back to provider-specific environment variable."
441+
),
442+
)
443+
api_key_path: Optional[str] = Field(
444+
default=None,
445+
min_length=1,
446+
description=(
447+
"Path to text file containing the API key for this model. "
448+
"If not set, falls back to provider-specific environment variable."
449+
),
450+
)
451+
452+
# Dynamic parameters (passed to LLM API)
453+
parameters: LLMParametersConfig = Field(
454+
default_factory=LLMParametersConfig,
455+
description="Dynamic parameters for this model (merged with defaults)",
456+
)
457+
458+
# Timeout can be model-specific (some models are slower)
459+
timeout: Optional[int] = Field(
460+
default=None,
461+
ge=1,
462+
description="Override timeout for this model",
463+
)
464+
465+
466+
class LLMPoolConfig(BaseModel):
467+
"""Pool of LLM configurations for reuse across the system.
468+
469+
Provides a centralized place to define all LLM configurations,
470+
which can be referenced by judge_panel, agents, or other components.
471+
472+
Cache and retry settings are managed at the defaults level only.
473+
Model entries contain model-specific settings (provider, parameters, SSL).
474+
"""
475+
476+
model_config = ConfigDict(extra="forbid")
477+
478+
defaults: LLMDefaultsConfig = Field(
479+
default_factory=LLMDefaultsConfig,
480+
description="Global default settings for all LLMs (cache, retry, parameters)",
481+
)
482+
models: dict[str, LLMProviderConfig] = Field(
483+
default_factory=dict,
484+
description="Model configurations. Key is unique model ID for referencing.",
485+
)
486+
487+
def get_model_ids(self) -> list[str]:
488+
"""Get all available model IDs."""
489+
return list(self.models.keys())
490+
491+
def resolve_llm_config(
492+
self, model_id: str, cache_suffix: Optional[str] = None
493+
) -> LLMConfig:
494+
"""Resolve a model ID to a fully configured LLMConfig.
495+
496+
Resolution order: defaults -> model entry (for model-specific fields)
497+
498+
Args:
499+
model_id: Model identifier (key in models dict)
500+
cache_suffix: Optional suffix for cache directory (e.g., "judge_0")
501+
502+
Returns:
503+
Fully resolved LLMConfig
504+
505+
Raises:
506+
ValueError: If model_id not found
507+
"""
508+
if model_id not in self.models:
509+
raise ValueError(
510+
f"Model '{model_id}' not found in llm_pool.models. "
511+
f"Available: {list(self.models.keys())}"
512+
)
513+
entry = self.models[model_id]
514+
515+
# Merge parameters: defaults -> model entry
516+
merged_params: dict[str, Any] = {}
517+
merged_params.update(self.defaults.parameters.to_dict(exclude_none=True))
518+
merged_params.update(entry.parameters.to_dict(exclude_none=True))
519+
520+
# Build cache_dir from defaults with model-specific suffix
521+
suffix = cache_suffix if cache_suffix else model_id
522+
cache_dir = os.path.join(self.defaults.cache_dir, suffix)
523+
524+
return LLMConfig(
525+
provider=entry.provider,
526+
model=entry.model or model_id,
527+
temperature=merged_params.get("temperature", DEFAULT_LLM_TEMPERATURE),
528+
max_tokens=merged_params.get(
529+
"max_completion_tokens", DEFAULT_LLM_MAX_TOKENS
530+
),
531+
timeout=(
532+
entry.timeout if entry.timeout is not None else self.defaults.timeout
533+
),
534+
num_retries=self.defaults.num_retries,
535+
ssl_verify=(
536+
entry.ssl_verify if entry.ssl_verify is not None else DEFAULT_SSL_VERIFY
537+
),
538+
ssl_cert_file=entry.ssl_cert_file,
539+
cache_enabled=self.defaults.cache_enabled,
540+
cache_dir=cache_dir,
541+
# Note: api_base and api_key_path are not propagated yet - requires LLMConfig extension
542+
)
543+
544+
545+
class JudgePanelConfig(BaseModel):
546+
"""Judge panel configuration for multi-LLM evaluation.
547+
548+
References models from LLM pool by model ID (the key in llm_pool.models).
549+
Each judge ID must correspond to a key in the llm_pool.models dictionary.
550+
"""
551+
552+
model_config = ConfigDict(extra="forbid")
553+
554+
judges: list[str] = Field(
555+
...,
556+
min_length=1,
557+
description="List of model IDs (keys from llm_pool.models). At least one required.",
558+
)
559+
enabled_metrics: Optional[list[str]] = Field(
560+
default=None,
561+
description=(
562+
"Metrics that should use the judge panel. "
563+
"If None, all metrics use the panel. "
564+
"If empty list, no metrics use the panel."
565+
),
566+
)
567+
aggregation_strategy: str = Field(
568+
default="average",
569+
description=(
570+
"Strategy for aggregating scores from multiple judges. "
571+
"Options: 'max', 'average', 'majority_vote'. "
572+
"Note: Currently unused - will be implemented later."
573+
),
574+
)
575+
576+
@field_validator("enabled_metrics")
577+
@classmethod
578+
def validate_enabled_metrics(cls, v: Optional[list[str]]) -> Optional[list[str]]:
579+
"""Validate enabled_metrics format (framework:metric_name)."""
580+
if v is not None:
581+
for metric in v:
582+
if not metric or ":" not in metric:
583+
raise ValueError(
584+
f'Metric "{metric}" must be in format "framework:metric_name"'
585+
)
586+
parts = metric.split(":", 1)
587+
if len(parts) != 2 or not parts[0].strip() or not parts[1].strip():
588+
raise ValueError(
589+
f'Metric "{metric}" must be in format "framework:metric_name"'
590+
)
591+
return v
592+
593+
@field_validator("aggregation_strategy")
594+
@classmethod
595+
def validate_aggregation_strategy(cls, v: str) -> str:
596+
"""Validate aggregation_strategy is a supported value."""
597+
allowed = ["max", "average", "majority_vote"]
598+
if v not in allowed:
599+
raise ValueError(
600+
f"Unsupported aggregation_strategy '{v}'. Allowed: {allowed}"
601+
)
602+
return v
603+
604+
321605
class SystemConfig(BaseModel):
322606
"""System configuration using individual config models."""
323607

@@ -328,6 +612,25 @@ class SystemConfig(BaseModel):
328612
default_factory=CoreConfig, description="Core eval configuration"
329613
)
330614
llm: LLMConfig = Field(default_factory=LLMConfig, description="LLM configuration")
615+
616+
# LLM Pool - shared pool of LLM configurations
617+
llm_pool: Optional[LLMPoolConfig] = Field(
618+
default=None,
619+
description=(
620+
"Pool of LLM configurations. Define models once, "
621+
"reference by ID in judge_panel or other components."
622+
),
623+
)
624+
625+
# Judge Panel - references models from llm_pool
626+
judge_panel: Optional[JudgePanelConfig] = Field(
627+
default=None,
628+
description=(
629+
"Optional judge panel configuration. "
630+
"References models from 'llm_pool' by ID. "
631+
"If not provided, the single 'llm' configuration is used."
632+
),
633+
)
331634
embedding: EmbeddingConfig = Field(
332635
default_factory=EmbeddingConfig, description="Embeddings configuration"
333636
)
@@ -349,3 +652,50 @@ class SystemConfig(BaseModel):
349652
default_conversation_metrics_metadata: dict[str, dict[str, Any]] = Field(
350653
default_factory=dict, description="Default conversation metrics metadata"
351654
)
655+
656+
def get_judge_configs(self) -> list[LLMConfig]:
657+
"""Get resolved LLMConfig for all judges.
658+
659+
Returns:
660+
List of LLMConfig objects for each judge.
661+
If judge_panel is configured, resolves from llm_pool.
662+
Otherwise, returns single llm config.
663+
"""
664+
if not self.judge_panel:
665+
return [self.llm]
666+
667+
if not self.llm_pool:
668+
raise ConfigurationError(
669+
"judge_panel is configured but 'llm_pool' is not defined. "
670+
"Please define the llm_pool section with models."
671+
)
672+
673+
configs = []
674+
for idx, judge_id in enumerate(self.judge_panel.judges):
675+
cache_suffix = f"judge_{idx}"
676+
config = self.llm_pool.resolve_llm_config(
677+
judge_id, cache_suffix=cache_suffix
678+
)
679+
configs.append(config)
680+
return configs
681+
682+
def get_llm_config(
683+
self, model_id: str, cache_suffix: Optional[str] = None
684+
) -> LLMConfig:
685+
"""Get resolved LLMConfig for a specific model from the pool.
686+
687+
Args:
688+
model_id: Model identifier (key in llm_pool.models)
689+
cache_suffix: Optional suffix for cache directory
690+
691+
Returns:
692+
Fully resolved LLMConfig
693+
694+
Raises:
695+
ConfigurationError: If llm_pool not configured or model not found
696+
"""
697+
if not self.llm_pool:
698+
raise ConfigurationError(
699+
f"Cannot resolve model '{model_id}' - 'llm_pool' is not configured."
700+
)
701+
return self.llm_pool.resolve_llm_config(model_id, cache_suffix=cache_suffix)

0 commit comments

Comments
 (0)