Skip to content

Commit a2dcf43

Browse files
committed
feat: autodetect model name
Signed-off-by: Liana Koleva <43767763+lianakoleva@users.noreply.github.com>
1 parent 629a2d5 commit a2dcf43

7 files changed

Lines changed: 275 additions & 5 deletions

File tree

docs/cli-options.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,10 @@ aiperf profile --model your_model --url localhost:8000 --goodput "request_latenc
116116

117117
### Endpoint
118118

119-
#### `-m`, `--model-names`, `--model` `<list>` _(Required)_
119+
#### `-m`, `--model-names`, `--model` `<list>`
120120

121121
Model name(s) to be benchmarked. Can be a comma-separated list or a single model name.
122+
If omitted, `aiperf profile` attempts to auto-detect a model from `GET {url}/v1/models`.
122123

123124
#### `--model-selection-strategy` `<str>`
124125

src/aiperf/cli_commands/profile.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
# SPDX-License-Identifier: Apache-2.0
33
"""CLI command for running the Profile subcommand."""
44

5+
import asyncio
56
from cyclopts import App
67

78
from aiperf.common.config import ServiceConfig, UserConfig
9+
from aiperf.common.config.cli_parameter import CLIParameter
810

911
app = App(name="profile")
1012

@@ -53,8 +55,46 @@ def profile(
5355

5456
service_config = service_config or load_service_config()
5557

58+
# If the user didn't provide --model/--model-names, try to discover
59+
# one from the server's OpenAI-compatible model list.
60+
if not user_config.endpoint.model_names:
61+
import logging
62+
63+
from aiperf.common.config.config_defaults import OutputDefaults
64+
from aiperf.common.models.model_autodetect import (
65+
autodetect_model_names_from_v1_models,
66+
)
67+
68+
# Install a basic stderr handler so the log message is visible even
69+
# when `--wait-for-model-timeout` is left at the default (0).
70+
if not logging.getLogger().handlers:
71+
logging.basicConfig(
72+
level=logging.INFO,
73+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
74+
)
75+
76+
raw_headers = user_config.input.headers or []
77+
headers = {str(k): str(v) for k, v in raw_headers}
78+
if user_config.endpoint.api_key:
79+
headers["Authorization"] = f"Bearer {user_config.endpoint.api_key}"
80+
81+
user_config.endpoint.model_names = asyncio.run(
82+
autodetect_model_names_from_v1_models(
83+
urls=user_config.endpoint.urls,
84+
headers=headers,
85+
)
86+
)
87+
88+
# `UserConfig` computed an artifact directory during config-load.
89+
# If it used the default artifact directory (not overridden by the
90+
# user), update it to reflect the discovered model name.
91+
if "artifact_directory" not in user_config.output.model_fields_set:
92+
user_config.output.artifact_directory = OutputDefaults.ARTIFACT_DIRECTORY
93+
user_config.output.artifact_directory = (
94+
user_config._compute_artifact_directory()
95+
)
96+
5697
if user_config.endpoint.wait_for_model_timeout > 0:
57-
import asyncio
5898
import logging
5999

60100
from aiperf.common.readiness_probe import wait_for_endpoint

src/aiperf/common/config/endpoint_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,12 @@ def validate_wait_for_model_coherent(self) -> Self:
8585
model_names: Annotated[
8686
list[str],
8787
Field(
88-
..., # This must be set by the user
89-
description="Model name(s) to be benchmarked. Can be a comma-separated list or a single model name.",
88+
default_factory=list,
89+
description=(
90+
"Model name(s) to be benchmarked. Can be a comma-separated list or a "
91+
"single model name. If omitted, `aiperf profile` will attempt to "
92+
"auto-detect a model from `GET {url}/v1/models`."
93+
),
9094
),
9195
BeforeValidator(parse_str_or_list),
9296
CLIParameter(

src/aiperf/common/config/user_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,11 @@ def _compute_artifact_directory(self) -> Path:
760760

761761
def _get_artifact_model_name(self) -> str:
762762
"""Get the artifact model name based on the user selected options."""
763+
if not self.endpoint.model_names:
764+
# When --model is omitted, `aiperf profile` will auto-detect models
765+
# later. Use a safe placeholder so config-load doesn't crash.
766+
return "auto"
767+
763768
model_name: str = self.endpoint.model_names[0]
764769
if len(self.endpoint.model_names) > 1:
765770
model_name = f"{model_name}_multi"
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""
4+
Model autodetection helpers
5+
6+
Used by `aiperf profile` when `--model/--model-names` is omitted:
7+
attempt to fetch the server's model list from `GET {base_url}/v1/models`.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from typing import Any
13+
14+
import orjson
15+
16+
from aiperf.common.aiperf_logger import AIPerfLogger
17+
from aiperf.transports.aiohttp_client import AioHttpClient
18+
19+
_logger = AIPerfLogger(__name__)
20+
21+
22+
async def autodetect_model_names_from_v1_models(
23+
*,
24+
urls: list[str],
25+
headers: dict[str, str],
26+
timeout_s: float = 10.0,
27+
) -> list[str]:
28+
"""Fetch `GET {url}/v1/models` and return a best-effort model list.
29+
30+
Selection strategy: return only the first discovered model id.
31+
"""
32+
33+
if not urls:
34+
raise ValueError("Autodetection requires at least one --url base URL")
35+
36+
# Use the first URL for discovery. If you have multiple URLs with
37+
# different model sets, you should pass --model explicitly.
38+
base_url = urls[0].rstrip("/")
39+
models_url = base_url + "/v1/models"
40+
41+
client = AioHttpClient(timeout=timeout_s)
42+
try:
43+
record = await client.get_request(models_url, headers=headers)
44+
finally:
45+
await client.close()
46+
47+
status = record.status
48+
if status != 200:
49+
raise ValueError(
50+
f"Failed to auto-detect models from {models_url}: HTTP status={status}"
51+
)
52+
53+
if not record.responses:
54+
raise ValueError(f"Empty response body while autodetecting {models_url}")
55+
56+
response_obj: Any = record.responses[0]
57+
body_text = getattr(response_obj, "text", None)
58+
if not isinstance(body_text, str) or not body_text:
59+
raise ValueError(f"Non-text response while autodetecting {models_url}")
60+
61+
try:
62+
payload = orjson.loads(body_text)
63+
except orjson.JSONDecodeError as e:
64+
raise ValueError(
65+
f"Invalid JSON returned from {models_url} while autodetecting models"
66+
) from e
67+
68+
if not isinstance(payload, dict):
69+
raise ValueError(f"Unexpected /v1/models response shape from {models_url}")
70+
71+
data = payload.get("data")
72+
if not isinstance(data, list):
73+
raise ValueError(f"Unexpected /v1/models response: missing data[] in {models_url}")
74+
75+
ids: list[str] = []
76+
for entry in data:
77+
if isinstance(entry, dict):
78+
model_id = entry.get("id")
79+
if isinstance(model_id, str) and model_id:
80+
ids.append(model_id)
81+
82+
if not ids:
83+
raise ValueError(f"No model ids found in /v1/models response from {models_url}")
84+
85+
chosen = ids[0]
86+
if len(ids) > 1:
87+
_logger.warning(
88+
f"{len(ids)} models returned by {models_url}; "
89+
"pass --model to select one explicitly"
90+
)
91+
_logger.warning(
92+
f"No --model provided; using first listed model '{chosen}'"
93+
)
94+
else:
95+
_logger.info(f"Auto-detected model '{chosen}' from {models_url}")
96+
return [chosen]
97+

tests/unit/common/config/test_endpoint_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def test_endpoint_config_defaults():
1818
the configuration is initialized correctly with expected default values.
1919
"""
2020

21-
# NOTE: Model names must be filled out
21+
# Model names default to [] so config-load doesn't crash when
22+
# `--model` is omitted (e.g. `aiperf profile` autodetects).
2223
config = EndpointConfig(model_names=["gpt2"])
2324

2425
assert config.model_selection_strategy == EndpointDefaults.MODEL_SELECTION_STRATEGY
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import asyncio
7+
from typing import Any
8+
9+
import orjson
10+
import pytest
11+
12+
from aiperf.common.models.model_autodetect import (
13+
autodetect_model_names_from_v1_models,
14+
)
15+
16+
17+
class _FakeRecord:
18+
def __init__(self, *, status: int, body_text: str) -> None:
19+
self.status = status
20+
resp = type("_Resp", (), {"text": body_text})()
21+
self.responses = [resp]
22+
23+
24+
class _FakeClient:
25+
def __init__(self, *, status: int, body_text: str) -> None:
26+
self._status = status
27+
self._body_text = body_text
28+
self.urls: list[str] = []
29+
self.headers: list[dict[str, str]] = []
30+
self.closed = False
31+
32+
async def get_request(
33+
self, url: str, headers: dict[str, str], **_: Any
34+
) -> _FakeRecord:
35+
self.urls.append(url)
36+
self.headers.append(headers)
37+
return _FakeRecord(status=self._status, body_text=self._body_text)
38+
39+
async def close(self) -> None:
40+
self.closed = True
41+
42+
43+
def _install_fake_aiohttp(
44+
monkeypatch: pytest.MonkeyPatch, *, status: int, body_text: str
45+
) -> _FakeClient:
46+
fake = _FakeClient(status=status, body_text=body_text)
47+
48+
def _factory(*_: Any, **__: Any) -> _FakeClient:
49+
return fake
50+
51+
monkeypatch.setattr(
52+
"aiperf.common.models.model_autodetect.AioHttpClient",
53+
_factory,
54+
)
55+
return fake
56+
57+
58+
def test_autodetect_picks_first_id_from_data(
59+
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
60+
) -> None:
61+
import logging
62+
63+
caplog.set_level(logging.WARNING, logger="aiperf.common.models.model_autodetect")
64+
body_text = orjson.dumps(
65+
{"data": [{"id": "model-a"}, {"id": "model-b"}]}
66+
).decode("utf-8")
67+
fake = _install_fake_aiohttp(
68+
monkeypatch, status=200, body_text=body_text
69+
)
70+
71+
result = asyncio.run(
72+
autodetect_model_names_from_v1_models(
73+
urls=["http://localhost:8000"],
74+
headers={"Authorization": "Bearer token"},
75+
timeout_s=1.0,
76+
)
77+
)
78+
79+
assert result == ["model-a"]
80+
assert fake.urls == ["http://localhost:8000/v1/models"]
81+
assert fake.headers[0]["Authorization"] == "Bearer token"
82+
assert fake.closed is True
83+
assert "2 models returned" in caplog.text
84+
assert "pass --model" in caplog.text
85+
assert "first listed model 'model-a'" in caplog.text
86+
87+
88+
def test_autodetect_single_model_logs_info_not_warning(
89+
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
90+
) -> None:
91+
import logging
92+
93+
caplog.set_level(logging.INFO, logger="aiperf.common.models.model_autodetect")
94+
body_text = orjson.dumps({"data": [{"id": "only-one"}]}).decode("utf-8")
95+
_install_fake_aiohttp(monkeypatch, status=200, body_text=body_text)
96+
97+
asyncio.run(
98+
autodetect_model_names_from_v1_models(
99+
urls=["http://localhost:8000"],
100+
headers={},
101+
timeout_s=1.0,
102+
)
103+
)
104+
105+
assert "Auto-detected model 'only-one'" in caplog.text
106+
assert "pass --model" not in caplog.text
107+
108+
109+
def test_autodetect_raises_on_non_200(monkeypatch: pytest.MonkeyPatch) -> None:
110+
fake_body_text = "oops"
111+
_install_fake_aiohttp(
112+
monkeypatch, status=404, body_text=fake_body_text
113+
)
114+
115+
with pytest.raises(ValueError, match="Failed to auto-detect models"):
116+
asyncio.run(
117+
autodetect_model_names_from_v1_models(
118+
urls=["http://localhost:8000"],
119+
headers={},
120+
timeout_s=1.0,
121+
)
122+
)

0 commit comments

Comments
 (0)