Skip to content

Commit 31f40bd

Browse files
committed
feat(serve): add SageMaker GenAI inference benchmarking and recommendation
Adds sagemaker.serve.ai_inference_recommender, a thin ergonomic layer over sagemaker-core's AIBenchmarkJob, AIRecommendationJob, and AIWorkloadConfig resources. ModelBuilder gains a new entry point and extends two existing verbs: # Benchmark a deployed endpoint job = mb.start_benchmark(endpoint=ep, workload=Workload.synthetic(...)) result = BenchmarkResult.from_job(job) # Recommendation flow extends optimize() and deploy() mb.optimize(workload=..., performance_target="throughput", instance_types=["ml.g6.12xlarge"]) endpoint = mb.deploy(role=role) # top recommendation endpoint = mb.deploy(role=role, recommendation_index=2) # alternative print(result) and print(mb.recommendations[0]) render their data as tables. Public surface added under sagemaker.serve: * Workload -- typed factory; extras pass through **params, validated server-side. * BenchmarkResult / BenchmarkMetrics / BenchmarkMetric -- parses the AIPerf output.tar.gz from S3. * Secret -- opt-in helper for tokens >512 chars (Secrets Manager). * BenchmarkJob, RecommendationJob -- re-exports without the AI prefix. * FeatureGatedError, WorkloadValidationError -- typed exceptions. Pin-mode and workload-mode optimize() kwargs are mutually exclusive. Recommendation deploy uses the ModelPackage path (auto-approves the package the rec job publishes). Includes 51 unit tests and 2 slow_test integ tests (tests/integ/test_ai_inference_recommender_integration.py) verified end-to-end against real AWS. Rebased onto upstream to pick up #5860 (preserve falsy values in sagemaker-core serialize), required so optimize_model=False reaches the wire.
1 parent addb37b commit 31f40bd

19 files changed

Lines changed: 2460 additions & 4 deletions

sagemaker-serve/src/sagemaker/serve/__init__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,25 @@
2929
from sagemaker.serve.utils.types import ModelServer
3030
from sagemaker.serve.model_builder import ModelBuilder
3131

