diff --git a/plugins/validation_tests/test_object_creation.py b/plugins/validation_tests/test_object_creation.py index 630e4409b..f3b049ae4 100644 --- a/plugins/validation_tests/test_object_creation.py +++ b/plugins/validation_tests/test_object_creation.py @@ -17,6 +17,8 @@ from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_registry import SUTS +from modelgauge.suts.baseten_api import BasetenSUT +from modelgauge.suts.huggingface_chat_completion import HuggingFaceChatCompletionDedicatedSUT from modelgauge.test_registry import TESTS from modelgauge.tests.safe_v1 import BaseSafeTestVersion1 # see "workaround" below from modelgauge_tests.fake_secrets import fake_all_secrets @@ -101,9 +103,27 @@ def test_all_suts_construct_and_record_init(sut_name): "nvidia-nemotron-mini-4b-instruct", "nvidia-nemotron-4-340b-instruct", "nvidia-llama-3.1-nemotron-70b-instruct", - "nvidia-llama-3.3-49b-nemotron-super", # too expensive "nvidia-mistral-nemo-minitron-8b-8k-instruct", } +TOO_EXPENSIVE_SUT_CLASSES = { + BasetenSUT, # Dedicated server + HuggingFaceChatCompletionDedicatedSUT, # Dedicated server +} + + +def suts_to_test(): + suts = [] + represented_classes = [] + for uid, sut_info in SUTS.items(): + if uid in SUTS_THAT_WE_DONT_CARE_ABOUT_FAILING: + continue + cls = sut_info.cls + if any([issubclass(cls, c) for c in TOO_EXPENSIVE_SUT_CLASSES]): + continue + if cls not in represented_classes: + represented_classes.append(cls) + suts.append(uid) + return suts # This test can take a while, and we don't want a test run to fail @@ -113,7 +133,7 @@ def test_all_suts_construct_and_record_init(sut_name): # get a sense of a real user's experience. @expensive_tests @pytest.mark.timeout(TIMEOUT) -@pytest.mark.parametrize("sut_name", set(SUTS.keys()) - SUTS_THAT_WE_DONT_CARE_ABOUT_FAILING) +@pytest.mark.parametrize("sut_name", suts_to_test()) def test_all_suts_can_evaluate(sut_name): sut = SUTS.make_instance(sut_name, secrets=load_secrets_from_config()) assert isinstance(sut, PromptResponseSUT), "Update this test to handle other types." @@ -132,7 +152,7 @@ def test_all_suts_can_evaluate(sut_name): @expensive_tests @pytest.mark.timeout(TIMEOUT) -@pytest.mark.parametrize("sut_name", set(SUTS.keys()) - SUTS_THAT_WE_DONT_CARE_ABOUT_FAILING) +@pytest.mark.parametrize("sut_name", suts_to_test()) def test_can_cache_all_sut_responses(sut_name, tmpdir): sut = SUTS.make_instance(sut_name, secrets=load_secrets_from_config()) assert isinstance(sut, PromptResponseSUT), "Update this test to handle other types."