Skip to content

Commit c0cfc77

Browse files
committed
feat(serve): add SageMaker GenAI inference benchmarking and recommendation
Adds sagemaker.serve.ai_inference_recommender, a thin ergonomic layer over the auto-generated AIBenchmarkJob, AIRecommendationJob, and AIWorkloadConfig resources in sagemaker-core. ModelBuilder gains two methods: job = mb.start_benchmark(endpoint=ep, workload=Workload.synthetic(...)) job = mb.start_inference_recommendation(workload, throughput, instance_types=[ml.g6.12xlarge]) After the job reaches a terminal state, customers retrieve results via constructors that wrap the auto-gen job resource: result = BenchmarkResult.from_job(job) rec = Recommendation.from_job(job) endpoint = rec.deploy(role=...) Public surface added under sagemaker.serve: * Workload — typed factory (synthetic) that builds the WorkloadSpec inline JSON envelope. Extra AIPerf parameters flow through **params unchecked and are validated server-side. * BenchmarkResult / BenchmarkMetrics / BenchmarkMetric — parses the AIPerf profile_export_aiperf.json out of the output.tar.gz artifact. * Recommendation — wrapper around one row of an AIRecommendationJob's recommendations list. .deploy() prefers the ModelPackage path, falls back to a raw image_uri + S3 channels container definition. * Secret — helper around AWS Secrets Manager for hf_token round-trip. * BenchmarkJob, RecommendationJob — re-exports of the auto-gen classes without the AI prefix. * FeatureGatedError, WorkloadValidationError — typed exceptions.
1 parent 1572b32 commit c0cfc77

19 files changed

Lines changed: 2159 additions & 1 deletion

sagemaker-serve/src/sagemaker/serve/__init__.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,27 @@
2929
from sagemaker.serve.utils.types import ModelServer
3030
from sagemaker.serve.model_builder import ModelBuilder
3131

