Skip to content

Commit 7b0ff46

Browse files
committed
feat(aliases): allow setting reasoning level
Signed-off-by: Linus Schlumberger <linus.schlumberger@siemens.com>
1 parent 7a9ed88 commit 7b0ff46

10 files changed

Lines changed: 439 additions & 26 deletions

File tree

src/tests/test_alias_routing.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import json
2+
from unittest.mock import MagicMock, patch
3+
4+
import pytest
5+
6+
from vllm_router.routers.routing_logic import RoundRobinRouter
7+
from vllm_router.utils import AliasConfig, SingletonABCMeta
8+
9+
10+
class FakeEndpointInfo:
11+
def __init__(self, url, model_names=None, sleep=False, Id=None):
12+
self.url = url
13+
self.model_names = model_names or ["deepseek-r1"]
14+
self.sleep = sleep
15+
self.Id = Id
16+
17+
18+
ENDPOINTS = [FakeEndpointInfo(url="http://engine1")]
19+
20+
MOCK_HEADERS = MagicMock()
21+
MOCK_HEADERS.items.return_value = [("content-type", "text/event-stream")]
22+
23+
24+
@pytest.fixture(autouse=True)
25+
def cleanup_singletons():
26+
yield
27+
for cls in list(SingletonABCMeta._instances.keys()):
28+
del SingletonABCMeta._instances[cls]
29+
30+
31+
def _make_service_discovery(aliases):
32+
sd = MagicMock()
33+
sd.get_endpoint_info.return_value = ENDPOINTS
34+
sd.aliases = aliases
35+
sd.has_ever_seen_model.return_value = True
36+
return sd
37+
38+
39+
def _make_request(body_dict, router):
40+
state = MagicMock()
41+
state.router = router
42+
state.engine_stats_scraper.get_engine_stats.return_value = {}
43+
state.request_stats_monitor.get_request_stats.return_value = {}
44+
state.otel_enabled = False
45+
state.semantic_cache_available = False
46+
state.callbacks = None
47+
state.external_provider_registry = None
48+
49+
req = MagicMock()
50+
req.headers = {"content-type": "application/json"}
51+
req.query_params = {}
52+
req.method = "POST"
53+
req.url = "http://router/v1/chat/completions"
54+
req.app.state = state
55+
56+
raw = json.dumps(body_dict).encode()
57+
58+
async def body():
59+
return raw
60+
61+
req.body = body
62+
return req
63+
64+
65+
async def _run_routing_test(aliases, request_body, expect_model, expect_reasoning=None):
66+
"""Route a request through route_general_request and verify the forwarded body."""
67+
router = RoundRobinRouter()
68+
setattr(router, "max_instance_failover_reroute_attempts", 0)
69+
req = _make_request(request_body, router)
70+
captured = {}
71+
72+
async def fake_process(request, body, server_url, *a, **kw):
73+
captured["body"] = json.loads(body)
74+
yield MOCK_HEADERS, 200
75+
yield b'{"id":"x"}'
76+
77+
with (
78+
patch(
79+
"vllm_router.services.request_service.request.get_service_discovery",
80+
return_value=_make_service_discovery(aliases),
81+
),
82+
patch(
83+
"vllm_router.services.request_service.request.is_request_rewriter_initialized",
84+
return_value=False,
85+
),
86+
patch(
87+
"vllm_router.services.request_service.request.process_request",
88+
side_effect=fake_process,
89+
),
90+
):
91+
from vllm_router.services.request_service.request import route_general_request
92+
93+
resp = await route_general_request(req, "/v1/chat/completions", MagicMock())
94+
95+
assert resp.status_code == 200
96+
assert captured["body"]["model"] == expect_model
97+
if expect_reasoning is not None:
98+
assert captured["body"]["reasoning_effort"] == expect_reasoning
99+
else:
100+
assert "reasoning_effort" not in captured["body"]
101+
102+
103+
_MESSAGES = [{"role": "user", "content": "hi"}]
104+
105+
106+
@pytest.mark.asyncio
107+
async def test_alias_injects_reasoning_effort():
108+
"""When alias has reasoning_effort and request doesn't, it should be injected."""
109+
await _run_routing_test(
110+
aliases={
111+
"reasoning-model": AliasConfig(model="deepseek-r1", reasoning_effort="high")
112+
},
113+
request_body={
114+
"model": "reasoning-model",
115+
"stream": False,
116+
"messages": _MESSAGES,
117+
},
118+
expect_model="deepseek-r1",
119+
expect_reasoning="high",
120+
)
121+
122+
123+
@pytest.mark.asyncio
124+
async def test_client_reasoning_effort_not_overwritten():
125+
"""When client already provides reasoning_effort, alias should NOT overwrite it."""
126+
await _run_routing_test(
127+
aliases={
128+
"reasoning-model": AliasConfig(model="deepseek-r1", reasoning_effort="high")
129+
},
130+
request_body={
131+
"model": "reasoning-model",
132+
"stream": False,
133+
"reasoning_effort": "low",
134+
"messages": _MESSAGES,
135+
},
136+
expect_model="deepseek-r1",
137+
expect_reasoning="low",
138+
)
139+
140+
141+
@pytest.mark.asyncio
142+
async def test_plain_alias_no_reasoning_effort():
143+
"""A plain alias (no reasoning_effort) should not inject reasoning_effort."""
144+
await _run_routing_test(
145+
aliases={"short-name": AliasConfig(model="deepseek-r1")},
146+
request_body={"model": "short-name", "stream": False, "messages": _MESSAGES},
147+
expect_model="deepseek-r1",
148+
)
149+
150+
151+
@pytest.mark.asyncio
152+
async def test_legacy_plain_string_alias():
153+
"""A plain-string alias value (from a custom ServiceDiscovery) must still work."""
154+
await _run_routing_test(
155+
aliases={"short-name": "deepseek-r1"},
156+
request_body={"model": "short-name", "stream": False, "messages": _MESSAGES},
157+
expect_model="deepseek-r1",
158+
)

