|
57 | 57 | from models.api.requests import QueryRequest |
58 | 58 | from models.common.responses.types import InputTool, InputToolMCP |
59 | 59 | from models.config import ApprovalFilter, ByokRag, ModelContextProtocolServer |
| 60 | +from utils.query import normalize_vertex_ai_model_id |
60 | 61 | from utils.responses import ( |
61 | 62 | _build_chunk_attributes, |
62 | 63 | _merge_tools, |
@@ -3577,3 +3578,32 @@ async def test_merge_header_no_server_tools_returns_client_only( |
3577 | 3578 | ) |
3578 | 3579 | assert tools is not None |
3579 | 3580 | assert len(tools) == 1 |
| 3581 | + |
| 3582 | + |
| 3583 | +class TestNormalizeVertexAIModelId: |
| 3584 | + """Tests for normalize_vertex_ai_model_id function.""" |
| 3585 | + |
| 3586 | + def test_normalizes_vertex_ai_model_id(self) -> None: |
| 3587 | + """Test that Vertex AI model IDs are normalized correctly.""" |
| 3588 | + input_model = "publishers/google/models/gemini-2.5-flash" |
| 3589 | + expected = "google/gemini-2.5-flash" |
| 3590 | + assert normalize_vertex_ai_model_id(input_model) == expected |
| 3591 | + |
| 3592 | + def test_normalizes_vertex_ai_model_id_with_version(self) -> None: |
| 3593 | + """Test normalization with versioned Vertex AI model ID.""" |
| 3594 | + input_model = "publishers/google/models/gemini-1.5-pro-001" |
| 3595 | + expected = "google/gemini-1.5-pro-001" |
| 3596 | + assert normalize_vertex_ai_model_id(input_model) == expected |
| 3597 | + |
| 3598 | + def test_preserves_non_vertex_ai_model_ids(self) -> None: |
| 3599 | + """Test that non-Vertex AI model IDs are returned unchanged.""" |
| 3600 | + # Regular model IDs should pass through |
| 3601 | + assert normalize_vertex_ai_model_id("gpt-4") == "gpt-4" |
| 3602 | + assert normalize_vertex_ai_model_id("openai/gpt-4") == "openai/gpt-4" |
| 3603 | + assert normalize_vertex_ai_model_id("watsonx/model") == "watsonx/model" |
| 3604 | + |
| 3605 | + def test_preserves_gemini_api_format(self) -> None: |
| 3606 | + """Test that Gemini API format (models/...) is preserved.""" |
| 3607 | + # Gemini API format doesn't have the publishers prefix |
| 3608 | + gemini_api_format = "models/gemini-2.5-flash" |
| 3609 | + assert normalize_vertex_ai_model_id(gemini_api_format) == gemini_api_format |
0 commit comments