32-
__all__ = ["InferenceSpec", "ModelServer", "ModelBuilder"]
32+
from sagemaker.serve.ai_inference_recommender import (
33+
BenchmarkJob,
34+
BenchmarkResult,
35+
RecommendationJob,
36+
Secret,
37+
Workload,
38+
FeatureGatedError,
39+
WorkloadValidationError,
40+
)
41+
42+
__all__ = [
43+
"InferenceSpec",
44+
"ModelServer",
45+
"ModelBuilder",
46+
"BenchmarkJob",
47+
"BenchmarkResult",
48+
"RecommendationJob",
49+
"Secret",
50+
"Workload",
51+
"FeatureGatedError",
52+
"WorkloadValidationError",
53+
]
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""SageMaker GenAI inference benchmarking and recommendation."""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.serve.ai_inference_recommender.exceptions import (
17+
FeatureGatedError,
18+
WorkloadValidationError,
19+
)
20+
from sagemaker.serve.ai_inference_recommender.jobs import (
21+
BenchmarkJob,
22+
RecommendationJob,
23+
)
24+
from sagemaker.serve.ai_inference_recommender.result import (
25+
BenchmarkMetric,
26+
BenchmarkMetrics,
27+
BenchmarkResult,
28+
)
29+
from sagemaker.serve.ai_inference_recommender.secrets import Secret
30+
from sagemaker.serve.ai_inference_recommender.workload import Workload
31+
32+
33+
__all__ = [
34+
"BenchmarkJob",
35+
"BenchmarkMetric",
36+
"BenchmarkMetrics",
37+
"BenchmarkResult",
38+
"FeatureGatedError",
39+
"RecommendationJob",
40+
"Secret",
41+
"Workload",
42+
"WorkloadValidationError",
43+
]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Constants for the AI inference recommender module."""
14+
from __future__ import absolute_import
15+
16+
MAX_INSTANCE_TYPES = 3
17+
18+
FEATURE_GATING_RUNBOOK_URL = (
19+
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
20+
"generative-ai-inference-recommendations.html"
21+
)
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Internal helpers backing ModelBuilder.start_benchmark and the recommendation branch of ModelBuilder.optimize."""
14+
from __future__ import absolute_import
15+
16+
import time
17+
import uuid
18+
from typing import List, Optional, Union
19+
20+
from sagemaker.core.helper.session_helper import Session, get_execution_role
21+
from sagemaker.core.resources import (
22+
AIBenchmarkJob,
23+
AIRecommendationJob,
24+
AIWorkloadConfig,
25+
Endpoint,
26+
)
27+
from sagemaker.core.shapes.shapes import (
28+
AIBenchmarkEndpoint,
29+
AIBenchmarkInferenceComponent,
30+
AIBenchmarkNetworkConfig,
31+
AIBenchmarkOutputConfig,
32+
AIBenchmarkTarget,
33+
AICapacityReservationConfig,
34+
AIModelSource,
35+
AIModelSourceS3,
36+
AIRecommendationComputeSpec,
37+
AIRecommendationConstraint,
38+
AIRecommendationInferenceSpecification,
39+
AIRecommendationOutputConfig,
40+
AIRecommendationPerformanceTarget,
41+
AIWorkloadConfigs,
42+
Tag,
43+
VpcConfig,
44+
WorkloadSpec,
45+
)
46+
from sagemaker.serve.ai_inference_recommender._constants import MAX_INSTANCE_TYPES
47+
from sagemaker.serve.ai_inference_recommender.workload import Workload
48+
49+
50+
def start_benchmark(
51+
builder, # ModelBuilder; not annotated to avoid a circular import.
52+
endpoint: Union[Endpoint, str],
53+
workload: Union[Workload, str],
54+
*,
55+
output_path: Optional[str] = None,
56+
role: Optional[str] = None,
57+
inference_components: Optional[List[str]] = None,
58+
vpc_config: Optional[VpcConfig] = None,
59+
tags: Optional[List[Tag]] = None,
60+
name: Optional[str] = None,
61+
workload_config_name: Optional[str] = None,
62+
wait: bool = True,
63+
) -> AIBenchmarkJob:
64+
"""Start an AI benchmark job against a SageMaker endpoint.
65+
66+
Args:
67+
endpoint: An ``Endpoint`` resource, or the name/ARN of an existing
68+
endpoint to benchmark.
69+
workload: Either a ``Workload`` (auto-creates a workload config) or
70+
the name/ARN of an existing ``AIWorkloadConfig``.
71+
output_path: ``s3://`` URI for benchmark output. Defaults to the
72+
session's default bucket.
73+
role: IAM execution role ARN. Defaults to the SageMaker execution
74+
role from the ambient session.
75+
inference_components: Optional list of inference component names to
76+
target on the endpoint.
77+
vpc_config: Optional ``VpcConfig`` for VPC-only endpoints.
78+
tags: Optional resource tags.
79+
name: Optional benchmark job name. Auto-generated if omitted.
80+
workload_config_name: Optional name for the auto-created workload
81+
config. Auto-generated if omitted.
82+
wait: If True, block until the job reaches a terminal state.
83+
84+
Returns:
85+
The created ``AIBenchmarkJob`` resource. After it reaches a terminal
86+
state, pass it to ``BenchmarkResult.from_job(job)`` to retrieve the
87+
parsed metrics.
88+
"""
89+
sagemaker_session = Session()
90+
role_arn = role or get_execution_role(sagemaker_session=sagemaker_session)
91+
output_location = output_path or _default_output_path(sagemaker_session, "benchmarks")
92+
93+
workload_config_id = _ensure_workload_config(workload, workload_config_name, tags=tags)
94+
95+
endpoint_name = endpoint.endpoint_name if isinstance(endpoint, Endpoint) else endpoint
96+
components = (
97+
[AIBenchmarkInferenceComponent(identifier=ic) for ic in inference_components]
98+
if inference_components
99+
else None
100+
)
101+
target = AIBenchmarkTarget(
102+
endpoint=AIBenchmarkEndpoint(
103+
identifier=endpoint_name,
104+
inference_components=components,
105+
)
106+
)
107+
network_config = (
108+
AIBenchmarkNetworkConfig(vpc_config=vpc_config) if vpc_config else None
109+
)
110+
111+
suffix = uuid.uuid4().hex[:8]
112+
job_name = name or f"sm-bench-{int(time.time())}-{suffix}"
113+
114+
job = AIBenchmarkJob.create(
115+
ai_benchmark_job_name=job_name,
116+
benchmark_target=target,
117+
output_config=AIBenchmarkOutputConfig(s3_output_location=output_location),
118+
ai_workload_config_identifier=workload_config_id,
119+
role_arn=role_arn,
120+
network_config=network_config,
121+
tags=tags,
122+
)
123+
if wait:
124+
job.wait()
125+
return job
126+
127+
128+
def run_recommendation_job(
129+
builder, # ModelBuilder; not annotated to avoid a circular import.
130+
workload: Union[Workload, str],
131+
performance_target: str,
132+
*,
133+
output_path: Optional[str] = None,
134+
role_arn: Optional[str] = None,
135+
instance_types: Optional[List[str]] = None,
136+
capacity_reservation_arns: Optional[List[str]] = None,
137+
optimize_model: bool = True,
138+
framework: Optional[str] = None,
139+
model_package_group: Optional[str] = None,
140+
tags: Optional[List[Tag]] = None,
141+
name: Optional[str] = None,
142+
workload_config_name: Optional[str] = None,
143+
wait: bool = True,
144+
) -> AIRecommendationJob:
145+
"""Submit an ``AIRecommendationJob`` for the model configured on this builder.
146+
147+
Backs the recommendation branch of :meth:`ModelBuilder.optimize`. Customers
148+
do not call this directly; pass ``workload`` and ``performance_target`` to
149+
:meth:`ModelBuilder.optimize` instead.
150+
151+
Args:
152+
workload: Either a ``Workload`` (auto-creates a workload config) or
153+
the name/ARN of an existing ``AIWorkloadConfig``.
154+
performance_target: One of ``"throughput"``, ``"ttft-ms"``, or
155+
``"cost"``.
156+
output_path: ``s3://`` URI for recommendation output. Defaults to
157+
the session's default bucket.
158+
role_arn: IAM execution role ARN. Defaults to the SageMaker execution
159+
role from the ambient session.
160+
instance_types: Up to 3 instance types to consider.
161+
capacity_reservation_arns: Optional list of ML reservation ARNs.
162+
optimize_model: If True (default), allow optimization techniques
163+
like speculative decoding and kernel tuning.
164+
framework: Inference framework. ``"LMI"`` or ``"VLLM"``.
165+
model_package_group: Optional model package group identifier in
166+
which to register the optimized model.
167+
tags: Optional resource tags.
168+
name: Optional recommendation job name. Auto-generated if omitted.
169+
workload_config_name: Optional name for the auto-created workload
170+
config. Auto-generated if omitted.
171+
wait: If True (default), block until the job reaches a terminal state.
172+
173+
Returns:
174+
The created ``AIRecommendationJob`` resource.
175+
"""
176+
sagemaker_session = Session()
177+
resolved_role_arn = role_arn or get_execution_role(sagemaker_session=sagemaker_session)
178+
output_location = output_path or _default_output_path(
179+
sagemaker_session, "recommendations"
180+
)
181+
182+
s3_uri = _resolve_model_s3_uri(builder)
183+
if not s3_uri:
184+
raise ValueError(
185+
"ModelBuilder must be configured with an S3 model_path before "
186+
"calling optimize() with a workload. Call build() first."
187+
)
188+
189+
if instance_types and len(instance_types) > MAX_INSTANCE_TYPES:
190+
raise ValueError(
191+
f"At most {MAX_INSTANCE_TYPES} instance_types are accepted; "
192+
f"got {len(instance_types)}."
193+
)
194+
195+
workload_config_id = _ensure_workload_config(workload, workload_config_name, tags=tags)
196+
197+
suffix = uuid.uuid4().hex[:8]
198+
job_name = name or f"sm-rec-{int(time.time())}-{suffix}"
199+
200+
compute_spec = None
201+
if instance_types or capacity_reservation_arns:
202+
capacity = (
203+
AICapacityReservationConfig(
204+
capacity_reservation_preference="capacity-reservations-only",
205+
ml_reservation_arns=capacity_reservation_arns,
206+
)
207+
if capacity_reservation_arns
208+
else None
209+
)
210+
compute_spec = AIRecommendationComputeSpec(
211+
instance_types=instance_types,
212+
capacity_reservation_config=capacity,
213+
)
214+
215+
inference_spec = (
216+
AIRecommendationInferenceSpecification(framework=framework) if framework else None
217+
)
218+
219+
job = AIRecommendationJob.create(
220+
ai_recommendation_job_name=job_name,
221+
model_source=AIModelSource(s3=AIModelSourceS3(s3_uri=s3_uri)),
222+
output_config=AIRecommendationOutputConfig(
223+
s3_output_location=output_location,
224+
model_package_group_identifier=model_package_group,
225+
),
226+
ai_workload_config_identifier=workload_config_id,
227+
performance_target=AIRecommendationPerformanceTarget(
228+
constraints=[AIRecommendationConstraint(metric=performance_target)],
229+
),
230+
role_arn=resolved_role_arn,
231+
inference_specification=inference_spec,
232+
optimize_model=optimize_model,
233+
compute_spec=compute_spec,
234+
tags=tags,
235+
)
236+
if wait:
237+
job.wait()
238+
return job
239+
240+
241+
def _resolve_model_s3_uri(builder) -> Optional[str]:
242+
for attr in ("model_path", "s3_upload_path", "s3_model_data_url"):
243+
candidate = getattr(builder, attr, None)
244+
if isinstance(candidate, str) and candidate.startswith("s3://"):
245+
return candidate
246+
return None
247+
248+
249+
def _ensure_workload_config(
250+
workload: Union[Workload, str],
251+
name: Optional[str],
252+
*,
253+
tags: Optional[List[Tag]] = None,
254+
) -> str:
255+
if isinstance(workload, str):
256+
return workload
257+
258+
config_name = name or f"sm-wl-{int(time.time())}-{uuid.uuid4().hex[:8]}"
259+
AIWorkloadConfig.create(
260+
ai_workload_config_name=config_name,
261+
ai_workload_configs=AIWorkloadConfigs(
262+
workload_spec=WorkloadSpec(inline=workload.to_inline()),
263+
),
264+
tags=tags,
265+
)
266+
return config_name
267+
268+
269+
def _default_output_path(session: Session, prefix: str) -> str:
270+
bucket = session.default_bucket()
271+
return f"s3://{bucket}/{prefix}/"

0 commit comments

Comments
 (0)