Skip to content

Commit 543d53c

Browse files
committed
feat(serve): add SageMaker GenAI inference benchmarking and recommendation
Adds sagemaker.serve.ai_inference_recommender, a thin ergonomic layer over sagemaker-core's AIBenchmarkJob, AIRecommendationJob, and AIWorkloadConfig resources. ModelBuilder gains a new entry point and extends two existing verbs: # Benchmark a deployed endpoint job = mb.start_benchmark(endpoint=ep, workload=Workload.synthetic(...)) result = BenchmarkResult.from_job(job) # Recommendation flow extends optimize() and deploy() mb.optimize(workload=..., performance_target="throughput", instance_types=["ml.g6.12xlarge"]) endpoint = mb.deploy(role=role) # top recommendation endpoint = mb.deploy(role=role, recommendation_index=2) # alternative print(result) and print(mb.recommendations[0]) render their data as tables. Public surface added under sagemaker.serve: * Workload -- typed factory; extras pass through **params, validated server-side. * BenchmarkResult / BenchmarkMetrics / BenchmarkMetric -- parses the AIPerf output.tar.gz from S3. * Secret -- opt-in helper for tokens >512 chars (Secrets Manager). * BenchmarkJob, RecommendationJob -- re-exports without the AI prefix. * FeatureGatedError, WorkloadValidationError -- typed exceptions. Pin-mode and workload-mode optimize() kwargs are mutually exclusive. Recommendation deploy uses the ModelPackage path (auto-approves the package the rec job publishes). Includes 51 unit tests and 2 slow_test integ tests (tests/integ/test_ai_inference_recommender_integration.py) verified end-to-end against real AWS. Rebased onto upstream to pick up #5860 (preserve falsy values in sagemaker-core serialize), required so optimize_model=False reaches the wire.
1 parent 1b68ecf commit 543d53c

9 files changed

Lines changed: 774 additions & 10 deletions

File tree