src/tests/test_parser.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,63 @@ def test_load_initial_config_from_config_file_if_required_when_yaml_config_file_
9292
assert args.static_aliases == "text-embedding-3-small:bge-m3"
9393

9494

95+
def test_load_initial_config_from_config_file_if_required_when_yaml_config_with_extended_aliases_is_provided(
96+
monkeypatch: pytest.MonkeyPatch,
97+
) -> None:
98+
with tempfile.NamedTemporaryFile() as f:
99+
monkeypatch.setattr(sys, "argv", [sys.argv[0], "--dynamic-config-yaml", f.name])
100+
f.write(
101+
yaml.safe_dump(
102+
{
103+
"static_aliases": {
104+
"text": "llama3",
105+
"reasoning": {"model": "llama3", "reasoning_effort": "high"},
106+
},
107+
}
108+
).encode()
109+
)
110+
f.seek(0)
111+
test_parser = argparse.ArgumentParser("test")
112+
test_parser.add_argument("--dynamic-config-yaml", type=str)
113+
test_parser.add_argument("--dynamic-config-json", type=str)
114+
args = test_parser.parse_args()
115+
args = parser.load_initial_config_from_config_file_if_required(
116+
test_parser, args
117+
)
118+
assert "text:llama3" in args.static_aliases
119+
assert "reasoning:llama3|reasoning_effort=high" in args.static_aliases
120+
121+
122+
def test_generate_static_aliases_rejects_unknown_key() -> None:
123+
from vllm_router.parsers.yaml_utils import generate_static_aliases
124+
125+
with pytest.raises(ValueError, match="unknown keys"):
126+
generate_static_aliases({"r1": {"model": "llama3", "reasoning_effrot": "high"}})
127+
128+
129+
def test_generate_static_aliases_rejects_missing_model() -> None:
130+
from vllm_router.parsers.yaml_utils import generate_static_aliases
131+
132+
with pytest.raises(ValueError, match="missing required key 'model'"):
133+
generate_static_aliases({"r1": {"reasoning_effort": "high"}})
134+
135+
136+
def test_generate_static_aliases_rejects_invalid_type() -> None:
137+
from vllm_router.parsers.yaml_utils import generate_static_aliases
138+
139+
with pytest.raises(ValueError, match="expected string or dict"):
140+
generate_static_aliases({"bad": 42})
141+
142+
143+
def test_generate_static_aliases_rejects_invalid_reasoning_effort() -> None:
144+
from vllm_router.parsers.yaml_utils import generate_static_aliases
145+
146+
with pytest.raises(ValueError, match="Invalid reasoning_effort"):
147+
generate_static_aliases(
148+
{"r1": {"model": "llama3", "reasoning_effort": "urgent"}}
149+
)
150+
151+
95152
def test_load_initial_config_from_config_file_if_required_when_json_config_file_is_provided_adds_values_to_args(
96153
monkeypatch: pytest.MonkeyPatch,
97154
) -> None:

src/tests/test_static_service_discovery.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from vllm_router.service_discovery import StaticServiceDiscovery
6+
from vllm_router.utils import AliasConfig
67

78

89
def test_init_when_static_backend_health_checks_calls_start_health_checks(
@@ -162,7 +163,7 @@ def test_has_ever_seen_model_when_model_is_alias_returns_true():
162163
None,
163164
["http://localhost.com"],
164165
["llama3"],
165-
{"llama": "llama3"},
166+
{"llama": AliasConfig(model="llama3")},
166167
None,
167168
["chat"],
168169
static_backend_health_checks=False,
@@ -172,3 +173,36 @@ def test_has_ever_seen_model_when_model_is_alias_returns_true():
172173
assert discovery_instance.has_ever_seen_model("llama") is True
173174
assert discovery_instance.has_ever_seen_model("llama3") is True
174175
assert discovery_instance.has_ever_seen_model("unknown-model") is False
176+
177+
178+
def _make_discovery(aliases=None):
179+
return StaticServiceDiscovery(
180+
None,
181+
["http://localhost.com"],
182+
["llama3"],
183+
aliases,
184+
None,
185+
["chat"],
186+
static_backend_health_checks=False,
187+
prefill_model_labels=None,
188+
decode_model_labels=None,
189+
)
190+
191+
192+
def test_init_normalizes_legacy_str_aliases_to_alias_config():
193+
"""Programmatic callers passing dict[str, str] should still work."""
194+
d = _make_discovery({"llama": "llama3"})
195+
assert d.aliases == {"llama": AliasConfig(model="llama3")}
196+
assert d.has_ever_seen_model("llama") is True
197+
198+
199+
def test_init_accepts_alias_config_values():
200+
d = _make_discovery(
201+
{"reasoning": AliasConfig(model="llama3", reasoning_effort="high")}
202+
)
203+
assert d.aliases["reasoning"].reasoning_effort == "high"
204+
205+
206+
def test_init_rejects_invalid_alias_value_type():
207+
with pytest.raises(TypeError, match="expected str or AliasConfig"):
208+
_make_discovery({"bad": 123})

src/tests/test_utils.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,82 @@
66
from starlette.datastructures import MutableHeaders
77

88
from vllm_router import utils
9+
from vllm_router.utils import AliasConfig, normalize_alias_config
910

1011

1112
@pytest.mark.parametrize(
1213
"aliases,expected_result",
1314
(
14-
("gpt-4:mistral-nemo-instruct-2407", {"gpt-4": "mistral-nemo-instruct-2407"}),
15+
(
16+
"gpt-4:mistral-nemo-instruct-2407",
17+
{"gpt-4": AliasConfig(model="mistral-nemo-instruct-2407")},
18+
),
1519
(
1620
"gpt-4:mistral-nemo-instruct-2407,gpt-3.5:mistral-nemo-instruct-2407",
1721
{
18-
"gpt-4": "mistral-nemo-instruct-2407",
19-
"gpt-3.5": "mistral-nemo-instruct-2407",
22+
"gpt-4": AliasConfig(model="mistral-nemo-instruct-2407"),
23+
"gpt-3.5": AliasConfig(model="mistral-nemo-instruct-2407"),
2024
},
2125
),
2226
(
2327
"gpt-4:deepseek-r1-distill-qwen-7b,mistral-7b-instruct:mistral-nemo-instruct-2407",
2428
{
25-
"gpt-4": "deepseek-r1-distill-qwen-7b",
26-
"mistral-7b-instruct": "mistral-nemo-instruct-2407",
29+
"gpt-4": AliasConfig(model="deepseek-r1-distill-qwen-7b"),
30+
"mistral-7b-instruct": AliasConfig(model="mistral-nemo-instruct-2407"),
31+
},
32+
),
33+
(
34+
"reasoning:deepseek-r1-distill-qwen-7b|reasoning_effort=high",
35+
{
36+
"reasoning": AliasConfig(
37+
model="deepseek-r1-distill-qwen-7b", reasoning_effort="high"
38+
)
39+
},
40+
),
41+
(
42+
"text:mistral-nemo-instruct-2407,reasoning:deepseek-r1-distill-qwen-7b|reasoning_effort=low",
43+
{
44+
"text": AliasConfig(model="mistral-nemo-instruct-2407"),
45+
"reasoning": AliasConfig(
46+
model="deepseek-r1-distill-qwen-7b", reasoning_effort="low"
47+
),
2748
},
2849
),
2950
),
3051
)
31-
def test_parse_static_aliases_when_aliases_as_string_supplied_returns_dict(
32-
aliases: str, expected_result: dict
33-
) -> None:
52+
def test_parse_static_aliases(aliases: str, expected_result: dict) -> None:
3453
assert utils.parse_static_aliases(aliases) == expected_result
3554

3655

56+
def test_alias_config_rejects_invalid_reasoning_effort() -> None:
57+
with pytest.raises(ValueError, match="Invalid reasoning_effort"):
58+
AliasConfig(model="test", reasoning_effort="invalid")
59+
60+
61+
def test_normalize_alias_config_accepts_plain_string() -> None:
62+
assert normalize_alias_config("gpt-4", "llama3") == AliasConfig(model="llama3")
63+
64+
65+
def test_normalize_alias_config_accepts_alias_config() -> None:
66+
config = AliasConfig(model="llama3", reasoning_effort="high")
67+
assert normalize_alias_config("reasoning", config) == config
68+
69+
70+
def test_normalize_alias_config_rejects_invalid_value() -> None:
71+
with pytest.raises(TypeError, match="expected str or AliasConfig"):
72+
normalize_alias_config("bad", 123)
73+
74+
75+
def test_parse_static_aliases_rejects_unknown_parameter() -> None:
76+
with pytest.raises(ValueError, match="Unknown alias parameter 'reasoning_effrot'"):
77+
utils.parse_static_aliases("r1:llama3|reasoning_effrot=high")
78+
79+
80+
def test_parse_static_aliases_rejects_invalid_entry() -> None:
81+
with pytest.raises(ValueError, match="Invalid alias entry"):
82+
utils.parse_static_aliases("missing-model")
83+
84+
3785
def test_replace_model_in_request_body_replaces_model() -> None:
3886
model = "mistral-nemo-instruct-2407"
3987
result = json.loads(
@@ -110,7 +158,6 @@ def test_is_model_healthy_when_requests_raises_exception_returns_false(
110158
def test_is_model_healthy_when_requests_status_with_status_code_not_200_returns_false(
111159
monkeypatch: pytest.MonkeyPatch,
112160
) -> None:
113-
114161
# Mock an internal server error response
115162
mock_response = MagicMock(status_code=500)
116163

0 commit comments

Comments
 (0)