diff --git a/src/modelgauge/cli.py b/src/modelgauge/cli.py index ac2d98e7..bd4d2eb3 100644 --- a/src/modelgauge/cli.py +++ b/src/modelgauge/cli.py @@ -132,7 +132,12 @@ def list_secrets() -> None: @cli.command() @LOCAL_PLUGIN_DIR_OPTION -@click.option("--sut", "-s", help="Which SUT to run.", required=True) +@click.option( + "--sut", + "-s", + help="Which SUT to run. Please quote the value to ensure that dynamic parameterizations are included.", + required=True, +) @sut_options_options @click.option("--prompt", help="The full text to send to the SUT.", required=True) def run_sut( diff --git a/src/modelgauge/dynamic_sut_factory.py b/src/modelgauge/dynamic_sut_factory.py index 4511385c..7c300e4b 100644 --- a/src/modelgauge/dynamic_sut_factory.py +++ b/src/modelgauge/dynamic_sut_factory.py @@ -41,4 +41,5 @@ def get_secrets(self) -> list[InjectSecret]: @abstractmethod def make_sut(self, sut_definition: SUTDefinition) -> SUT: + """Factories that handle special SUT config parameters (e.g. moderated, reasoning) must accept them as kwargs.""" pass diff --git a/src/modelgauge/sut_definition.py b/src/modelgauge/sut_definition.py index d40effc6..e646a246 100644 --- a/src/modelgauge/sut_definition.py +++ b/src/modelgauge/sut_definition.py @@ -23,7 +23,19 @@ def name_for_label(self, label): raise (ValueError(f"for static elements, {label} must match {self.label}")) -class PrefixSUTSpecificationElement(SUTSpecificationElement): +class SUTMetadataElement(SUTSpecificationElement): + """Core SUT information.""" + + pass + + +class SUTConfigElement(SUTSpecificationElement): + """Optional configuration data passed to a SUT on initialization.""" + + required = False + + +class PrefixSUTSpecificationElement(SUTConfigElement): def matches(self, field_name): return field_name.startswith(self.name) @@ -37,15 +49,14 @@ class SUTSpecification: def __init__(self): fields = [ - SUTSpecificationElement("model", "m", str, True), - SUTSpecificationElement("driver", "d", str, True), - SUTSpecificationElement("maker", "mk", str), - SUTSpecificationElement("provider", "pr", str), - SUTSpecificationElement("display_name", "dn", str), - SUTSpecificationElement("reasoning", "reas", bool), - SUTSpecificationElement("moderated", "mod", bool), - SUTSpecificationElement("date", "dt", str), - SUTSpecificationElement("base_url", "url", str), + SUTMetadataElement("model", "m", str, True), + SUTMetadataElement("driver", "d", str, True), + SUTMetadataElement("maker", "mk", str), + SUTMetadataElement("provider", "pr", str), + SUTMetadataElement("date", "dt", str), + SUTConfigElement("reasoning", "reas", bool), + SUTConfigElement("moderated", "mod", bool), + SUTConfigElement("base_url", "url", str), ] self._wildcard_fields = [PrefixSUTSpecificationElement("vllm-", "vllm", str)] @@ -59,9 +70,12 @@ def knows(self, name: str): def requires(self, name: str): return self.knows(name) and self._fields_by_name[name].required - def validate(self, data: dict) -> bool: + def validate(self, metadata: dict, config_data: dict) -> bool: for field_spec in self._fields_by_name.values(): - value = data.get(field_spec.name, None) + if isinstance(field_spec, SUTMetadataElement): + value = metadata.get(field_spec.name, None) + else: + value = config_data.get(field_spec.name, None) if field_spec.required and value is None: raise ValueError(f"Field {field_spec.name} is required.") if value is not None and not isinstance(value, field_spec.value_type): @@ -79,6 +93,14 @@ def element_for_label(self, label: str): return element return None + def element_for_name(self, name: str): + if name in self._fields_by_name: + return self._fields_by_name[name] + for element in self._wildcard_fields: + if element.matches(name): + return element + return None + DEFINITION_VALUE_TYPES = Union[str, int, float, bool, None] @@ -86,19 +108,19 @@ def element_for_label(self, label: str): class SUTDefinition: """The data in a SUT configuration file or JSON blob""" - _data: dict[str, DEFINITION_VALUE_TYPES] + _metadata: dict[str, DEFINITION_VALUE_TYPES] def __init__(self, data=None, **kwargs): self.spec = SUTSpecification() - self._data = {} + self._metadata = {} # Core SUT information. + self.config_data = {} # Everything that comes after ";" if data: for k, v in data.items(): self._add(k, v) for k, v in kwargs.items(): self._add(k, v) - if not self.spec.validate(self._data): - raise ValueError(f"Invalid data: {self._data}") + self.spec.validate(self._metadata, self.config_data) generator = SUTUIDGenerator(self) self.uid = generator.uid @@ -110,31 +132,37 @@ def __str__(self): def _add(self, key: str, value: DEFINITION_VALUE_TYPES): if isinstance(value, str): value = value.strip() - if self.spec.knows(key): - self._data[key] = value - else: + if not self.spec.knows(key): raise ValueError(f"Don't know what to do with {key}") + spec_element = self.spec.element_for_name(key) + if isinstance(spec_element, SUTMetadataElement): + self._metadata[key] = value + elif isinstance(spec_element, SUTConfigElement): + self.config_data[key] = value + else: + raise ValueError(f"Unknown spec element type {spec_element} for {key}") def get(self, field: str, default=None) -> DEFINITION_VALUE_TYPES: - return self._data.get(field, default) + return self._metadata.get(field, self.config_data.get(field, default)) def get_matching(self, label: str) -> Mapping[str, DEFINITION_VALUE_TYPES] | None: element = self.spec.element_for_label(label) if not element: return None result = {} - for k, v in self._data.items(): - if element.matches(k): - result[k] = v + for items in (self._metadata.items(), self.config_data.items()): + for k, v in items: + if element.matches(k): + result[k] = v return result def to_dynamic_sut_metadata(self) -> DynamicSUTMetadata: return DynamicSUTMetadata( - model=self._data["model"], # type: ignore - driver=self._data["driver"], # type: ignore - maker=self._data.get("maker", None), # type: ignore - provider=self._data.get("provider", None), # type: ignore - date=self._data.get("date", None), # type: ignore + model=self._metadata["model"], # type: ignore + driver=self._metadata["driver"], # type: ignore + maker=self._metadata.get("maker", None), # type: ignore + provider=self._metadata.get("provider", None), # type: ignore + date=self._metadata.get("date", None), # type: ignore ) def external_model_name(self) -> str: @@ -231,7 +259,6 @@ class SUTUIDGenerator: order = ( "moderated", "reasoning", - "display_name", "base_url", # for OpenAI-compatible SUTs ) field_separator = RICH_UID_FIELD_SEPARATOR diff --git a/src/modelgauge/sut_factory.py b/src/modelgauge/sut_factory.py index 2581cab0..567ffa09 100644 --- a/src/modelgauge/sut_factory.py +++ b/src/modelgauge/sut_factory.py @@ -20,6 +20,10 @@ class SUTNotFoundException(Exception): pass +class IncompatibleSUTParamsError(Exception): + pass + + class SUTType(Enum): DYNAMIC = "dynamic" KNOWN = "known" @@ -183,7 +187,13 @@ def _make_dynamic_sut(self, uid: str) -> SUT: factory = self.dynamic_sut_factories.get(sut_definition.get("driver")) # type: ignore if not factory: raise UnknownSUTMakerError(f'Don\'t know how to make dynamic sut "{uid}"') - return factory.make_sut(sut_definition) + try: + sut = factory.make_sut(sut_definition, **sut_definition.config_data) + except TypeError: + raise IncompatibleSUTParamsError( + f"The {factory.__class__.__name__} factory cannot handle some dynamic SUT parameters specified in the uid: {sut_definition.config_data}." + ) + return sut def keys(self) -> list[str]: """Mimic the registry interface.""" diff --git a/tests/modelgauge_tests/test_dynamic_sut_factory.py b/tests/modelgauge_tests/test_dynamic_sut_factory.py index 09fb9ab8..3a1bf1bc 100644 --- a/tests/modelgauge_tests/test_dynamic_sut_factory.py +++ b/tests/modelgauge_tests/test_dynamic_sut_factory.py @@ -15,6 +15,11 @@ def make_sut(self, sut_definition: SUTDefinition): return FakeSUT(sut_definition.dynamic_uid) +class FakeDynamicFactoryHandlesMod(FakeDynamicFactory): + def make_sut(self, sut_definition: SUTDefinition, moderated: bool = False): + return FakeSUT(sut_definition.dynamic_uid) + + def test_injected_secrets(): factory = FakeDynamicFactory( {"some-scope": {"some-key": "some-value"}, "optional-scope": {"optional-key": "optional-value"}} diff --git a/tests/modelgauge_tests/test_sut_definition.py b/tests/modelgauge_tests/test_sut_definition.py index e2c1d6e1..3930a612 100644 --- a/tests/modelgauge_tests/test_sut_definition.py +++ b/tests/modelgauge_tests/test_sut_definition.py @@ -32,13 +32,16 @@ def test_to_dynamic_sut_metadata(): def test_parse_rich_sut_uid(): - uid = "google/gemma-3-27b-it:nebius:hfrelay;url=https://example.com/" + uid = "google/gemma-3-27b-it:nebius:hfrelay;reas=y;url=https://example.com/" definition = SUTDefinition.parse(uid) assert definition.get("model") == "gemma-3-27b-it" assert definition.get("maker") == "google" assert definition.get("driver") == "hfrelay" assert definition.get("provider") == "nebius" assert definition.get("base_url") == "https://example.com/" + assert definition.get("reasoning") is True + + assert definition.uid == uid def test_vllm_parameters(): diff --git a/tests/modelgauge_tests/test_sut_factory.py b/tests/modelgauge_tests/test_sut_factory.py index 6bba9359..e0a6f541 100644 --- a/tests/modelgauge_tests/test_sut_factory.py +++ b/tests/modelgauge_tests/test_sut_factory.py @@ -4,9 +4,9 @@ from modelgauge.dynamic_sut_factory import UnknownSUTMakerError from modelgauge.instance_factory import InstanceFactory from modelgauge.sut import SUT -from modelgauge.sut_factory import SUTFactory, SUTNotFoundException, SUTType +from modelgauge.sut_factory import IncompatibleSUTParamsError, SUTFactory, SUTNotFoundException, SUTType from modelgauge_tests.fake_sut import FakeSUT -from modelgauge_tests.test_dynamic_sut_factory import FakeDynamicFactory +from modelgauge_tests.test_dynamic_sut_factory import FakeDynamicFactory, FakeDynamicFactoryHandlesMod KNOWN_UID = "known" UNKNOWN_UID = "pleasedontregisterasutwiththisuid" @@ -26,7 +26,11 @@ def sut_factory(): def sut_factory_dynamic(): """SUT factory that patches the dynamic SUT factories.""" registry = InstanceFactory[SUT]() - dynamic_factories = {"driver1": FakeDynamicFactory({}), "driver2": FakeDynamicFactory({})} + dynamic_factories = { + "driver1": FakeDynamicFactory({}), + "driver2": FakeDynamicFactory({}), + "mod_driver": FakeDynamicFactoryHandlesMod({}), + } with patch( "modelgauge.sut_factory.SUTFactory._load_dynamic_sut_factories", return_value=dynamic_factories, @@ -67,6 +71,16 @@ def test_make_instance_dynamic_unknown_driver(sut_factory_dynamic): sut_factory_dynamic.make_instance("google/gemma:unknown", secrets={}) +def test_make_instance_dynamic_with_params(sut_factory_dynamic): + sut = sut_factory_dynamic.make_instance("google/gemma:mod_driver;mod=y", secrets={}) + assert isinstance(sut, FakeSUT) + + +def test_make_instance_dynamic_incompatible_params(sut_factory_dynamic): + with pytest.raises(IncompatibleSUTParamsError): + sut_factory_dynamic.make_instance("google/gemma:driver1;mod=y", secrets={}) + + def test_make_instance_unknown_type(sut_factory): with pytest.raises(SUTNotFoundException): sut_factory.make_instance(UNKNOWN_UID, secrets={})