Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 16 additions & 68 deletions tests/test_unit_test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,73 +784,20 @@ def test_unittest_discovery_with_pytest_class_fixture():

import hashlib
import json
from typing import Any, Dict, List, Optional, Required, TypedDict, Union # noqa: UP035


class LiteLLMParamsTypedDict(TypedDict, total=False):
model: str
custom_llm_provider: Optional[str]
tpm: Optional[int]
rpm: Optional[int]
order: Optional[int]
weight: Optional[int]
max_parallel_requests: Optional[int]
api_key: Optional[str]
api_base: Optional[str]
api_version: Optional[str]
stream_timeout: Optional[Union[float, str]]
max_retries: Optional[int]
organization: Optional[Union[list, str]] # for openai orgs
## DROP PARAMS ##
drop_params: Optional[bool]
## UNIFIED PROJECT/REGION ##
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if all these are part of the original code and important to it, we should keep it, just skip this test for < 3.10

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super important actually, doesn't impact the actual test

region_name: Optional[str]
## VERTEX AI ##
vertex_project: Optional[str]
vertex_location: Optional[str]
## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_region_name: Optional[str]
## IBM WATSONX ##
watsonx_region_name: Optional[str]
## CUSTOM PRICING ##
input_cost_per_token: Optional[float]
output_cost_per_token: Optional[float]
input_cost_per_second: Optional[float]
output_cost_per_second: Optional[float]
num_retries: Optional[int]
## MOCK RESPONSES ##

# routing params
# use this for tag-based routing
tags: Optional[list[str]]

# deployment budgets
max_budget: Optional[float]
budget_duration: Optional[str]

class DeploymentTypedDict(TypedDict, total=False):
model_name: Required[str]
litellm_params: Required[LiteLLMParamsTypedDict]
model_info: dict

class Router:
model_names: set = set() # noqa: RUF012
cache_responses: Optional[bool] = False
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
model_names: list
cache_responses = False
tenacity = None

def __init__( # noqa: PLR0915
self,
model_list: Optional[
Union[list[DeploymentTypedDict], list[dict[str, Any]]]
] = None,
model_list = None,
) -> None:
self.model_list = model_list # noqa: ARG002
self.model_id_to_deployment_index_map: dict[str, int] = {}
self.model_name_to_deployment_indices: dict[str, list[int]] = {}
def _generate_model_id(self, model_group: str, litellm_params: dict): # noqa: ANN202
self.model_list = model_list
self.model_id_to_deployment_index_map = {}
self.model_name_to_deployment_indices = {}
def _generate_model_id(self, model_group, litellm_params):
# Optimized: Use list and join instead of string concatenation in loop
# This avoids creating many temporary string objects (O(n) vs O(n²) complexity)
parts = [model_group]
Expand All @@ -874,7 +821,7 @@ def _generate_model_id(self, model_group: str, litellm_params: dict): # noqa: A

return hash_object.hexdigest()
def _add_model_to_list_and_index_map(
self, model: dict, model_id: Optional[str] = None
self, model, model_id = None
) -> None:
idx = len(self.model_list)
self.model_list.append(model)
Expand All @@ -892,7 +839,7 @@ def _add_model_to_list_and_index_map(
self.model_name_to_deployment_indices[model_name] = []
self.model_name_to_deployment_indices[model_name].append(idx)

def _build_model_id_to_deployment_index_map(self, model_list: list) -> None:
def _build_model_id_to_deployment_index_map(self, model_list):
# First populate the model_list
self.model_list = []
for _, model in enumerate(model_list):
Expand All @@ -911,6 +858,7 @@ def _build_model_id_to_deployment_index_map(self, model_list: list) -> None:
model["model_info"]["id"] = model_id

self._add_model_to_list_and_index_map(model=model, model_id=model_id)

"""
code_file_path.write_text(code_file_content)

Expand All @@ -924,9 +872,9 @@ def _build_model_id_to_deployment_index_map(self, model_list: list) -> None:

class TestRouterIndexManagement:
@pytest.fixture
def router(self): # noqa: ANN201
def router(self):
return Router(model_list=[])
def test_build_model_id_to_deployment_index_map(self, router) -> None: # noqa: ANN001
def test_build_model_id_to_deployment_index_map(self, router):
model_list = [
{
"model_name": "gpt-3.5-turbo",
Expand All @@ -941,7 +889,7 @@ def test_build_model_id_to_deployment_index_map(self, router) -> None: # noqa:
]

# Test: Build index from model list
router._build_model_id_to_deployment_index_map(model_list) # noqa: SLF001
router._build_model_id_to_deployment_index_map(model_list)

# Verify: model_list is populated
assert len(router.model_list) == 2
Expand Down Expand Up @@ -1630,9 +1578,9 @@ def test_analyze_imports_class_fixture():

class TestRouterIndexManagement:
@pytest.fixture
def router(self): # noqa: ANN201
def router(self):
return Router(model_list=[])
def test_build_model_id_to_deployment_index_map(self, router) -> None: # noqa: ANN001
def test_build_model_id_to_deployment_index_map(self, router):
model_list = [
{
"model_name": "gpt-3.5-turbo",
Expand All @@ -1647,7 +1595,7 @@ def test_build_model_id_to_deployment_index_map(self, router) -> None: # noqa:
]

# Test: Build index from model list
router._build_model_id_to_deployment_index_map(model_list) # noqa: SLF001
router._build_model_id_to_deployment_index_map(model_list)

# Verify: model_list is populated
assert len(router.model_list) == 2
Expand Down