Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions plugins/validation_tests/test_object_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.sut_registry import SUTS

from modelgauge.suts.baseten_api import BasetenSUT
from modelgauge.suts.huggingface_chat_completion import HuggingFaceChatCompletionDedicatedSUT
from modelgauge.test_registry import TESTS
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1 # see "workaround" below
from modelgauge_tests.fake_secrets import fake_all_secrets
Expand Down Expand Up @@ -101,9 +103,27 @@ def test_all_suts_construct_and_record_init(sut_name):
"nvidia-nemotron-mini-4b-instruct",
"nvidia-nemotron-4-340b-instruct",
"nvidia-llama-3.1-nemotron-70b-instruct",
"nvidia-llama-3.3-49b-nemotron-super", # too expensive
"nvidia-mistral-nemo-minitron-8b-8k-instruct",
}
TOO_EXPENSIVE_SUT_CLASSES = {
Comment thread
bkorycki marked this conversation as resolved.
BasetenSUT, # Dedicated server
HuggingFaceChatCompletionDedicatedSUT, # Dedicated server
}


def suts_to_test():
suts = []
represented_classes = []
for uid, sut_info in SUTS.items():
if uid in SUTS_THAT_WE_DONT_CARE_ABOUT_FAILING:
continue
cls = sut_info.cls
if any([issubclass(cls, c) for c in TOO_EXPENSIVE_SUT_CLASSES]):
continue
if cls not in represented_classes:
represented_classes.append(cls)
suts.append(uid)
return suts


# This test can take a while, and we don't want a test run to fail
Expand All @@ -113,7 +133,7 @@ def test_all_suts_construct_and_record_init(sut_name):
# get a sense of a real user's experience.
@expensive_tests
@pytest.mark.timeout(TIMEOUT)
@pytest.mark.parametrize("sut_name", set(SUTS.keys()) - SUTS_THAT_WE_DONT_CARE_ABOUT_FAILING)
@pytest.mark.parametrize("sut_name", suts_to_test())
def test_all_suts_can_evaluate(sut_name):
sut = SUTS.make_instance(sut_name, secrets=load_secrets_from_config())
assert isinstance(sut, PromptResponseSUT), "Update this test to handle other types."
Expand All @@ -132,7 +152,7 @@ def test_all_suts_can_evaluate(sut_name):

@expensive_tests
@pytest.mark.timeout(TIMEOUT)
@pytest.mark.parametrize("sut_name", set(SUTS.keys()) - SUTS_THAT_WE_DONT_CARE_ABOUT_FAILING)
@pytest.mark.parametrize("sut_name", suts_to_test())
def test_can_cache_all_sut_responses(sut_name, tmpdir):
sut = SUTS.make_instance(sut_name, secrets=load_secrets_from_config())
assert isinstance(sut, PromptResponseSUT), "Update this test to handle other types."
Expand Down