Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@
"jsonschema",
"ruamel.yaml",
"pyyaml",
"litellm>=1.75.5, <=1.82.6",
# For LiteLLM tests. Upper bound pinned: versions 1.82.7+ compromised in supply chain attack.
"litellm>=1.83.0, <2",
# For LiteLLM tests. Versions >=1.82.7,<1.83.0 compromised in supply chain attack.
]

langchain_extra_require = [
Expand Down
26 changes: 11 additions & 15 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3624,8 +3624,7 @@ def test_run_inference_with_litellm_string_prompt_format(
) as mock_litellm, mock.patch(
"vertexai._genai._evals_common._call_litellm_completion"
) as mock_call_litellm_completion:
# fmt: on
mock_litellm.utils.get_valid_models.return_value = ["gpt-4o"]
mock_litellm.get_llm_provider.return_value = ("gpt-4o", "openai", None , None)
prompt_df = pd.DataFrame([{"prompt": "What is LiteLLM?"}])
expected_messages = [{"role": "user", "content": "What is LiteLLM?"}]

Expand Down Expand Up @@ -3676,17 +3675,12 @@ def test_run_inference_with_litellm_openai_request_format(
mock_api_client_fixture,
):
"""Tests inference with LiteLLM where the row contains a chat completion request body."""
# fmt: off
with (
mock.patch(
"vertexai._genai._evals_common.litellm"
) as mock_litellm,
mock.patch(
"vertexai._genai._evals_common._call_litellm_completion"
) as mock_call_litellm_completion,
):
# fmt: on
mock_litellm.utils.get_valid_models.return_value = ["gpt-4o"]
with mock.patch(
"vertexai._genai._evals_common.litellm"
) as mock_litellm, mock.patch(
"vertexai._genai._evals_common._call_litellm_completion"
) as mock_call_litellm_completion:
mock_litellm.get_llm_provider.return_value = ("gpt-4o", "openai", None , None)
prompt_df = pd.DataFrame(
[
{
Expand Down Expand Up @@ -3755,7 +3749,9 @@ def test_run_inference_with_unsupported_model_string(
with mock.patch(
"vertexai._genai._evals_common.litellm"
) as mock_litellm_package:
mock_litellm_package.utils.get_valid_models.return_value = []
mock_litellm_package.get_llm_provider.side_effect = ValueError(
"unsupported model"
)
evals_module = evals.Evals(api_client_=mock_api_client_fixture)
prompt_df = pd.DataFrame([{"prompt": "test"}])

Expand Down Expand Up @@ -3822,7 +3818,7 @@ def test_run_inference_with_litellm_parsing(
# fmt: off
with mock.patch("vertexai._genai._evals_common.litellm") as mock_litellm:
# fmt: on
mock_litellm.utils.get_valid_models.return_value = ["gpt-4o"]
mock_litellm.get_llm_provider.return_value = ("gpt-4o", "openai", None , None)
inference_result = self.client.evals.run_inference(
model="gpt-4o",
src=mock_df,
Expand Down
9 changes: 8 additions & 1 deletion vertexai/_genai/_evals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,14 @@ def _is_litellm_vertex_maas_model(model: str) -> bool:

def _is_litellm_model(model: str) -> bool:
"""Checks if the model name corresponds to a valid LiteLLM model name."""
return model in litellm.utils.get_valid_models(model)
if litellm is None:
return False

try:
litellm.get_llm_provider(model)
return True
except ValueError:
return False


def _is_gemini_model(model: str) -> bool:
Expand Down
Loading