sagemaker-serve/src/sagemaker/serve/ai_inference_recommender/_model_builder_methods.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
AIBenchmarkOutputConfig,
3232
AIBenchmarkTarget,
3333
AICapacityReservationConfig,
34+
AIDatasetConfig,
3435
AIModelSource,
3536
AIModelSourceS3,
3637
AIRecommendationComputeSpec,
@@ -39,6 +40,9 @@
3940
AIRecommendationOutputConfig,
4041
AIRecommendationPerformanceTarget,
4142
AIWorkloadConfigs,
43+
AIWorkloadDataSource,
44+
AIWorkloadInputDataConfig,
45+
AIWorkloadS3DataSource,
4246
Tag,
4347
VpcConfig,
4448
WorkloadSpec,
@@ -120,6 +124,8 @@ def start_benchmark(
120124
network_config=network_config,
121125
tags=tags,
122126
)
127+
if builder is not None:
128+
builder._benchmark_job = job
123129
if wait:
124130
job.wait()
125131
return job
@@ -256,11 +262,25 @@ def _ensure_workload_config(
256262
return workload
257263

258264
config_name = name or f"sm-wl-{int(time.time())}-{uuid.uuid4().hex[:8]}"
265+
dataset_config = None
266+
if workload.dataset_channels:
267+
dataset_config = AIDatasetConfig(
268+
input_data_config=[
269+
AIWorkloadInputDataConfig(
270+
channel_name=channel.channel_name,
271+
data_source=AIWorkloadDataSource(
272+
s3_data_source=AIWorkloadS3DataSource(s3_uri=channel.s3_uri),
273+
),
274+
)
275+
for channel in workload.dataset_channels
276+
],
277+
)
259278
AIWorkloadConfig.create(
260279
ai_workload_config_name=config_name,
261280
ai_workload_configs=AIWorkloadConfigs(
262281
workload_spec=WorkloadSpec(inline=workload.to_inline()),
263282
),
283+
dataset_config=dataset_config,
264284
tags=tags,
265285
)
266286
return config_name

sagemaker-serve/src/sagemaker/serve/ai_inference_recommender/result.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ class BenchmarkResult:
115115

116116
metrics: BenchmarkMetrics
117117
s3_output_location: str
118+
endpoint: Optional[str] = None
119+
workload_config: Optional[str] = None
120+
tool_version: Optional[str] = None
118121
profile: Dict[str, Any] = field(default_factory=dict)
119122

120123
def __repr__(self) -> str:
@@ -134,6 +137,9 @@ def __repr__(self) -> str:
134137
table = _format_metrics_table(ordered)
135138
return (
136139
f"BenchmarkResult\n"
140+
f" endpoint: {self.endpoint or '-'}\n"
141+
f" workload_config: {self.workload_config or '-'}\n"
142+
f" tool_version: {self.tool_version or '-'}\n"
137143
f" s3_output_location: {self.s3_output_location}\n"
138144
f" metrics:\n{_indent(table, ' ')}\n"
139145
f" raw profile available via .profile"
@@ -148,6 +154,10 @@ def from_job(
148154
) -> "BenchmarkResult":
149155
"""Download and parse the benchmark output for a completed ``AIBenchmarkJob``.
150156
157+
Populates ``endpoint``, ``workload_config``, and ``tool_version`` from
158+
the job's ``BenchmarkTarget`` and ``WorkloadConfigIdentifier`` plus the
159+
AIPerf profile metadata so the parsed result is self-describing.
160+
151161
Args:
152162
job: An ``AIBenchmarkJob`` (or ``BenchmarkJob`` re-export) that has
153163
reached a terminal state.
@@ -183,21 +193,35 @@ def from_job(
183193
f"AIBenchmarkJob {job.get_name()} has no S3OutputLocation "
184194
f"(status={status}). {hint}"
185195
)
186-
return cls.from_s3(job.output_config.s3_output_location, session=session)
196+
workload_config = getattr(job, "ai_workload_config_identifier", None)
197+
return cls.from_s3(
198+
job.output_config.s3_output_location,
199+
session=session,
200+
endpoint=_extract_endpoint(job),
201+
# Normalize falsy sentinels (e.g. unset optional fields) to None
202+
# so the result renders cleanly when fields are missing.
203+
workload_config=workload_config or None,
204+
)
187205

188206
@classmethod
189207
def from_s3(
190208
cls,
191209
s3_output_location: str,
192210
*,
193211
session: Optional[boto3.session.Session] = None,
212+
endpoint: Optional[str] = None,
213+
workload_config: Optional[str] = None,
194214
) -> "BenchmarkResult":
195215
"""Download and parse the benchmark output artifact from S3.
196216
197217
Args:
198218
s3_output_location: ``s3://bucket/prefix/`` location written by
199219
the benchmark job.
200220
session: Optional boto3 session. Defaults to the ambient session.
221+
endpoint: Optional endpoint identifier to attach to the result.
222+
Threaded through by :meth:`from_job`.
223+
workload_config: Optional workload-config identifier to attach.
224+
Threaded through by :meth:`from_job`.
201225
202226
Returns:
203227
A parsed ``BenchmarkResult``.
@@ -216,10 +240,39 @@ def from_s3(
216240
return cls(
217241
metrics=BenchmarkMetrics.from_profile_json(profile),
218242
s3_output_location=s3_output_location,
243+
endpoint=endpoint,
244+
workload_config=workload_config,
245+
tool_version=_extract_tool_version(profile),
219246
profile=profile,
220247
)
221248

222249

250+
def _extract_endpoint(job) -> Optional[str]:
251+
target = getattr(job, "benchmark_target", None) or None
252+
endpoint = (getattr(target, "endpoint", None) or None) if target else None
253+
identifier = getattr(endpoint, "identifier", None) if endpoint else None
254+
return identifier or None
255+
256+
257+
def _extract_tool_version(profile: Dict[str, Any]) -> Optional[str]:
258+
"""Best-effort lookup of the AIPerf tool version from the profile JSON.
259+
260+
AIPerf has no single canonical key; we check a few plausible top-level
261+
locations and return the first string we find.
262+
"""
263+
for key in ("aiperf_version", "tool_version", "version"):
264+
value = profile.get(key)
265+
if isinstance(value, str):
266+
return value
267+
meta = profile.get("metadata") or profile.get("meta") or {}
268+
if isinstance(meta, dict):
269+
for key in ("aiperf_version", "tool_version", "version"):
270+
value = meta.get(key)
271+
if isinstance(value, str):
272+
return value
273+
return None
274+
275+
223276
def _parse_s3_uri(uri: str) -> tuple:
224277
parsed = urlparse(uri)
225278
if parsed.scheme != "s3":

sagemaker-serve/src/sagemaker/serve/ai_inference_recommender/workload.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,27 @@
1414
from __future__ import absolute_import
1515

1616
import json
17-
from typing import Any, Dict, Optional, Union
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
from pydantic import BaseModel, ConfigDict, Field
2020

2121
from sagemaker.serve.ai_inference_recommender.secrets import Secret
2222

2323

24+
class _DatasetChannel(BaseModel):
25+
"""Internal record of an S3 channel that should be mounted at job runtime.
26+
27+
Used by ``Workload.from_dataset`` to capture the S3 URI and channel name
28+
so :class:`AIWorkloadConfig` can be created with the matching
29+
``DatasetConfig``. AIPerf reads the dataset from the container path
30+
declared in ``parameters.input_file``; this record exists only to plumb
31+
the S3-side configuration through to the service.
32+
"""
33+
34+
channel_name: str
35+
s3_uri: str
36+
37+
2438
class Workload(BaseModel):
2539
"""A workload specification used by benchmark and recommendation jobs."""
2640

@@ -29,6 +43,7 @@ class Workload(BaseModel):
2943
parameters: Dict[str, Any]
3044
secrets: Dict[str, Union[str, Secret]] = Field(default_factory=dict)
3145
tooling: Dict[str, Any] = Field(default_factory=lambda: {"api_standard": "openai"})
46+
dataset_channels: List[_DatasetChannel] = Field(default_factory=list)
3247

3348
@classmethod
3449
def synthetic(
@@ -47,6 +62,11 @@ def synthetic(
4762
) -> "Workload":
4863
"""Build a workload that uses synthetic prompts.
4964
65+
Synthetic prompts are generated by AIPerf from the Sonnet dataset,
66+
producing realistic token distributions. Use
67+
:meth:`Workload.from_dataset` to drive the benchmark from a real
68+
request trace instead.
69+
5070
Args:
5171
tokenizer: HuggingFace tokenizer id (e.g. ``meta-llama/Llama-3.2-1B``).
5272
concurrency: Number of in-flight requests.
@@ -77,6 +97,88 @@ def synthetic(
7797
secrets["hf_token"] = hf_token
7898
return cls(parameters=parameters, secrets=secrets)
7999

100+
@classmethod
101+
def from_dataset(
102+
cls,
103+
*,
104+
s3_uri: str,
105+
channel_name: str,
106+
input_file: str,
107+
custom_dataset_type: Optional[str] = None,
108+
tokenizer: Optional[str] = None,
109+
concurrency: int = 1,
110+
request_count: int = 100,
111+
streaming: bool = True,
112+
hf_token: Optional[Union[str, Secret]] = None,
113+
**params: Any,
114+
) -> "Workload":
115+
"""Build a workload that drives traffic from an S3-hosted dataset.
116+
117+
The S3 location is mounted into the AIPerf container at
118+
``/opt/ml/input/data/{channel_name}/`` by the SageMaker AI inference
119+
recommender service via the workload config's ``DatasetConfig``.
120+
AIPerf reads the file declared by ``input_file`` from inside that
121+
mount.
122+
123+
Args:
124+
s3_uri: ``s3://bucket/prefix/`` (or single-object) URI containing
125+
the dataset. Mounted under the container at
126+
``/opt/ml/input/data/{channel_name}/``.
127+
channel_name: Logical channel name. Used as the directory under
128+
``/opt/ml/input/data/`` where the S3 contents are mounted.
129+
input_file: Container-internal path to the file AIPerf should
130+
read, e.g. ``/opt/ml/input/data/traffic/requests.jsonl``.
131+
Must match ``s3_uri`` + ``channel_name`` for the mount to
132+
contain the file.
133+
custom_dataset_type: Optional AIPerf custom-dataset format
134+
(e.g. ``"openai-chat"``).
135+
tokenizer: Optional HuggingFace tokenizer id; required for some
136+
AIPerf metrics that compute per-token statistics.
137+
concurrency: Number of in-flight requests.
138+
request_count: Total number of requests to issue.
139+
streaming: Whether to use streaming chat completions.
140+
hf_token: HuggingFace access token for gated tokenizers. Accepts
141+
a ``Secret`` or a Secrets Manager ARN string.
142+
**params: Additional parameters merged into the workload's
143+
``parameters`` map.
144+
145+
Returns:
146+
A ``Workload`` whose inline payload references ``input_file`` and
147+
whose ``dataset_channels`` carry the S3 URI for the service to
148+
mount.
149+
"""
150+
if not s3_uri.startswith("s3://"):
151+
raise ValueError(
152+
f"s3_uri must start with 's3://'; got {s3_uri!r}."
153+
)
154+
expected_prefix = f"/opt/ml/input/data/{channel_name}/"
155+
if not input_file.startswith(expected_prefix):
156+
raise ValueError(
157+
f"input_file must live under {expected_prefix!r} so the "
158+
f"mounted channel contains it; got {input_file!r}."
159+
)
160+
parameters: Dict[str, Any] = {
161+
"input_file": input_file,
162+
"concurrency": concurrency,
163+
"request_count": request_count,
164+
"streaming": streaming,
165+
**params,
166+
}
167+
if custom_dataset_type is not None:
168+
parameters["custom_dataset_type"] = custom_dataset_type
169+
if tokenizer is not None:
170+
parameters["tokenizer"] = tokenizer
171+
secrets: Dict[str, Union[str, Secret]] = {}
172+
if hf_token is not None:
173+
secrets["hf_token"] = hf_token
174+
return cls(
175+
parameters=parameters,
176+
secrets=secrets,
177+
dataset_channels=[
178+
_DatasetChannel(channel_name=channel_name, s3_uri=s3_uri),
179+
],
180+
)
181+
80182
def to_inline(self) -> str:
81183
"""Serialize the workload to a JSON string.
82184

0 commit comments

Comments
 (0)