Skip to content

Commit cb4bfe3

Browse files
committed
pre-commit fixes
Signed-off-by: Liana Koleva <43767763+lianakoleva@users.noreply.github.com>
1 parent 274336b commit cb4bfe3

4 files changed

Lines changed: 122 additions & 124 deletions

File tree

docs/cli-options.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ aiperf profile --model your_model --url localhost:8000 --goodput "request_latenc
118118

119119
#### `-m`, `--model-names`, `--model` `<list>`
120120

121-
Model name(s) to be benchmarked. Can be a comma-separated list or a single model name.
122-
If omitted, `aiperf profile` attempts to auto-detect a model from `GET {url}/v1/models`.
121+
Model name(s) to be benchmarked. Can be a comma-separated list or a single model name. If omitted, `aiperf profile` will attempt to auto-detect a model from `GET {url}/v1/models`.
123122

124123
#### `--model-selection-strategy` `<str>`
125124

src/aiperf/cli_commands/profile.py

Lines changed: 76 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""CLI command for running the Profile subcommand."""
44

55
import asyncio
6+
import logging
67

78
from cyclopts import App
89

@@ -11,6 +12,79 @@
1112
app = App(name="profile")
1213

1314

15+
def _ensure_stderr_logging() -> None:
16+
if not logging.getLogger().handlers:
17+
logging.basicConfig(
18+
level=logging.INFO,
19+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
20+
)
21+
22+
23+
def _request_headers_for_endpoint(user_config: UserConfig) -> dict[str, str]:
24+
raw_headers = user_config.input.headers or []
25+
headers = {str(k): str(v) for k, v in raw_headers}
26+
if user_config.endpoint.api_key:
27+
headers["Authorization"] = f"Bearer {user_config.endpoint.api_key}"
28+
return headers
29+
30+
31+
def _maybe_autodiscover_models(user_config: UserConfig) -> None:
32+
# If the user didn't provide --model/--model-names, try to discover
33+
# one from the server's OpenAI-compatible model list.
34+
if user_config.endpoint.model_names:
35+
return
36+
37+
from aiperf.common.config.config_defaults import OutputDefaults
38+
from aiperf.common.models.model_autodetect import (
39+
autodetect_names,
40+
)
41+
42+
# Install a basic stderr handler so the log message is visible even
43+
# when `--wait-for-model-timeout` is left at the default (0).
44+
_ensure_stderr_logging()
45+
46+
user_config.endpoint.model_names = asyncio.run(
47+
autodetect_names(
48+
urls=user_config.endpoint.urls,
49+
headers=_request_headers_for_endpoint(user_config),
50+
)
51+
)
52+
53+
# `UserConfig` computed an artifact directory during config-load.
54+
# If it used the default artifact directory (not overridden by the
55+
# user), update it to reflect the discovered model name.
56+
if "artifact_directory" not in user_config.output.model_fields_set:
57+
user_config.output.artifact_directory = OutputDefaults.ARTIFACT_DIRECTORY
58+
user_config.output.artifact_directory = (
59+
user_config._compute_artifact_directory()
60+
)
61+
62+
63+
def _maybe_wait_for_model(user_config: UserConfig) -> None:
64+
if user_config.endpoint.wait_for_model_timeout <= 0:
65+
return
66+
67+
from aiperf.common.readiness_probe import wait_for_endpoint
68+
69+
# The probe runs before `run_system_controller` (which installs
70+
# rich logging), so there are no handlers attached yet. Install
71+
# a basic stderr handler so probe log messages are visible.
72+
_ensure_stderr_logging()
73+
74+
asyncio.run(
75+
wait_for_endpoint(
76+
urls=user_config.endpoint.urls,
77+
model_names=user_config.endpoint.model_names,
78+
mode=user_config.endpoint.wait_for_model_mode,
79+
endpoint_type=str(user_config.endpoint.type),
80+
custom_endpoint=user_config.endpoint.custom_endpoint,
81+
timeout_s=user_config.endpoint.wait_for_model_timeout,
82+
interval_s=user_config.endpoint.wait_for_model_interval,
83+
headers=_request_headers_for_endpoint(user_config),
84+
)
85+
)
86+
87+
1488
@app.default
1589
def profile(
1690
user_config: UserConfig,
@@ -54,78 +128,6 @@ def profile(
54128
from aiperf.common.config.loader import load_service_config
55129

56130
service_config = service_config or load_service_config()
57-
58-
# If the user didn't provide --model/--model-names, try to discover
59-
# one from the server's OpenAI-compatible model list.
60-
if not user_config.endpoint.model_names:
61-
import logging
62-
63-
from aiperf.common.config.config_defaults import OutputDefaults
64-
from aiperf.common.models.model_autodetect import (
65-
autodetect_model_names_from_v1_models,
66-
)
67-
68-
# Install a basic stderr handler so the log message is visible even
69-
# when `--wait-for-model-timeout` is left at the default (0).
70-
if not logging.getLogger().handlers:
71-
logging.basicConfig(
72-
level=logging.INFO,
73-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
74-
)
75-
76-
raw_headers = user_config.input.headers or []
77-
headers = {str(k): str(v) for k, v in raw_headers}
78-
if user_config.endpoint.api_key:
79-
headers["Authorization"] = f"Bearer {user_config.endpoint.api_key}"
80-
81-
user_config.endpoint.model_names = asyncio.run(
82-
autodetect_model_names_from_v1_models(
83-
urls=user_config.endpoint.urls,
84-
headers=headers,
85-
)
86-
)
87-
88-
# `UserConfig` computed an artifact directory during config-load.
89-
# If it used the default artifact directory (not overridden by the
90-
# user), update it to reflect the discovered model name.
91-
if "artifact_directory" not in user_config.output.model_fields_set:
92-
user_config.output.artifact_directory = (
93-
OutputDefaults.ARTIFACT_DIRECTORY
94-
)
95-
user_config.output.artifact_directory = (
96-
user_config._compute_artifact_directory()
97-
)
98-
99-
if user_config.endpoint.wait_for_model_timeout > 0:
100-
import logging
101-
102-
from aiperf.common.readiness_probe import wait_for_endpoint
103-
104-
# The probe runs before `run_system_controller` (which installs
105-
# rich logging), so there are no handlers attached yet. Install
106-
# a basic stderr handler so probe log messages are visible.
107-
if not logging.getLogger().handlers:
108-
logging.basicConfig(
109-
level=logging.INFO,
110-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
111-
)
112-
113-
raw_headers = user_config.input.headers or []
114-
headers = {str(k): str(v) for k, v in raw_headers}
115-
if user_config.endpoint.api_key:
116-
headers["Authorization"] = f"Bearer {user_config.endpoint.api_key}"
117-
118-
asyncio.run(
119-
wait_for_endpoint(
120-
urls=user_config.endpoint.urls,
121-
model_names=user_config.endpoint.model_names,
122-
mode=user_config.endpoint.wait_for_model_mode,
123-
endpoint_type=str(user_config.endpoint.type),
124-
custom_endpoint=user_config.endpoint.custom_endpoint,
125-
timeout_s=user_config.endpoint.wait_for_model_timeout,
126-
interval_s=user_config.endpoint.wait_for_model_interval,
127-
headers=headers,
128-
)
129-
)
130-
131+
_maybe_autodiscover_models(user_config)
132+
_maybe_wait_for_model(user_config)
131133
run_system_controller(user_config, service_config)

src/aiperf/common/models/model_autodetect.py

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,71 +19,70 @@
1919
_logger = AIPerfLogger(__name__)
2020

2121

22-
async def autodetect_model_names_from_v1_models(
23-
*,
24-
urls: list[str],
25-
headers: dict[str, str],
26-
timeout_s: float = 10.0,
27-
) -> list[str]:
28-
"""Fetch `GET {url}/v1/models` and return a best-effort model list.
29-
30-
Selection strategy: return only the first discovered model id.
31-
"""
32-
33-
if not urls:
34-
raise ValueError("Autodetection requires at least one --url base URL")
35-
36-
# Use the first URL for discovery. If you have multiple URLs with
37-
# different model sets, you should pass --model explicitly.
38-
base_url = urls[0].rstrip("/")
39-
models_url = base_url + "/v1/models"
40-
41-
client = AioHttpClient(timeout=timeout_s)
42-
try:
43-
record = await client.get_request(models_url, headers=headers)
44-
finally:
45-
await client.close()
46-
47-
status = record.status
48-
if status != 200:
22+
def _extract_body_text(record: Any, models_url: str) -> str:
23+
if record.status != 200:
4924
raise ValueError(
50-
f"Failed to auto-detect models from {models_url}: HTTP status={status}"
25+
f"Failed to auto-detect models from {models_url}: HTTP status={record.status}"
5126
)
52-
5327
if not record.responses:
5428
raise ValueError(f"Empty response body while autodetecting {models_url}")
55-
56-
response_obj: Any = record.responses[0]
57-
body_text = getattr(response_obj, "text", None)
29+
body_text = getattr(record.responses[0], "text", None)
5830
if not isinstance(body_text, str) or not body_text:
5931
raise ValueError(f"Non-text response while autodetecting {models_url}")
32+
return body_text
33+
6034

35+
def _extract_ids_from_payload(body_text: str, models_url: str) -> list[str]:
6136
try:
6237
payload = orjson.loads(body_text)
6338
except orjson.JSONDecodeError as e:
6439
raise ValueError(
6540
f"Invalid JSON returned from {models_url} while autodetecting models"
6641
) from e
67-
6842
if not isinstance(payload, dict):
6943
raise ValueError(f"Unexpected /v1/models response shape from {models_url}")
70-
7144
data = payload.get("data")
7245
if not isinstance(data, list):
7346
raise ValueError(
7447
f"Unexpected /v1/models response: missing data[] in {models_url}"
7548
)
76-
77-
ids: list[str] = []
78-
for entry in data:
79-
if isinstance(entry, dict):
80-
model_id = entry.get("id")
81-
if isinstance(model_id, str) and model_id:
82-
ids.append(model_id)
83-
49+
ids = [
50+
entry["id"]
51+
for entry in data
52+
if isinstance(entry, dict)
53+
and isinstance(entry.get("id"), str)
54+
and entry.get("id")
55+
]
8456
if not ids:
8557
raise ValueError(f"No model ids found in /v1/models response from {models_url}")
58+
return ids
59+
60+
61+
async def autodetect_names(
62+
*,
63+
urls: list[str],
64+
headers: dict[str, str],
65+
timeout_s: float = 10.0,
66+
) -> list[str]:
67+
"""Fetch `GET {url}/v1/models` and return a best-effort model list.
68+
69+
Selection strategy: return only the first discovered model id.
70+
"""
71+
if not urls:
72+
raise ValueError("Autodetection requires at least one --url base URL")
73+
74+
# Use the first URL for discovery. If you have multiple URLs with
75+
# different model sets, you should pass --model explicitly.
76+
base_url = urls[0].rstrip("/")
77+
models_url = base_url + "/v1/models"
78+
79+
client = AioHttpClient(timeout=timeout_s)
80+
try:
81+
record = await client.get_request(models_url, headers=headers)
82+
finally:
83+
await client.close()
8684

85+
ids = _extract_ids_from_payload(_extract_body_text(record, models_url), models_url)
8786
chosen = ids[0]
8887
if len(ids) > 1:
8988
_logger.warning(

tests/unit/common/test_model_autodetect.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
import orjson
1010
import pytest
1111

12-
from aiperf.common.models.model_autodetect import (
13-
autodetect_model_names_from_v1_models,
14-
)
12+
from aiperf.common.models.model_autodetect import autodetect_names
1513

1614

1715
class _FakeRecord:
@@ -67,7 +65,7 @@ def test_autodetect_picks_first_id_from_data(
6765
fake = _install_fake_aiohttp(monkeypatch, status=200, body_text=body_text)
6866

6967
result = asyncio.run(
70-
autodetect_model_names_from_v1_models(
68+
autodetect_names(
7169
urls=["http://localhost:8000"],
7270
headers={"Authorization": "Bearer token"},
7371
timeout_s=1.0,
@@ -93,7 +91,7 @@ def test_autodetect_single_model_logs_info_not_warning(
9391
_install_fake_aiohttp(monkeypatch, status=200, body_text=body_text)
9492

9593
asyncio.run(
96-
autodetect_model_names_from_v1_models(
94+
autodetect_names(
9795
urls=["http://localhost:8000"],
9896
headers={},
9997
timeout_s=1.0,
@@ -110,7 +108,7 @@ def test_autodetect_raises_on_non_200(monkeypatch: pytest.MonkeyPatch) -> None:
110108

111109
with pytest.raises(ValueError, match="Failed to auto-detect models"):
112110
asyncio.run(
113-
autodetect_model_names_from_v1_models(
111+
autodetect_names(
114112
urls=["http://localhost:8000"],
115113
headers={},
116114
timeout_s=1.0,

0 commit comments

Comments
 (0)