32-
__all__ = ["InferenceSpec", "ModelServer", "ModelBuilder"]
32+
from sagemaker.serve.ai_inference_recommender import (
33+
BenchmarkJob,
34+
BenchmarkResult,
35+
Recommendation,
36+
RecommendationJob,
37+
Secret,
38+
Workload,
39+
FeatureGatedError,
40+
WorkloadValidationError,
41+
)
42+
43+
__all__ = [
44+
"InferenceSpec",
45+
"ModelServer",
46+
"ModelBuilder",
47+
"BenchmarkJob",
48+
"BenchmarkResult",
49+
"Recommendation",
50+
"RecommendationJob",
51+
"Secret",
52+
"Workload",
53+
"FeatureGatedError",
54+
"WorkloadValidationError",
55+
]
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""SageMaker GenAI inference benchmarking and recommendation."""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.serve.ai_inference_recommender.exceptions import (
17+
FeatureGatedError,
18+
WorkloadValidationError,
19+
)
20+
from sagemaker.serve.ai_inference_recommender.jobs import (
21+
BenchmarkJob,
22+
RecommendationJob,
23+
)
24+
from sagemaker.serve.ai_inference_recommender.recommendation import Recommendation
25+
from sagemaker.serve.ai_inference_recommender.result import (
26+
BenchmarkMetric,
27+
BenchmarkMetrics,
28+
BenchmarkResult,
29+
)
30+
from sagemaker.serve.ai_inference_recommender.secrets import Secret
31+
from sagemaker.serve.ai_inference_recommender.workload import Workload
32+
33+
34+
__all__ = [
35+
"BenchmarkJob",
36+
"BenchmarkMetric",
37+
"BenchmarkMetrics",
38+
"BenchmarkResult",
39+
"FeatureGatedError",
40+
"Recommendation",
41+
"RecommendationJob",
42+
"Secret",
43+
"Workload",
44+
"WorkloadValidationError",
45+
]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Constants for the AI inference recommender module."""
14+
from __future__ import absolute_import
15+
16+
MAX_INSTANCE_TYPES = 3
17+
18+
FEATURE_GATING_RUNBOOK_URL = (
19+
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
20+
"generative-ai-inference-recommendations.html"
21+
)
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Implementation of ModelBuilder.start_benchmark and start_inference_recommendation."""
14+
from __future__ import absolute_import
15+
16+
import time
17+
import uuid
18+
from typing import List, Optional, Union
19+
20+
from sagemaker.core.helper.session_helper import Session, get_execution_role
21+
from sagemaker.core.resources import (
22+
AIBenchmarkJob,
23+
AIRecommendationJob,
24+
AIWorkloadConfig,
25+
Endpoint,
26+
)
27+
from sagemaker.core.shapes.shapes import (
28+
AIBenchmarkEndpoint,
29+
AIBenchmarkInferenceComponent,
30+
AIBenchmarkNetworkConfig,
31+
AIBenchmarkOutputConfig,
32+
AIBenchmarkTarget,
33+
AICapacityReservationConfig,
34+
AIModelSource,
35+
AIModelSourceS3,
36+
AIRecommendationComputeSpec,
37+
AIRecommendationConstraint,
38+
AIRecommendationInferenceSpecification,
39+
AIRecommendationOutputConfig,
40+
AIRecommendationPerformanceTarget,
41+
AIWorkloadConfigs,
42+
Tag,
43+
VpcConfig,
44+
WorkloadSpec,
45+
)
46+
from sagemaker.serve.ai_inference_recommender._constants import MAX_INSTANCE_TYPES
47+
from sagemaker.serve.ai_inference_recommender.workload import Workload
48+
49+
50+
def start_benchmark(
51+
builder, # ModelBuilder; not annotated to avoid a circular import.
52+
endpoint: Union[Endpoint, str],
53+
workload: Union[Workload, str],
54+
*,
55+
output_path: Optional[str] = None,
56+
role: Optional[str] = None,
57+
inference_components: Optional[List[str]] = None,
58+
vpc_config: Optional[VpcConfig] = None,
59+
tags: Optional[List[Tag]] = None,
60+
name: Optional[str] = None,
61+
workload_config_name: Optional[str] = None,
62+
wait: bool = False,
63+
) -> AIBenchmarkJob:
64+
"""Start an AI benchmark job against a SageMaker endpoint.
65+
66+
Args:
67+
endpoint: An ``Endpoint`` resource, or the name/ARN of an existing
68+
endpoint to benchmark.
69+
workload: Either a ``Workload`` (auto-creates a workload config) or
70+
the name/ARN of an existing ``AIWorkloadConfig``.
71+
output_path: ``s3://`` URI for benchmark output. Defaults to the
72+
session's default bucket.
73+
role: IAM execution role ARN. Defaults to the SageMaker execution
74+
role from the ambient session.
75+
inference_components: Optional list of inference component names to
76+
target on the endpoint.
77+
vpc_config: Optional ``VpcConfig`` for VPC-only endpoints.
78+
tags: Optional resource tags.
79+
name: Optional benchmark job name. Auto-generated if omitted.
80+
workload_config_name: Optional name for the auto-created workload
81+
config. Auto-generated if omitted.
82+
wait: If True, block until the job reaches a terminal state.
83+
84+
Returns:
85+
The created ``AIBenchmarkJob`` resource. After it reaches a terminal
86+
state, pass it to ``BenchmarkResult.from_job(job)`` to retrieve the
87+
parsed metrics.
88+
"""
89+
sagemaker_session = Session()
90+
role_arn = role or get_execution_role(sagemaker_session=sagemaker_session)
91+
output_location = output_path or _default_output_path(sagemaker_session, "benchmarks")
92+
93+
workload_config_id = _ensure_workload_config(workload, workload_config_name, tags=tags)
94+
95+
endpoint_name = endpoint.endpoint_name if isinstance(endpoint, Endpoint) else endpoint
96+
components = (
97+
[AIBenchmarkInferenceComponent(identifier=ic) for ic in inference_components]
98+
if inference_components
99+
else None
100+
)
101+
target = AIBenchmarkTarget(
102+
endpoint=AIBenchmarkEndpoint(
103+
identifier=endpoint_name,
104+
inference_components=components,
105+
)
106+
)
107+
network_config = (
108+
AIBenchmarkNetworkConfig(vpc_config=vpc_config) if vpc_config else None
109+
)
110+
111+
suffix = uuid.uuid4().hex[:8]
112+
job_name = name or f"sm-bench-{int(time.time())}-{suffix}"
113+
114+
job = AIBenchmarkJob.create(
115+
ai_benchmark_job_name=job_name,
116+
benchmark_target=target,
117+
output_config=AIBenchmarkOutputConfig(s3_output_location=output_location),
118+
ai_workload_config_identifier=workload_config_id,
119+
role_arn=role_arn,
120+
network_config=network_config,
121+
tags=tags,
122+
)
123+
if wait:
124+
job.wait()
125+
return job
126+
127+
128+
def start_inference_recommendation(
129+
builder, # ModelBuilder; not annotated to avoid a circular import.
130+
workload: Union[Workload, str],
131+
performance_target: str,
132+
*,
133+
output_path: Optional[str] = None,
134+
role: Optional[str] = None,
135+
instance_types: Optional[List[str]] = None,
136+
capacity_reservation_arns: Optional[List[str]] = None,
137+
optimize_model: bool = True,
138+
framework: Optional[str] = None,
139+
model_package_group: Optional[str] = None,
140+
tags: Optional[List[Tag]] = None,
141+
name: Optional[str] = None,
142+
workload_config_name: Optional[str] = None,
143+
wait: bool = False,
144+
) -> AIRecommendationJob:
145+
"""Start an AI recommendation job for the model configured on this builder.
146+
147+
Args:
148+
workload: Either a ``Workload`` (auto-creates a workload config) or
149+
the name/ARN of an existing ``AIWorkloadConfig``.
150+
performance_target: One of ``"throughput"``, ``"ttft-ms"``, or
151+
``"cost"``.
152+
output_path: ``s3://`` URI for recommendation output. Defaults to
153+
the session's default bucket.
154+
role: IAM execution role ARN. Defaults to the SageMaker execution
155+
role from the ambient session.
156+
instance_types: Up to 3 instance types to consider.
157+
capacity_reservation_arns: Optional list of ML reservation ARNs.
158+
optimize_model: If True (default), allow optimization techniques
159+
like speculative decoding and kernel tuning.
160+
framework: Inference framework. ``"LMI"`` or ``"VLLM"``.
161+
model_package_group: Optional model package group identifier in
162+
which to register the optimized model.
163+
tags: Optional resource tags.
164+
name: Optional recommendation job name. Auto-generated if omitted.
165+
workload_config_name: Optional name for the auto-created workload
166+
config. Auto-generated if omitted.
167+
wait: If True, block until the job reaches a terminal state.
168+
169+
Returns:
170+
The created ``AIRecommendationJob`` resource.
171+
"""
172+
sagemaker_session = Session()
173+
role_arn = role or get_execution_role(sagemaker_session=sagemaker_session)
174+
output_location = output_path or _default_output_path(
175+
sagemaker_session, "recommendations"
176+
)
177+
178+
s3_uri = _resolve_model_s3_uri(builder)
179+
if not s3_uri:
180+
raise ValueError(
181+
"ModelBuilder must be configured with an S3 model_path before "
182+
"calling start_inference_recommendation()."
183+
)
184+
185+
if instance_types and len(instance_types) > MAX_INSTANCE_TYPES:
186+
raise ValueError(
187+
f"At most {MAX_INSTANCE_TYPES} instance_types are accepted; "
188+
f"got {len(instance_types)}."
189+
)
190+
191+
workload_config_id = _ensure_workload_config(workload, workload_config_name, tags=tags)
192+
193+
suffix = uuid.uuid4().hex[:8]
194+
job_name = name or f"sm-rec-{int(time.time())}-{suffix}"
195+
196+
compute_spec = None
197+
if instance_types or capacity_reservation_arns:
198+
capacity = (
199+
AICapacityReservationConfig(
200+
capacity_reservation_preference="capacity-reservations-only",
201+
ml_reservation_arns=capacity_reservation_arns,
202+
)
203+
if capacity_reservation_arns
204+
else None
205+
)
206+
compute_spec = AIRecommendationComputeSpec(
207+
instance_types=instance_types,
208+
capacity_reservation_config=capacity,
209+
)
210+
211+
inference_spec = (
212+
AIRecommendationInferenceSpecification(framework=framework) if framework else None
213+
)
214+
215+
job = AIRecommendationJob.create(
216+
ai_recommendation_job_name=job_name,
217+
model_source=AIModelSource(s3=AIModelSourceS3(s3_uri=s3_uri)),
218+
output_config=AIRecommendationOutputConfig(
219+
s3_output_location=output_location,
220+
model_package_group_identifier=model_package_group,
221+
),
222+
ai_workload_config_identifier=workload_config_id,
223+
performance_target=AIRecommendationPerformanceTarget(
224+
constraints=[AIRecommendationConstraint(metric=performance_target)],
225+
),
226+
role_arn=role_arn,
227+
inference_specification=inference_spec,
228+
optimize_model=optimize_model,
229+
compute_spec=compute_spec,
230+
tags=tags,
231+
)
232+
if wait:
233+
job.wait()
234+
return job
235+
236+
237+
def _resolve_model_s3_uri(builder) -> Optional[str]:
238+
candidate = getattr(builder, "model_path", None) or getattr(
239+
builder, "s3_upload_path", None
240+
)
241+
if isinstance(candidate, str) and candidate.startswith("s3://"):
242+
return candidate
243+
return None
244+
245+
246+
def _ensure_workload_config(
247+
workload: Union[Workload, str],
248+
name: Optional[str],
249+
*,
250+
tags: Optional[List[Tag]] = None,
251+
) -> str:
252+
if isinstance(workload, str):
253+
return workload
254+
255+
config_name = name or f"sm-wl-{int(time.time())}-{uuid.uuid4().hex[:8]}"
256+
AIWorkloadConfig.create(
257+
ai_workload_config_name=config_name,
258+
ai_workload_configs=AIWorkloadConfigs(
259+
workload_spec=WorkloadSpec(inline=workload.to_inline()),
260+
),
261+
tags=tags,
262+
)
263+
return config_name
264+
265+
266+
def _default_output_path(session: Session, prefix: str) -> str:
267+
bucket = session.default_bucket()
268+
return f"s3://{bucket}/{prefix}/"

0 commit comments

Comments
 (0)