Skip to content

Commit 638234f

Browse files
committed
feat(serve): add SageMaker GenAI inference benchmarking and recommendation
Adds sagemaker.serve.ai_inference_recommender, a thin ergonomic layer over sagemaker-core's AIBenchmarkJob, AIRecommendationJob, and AIWorkloadConfig resources. ModelBuilder gains a new entry point and extends two existing verbs: # Benchmark a deployed endpoint job = mb.start_benchmark(endpoint=ep, workload=Workload.synthetic(...)) result = BenchmarkResult.from_job(job) # Recommendation flow extends optimize() and deploy() mb.optimize(workload=..., performance_target="throughput", instance_types=["ml.g6.12xlarge"]) endpoint = mb.deploy(role=role) # top recommendation endpoint = mb.deploy(role=role, recommendation_index=2) # alternative print(result) and print(mb.recommendations[0]) render their data as tables. Public surface added under sagemaker.serve: * Workload -- typed factory; extras pass through **params, validated server-side. * BenchmarkResult / BenchmarkMetrics / BenchmarkMetric -- parses the AIPerf output.tar.gz from S3. * Secret -- opt-in helper for tokens >512 chars (Secrets Manager). * BenchmarkJob, RecommendationJob -- re-exports without the AI prefix. * FeatureGatedError, WorkloadValidationError -- typed exceptions. Pin-mode and workload-mode optimize() kwargs are mutually exclusive. Recommendation deploy uses the ModelPackage path (auto-approves the package the rec job publishes). Includes 51 unit tests and 2 slow_test integ tests (tests/integ/test_ai_inference_recommender_integration.py) verified end-to-end against real AWS. Rebased onto upstream to pick up #5860 (preserve falsy values in sagemaker-core serialize), required so optimize_model=False reaches the wire.
1 parent ccb0e37 commit 638234f

12 files changed

Lines changed: 448 additions & 390 deletions

File tree

sagemaker-serve/src/sagemaker/serve/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
Workload,
3838
FeatureGatedError,
3939
WorkloadValidationError,
40+
start_benchmark,
4041
)
4142

4243
__all__ = [
@@ -50,4 +51,5 @@
5051
"Workload",
5152
"FeatureGatedError",
5253
"WorkloadValidationError",
54+
"start_benchmark",
5355
]

sagemaker-serve/src/sagemaker/serve/ai_inference_recommender/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
)
2929
from sagemaker.serve.ai_inference_recommender.secrets import Secret
3030
from sagemaker.serve.ai_inference_recommender.workload import Workload
31+
from sagemaker.serve.ai_inference_recommender._model_builder_methods import (
32+
start_benchmark,
33+
)
3134

3235

3336
__all__ = [
@@ -40,4 +43,5 @@
4043
"Secret",
4144
"Workload",
4245
"WorkloadValidationError",
46+
"start_benchmark",
4347
]

sagemaker-serve/src/sagemaker/serve/ai_inference_recommender/_model_builder_methods.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Internal helpers backing ModelBuilder.start_benchmark and the recommendation branch of ModelBuilder.optimize."""
13+
"""Internal helpers backing the public start_benchmark function and ModelBuilder.generate_deployment_recommendations."""
1414
from __future__ import absolute_import
1515

1616
import time
1717
import uuid
18-
from typing import List, Optional, Union
18+
from typing import Any, List, Optional, Union
1919

2020
from sagemaker.core.helper.session_helper import Session, get_execution_role
2121
from sagemaker.core.resources import (
@@ -24,6 +24,8 @@
2424
AIWorkloadConfig,
2525
Endpoint,
2626
)
27+
from sagemaker.core.telemetry.constants import Feature
28+
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2729
from sagemaker.core.shapes.shapes import (
2830
AIBenchmarkEndpoint,
2931
AIBenchmarkInferenceComponent,
@@ -51,10 +53,12 @@
5153
from sagemaker.serve.ai_inference_recommender.workload import Workload
5254

5355

56+
@_telemetry_emitter(
57+
feature=Feature.MODEL_CUSTOMIZATION, func_name="ai_inference_recommender.start_benchmark"
58+
)
5459
def start_benchmark(
55-
builder, # ModelBuilder; not annotated to avoid a circular import.
5660
endpoint: Union[Endpoint, str],
57-
workload: Union[Workload, str],
61+
workload: Optional[Union[Workload, str]] = None,
5862
*,
5963
output_path: Optional[str] = None,
6064
role: Optional[str] = None,
@@ -64,14 +68,17 @@ def start_benchmark(
6468
name: Optional[str] = None,
6569
workload_config_name: Optional[str] = None,
6670
wait: bool = True,
71+
**workload_kwargs: Any,
6772
) -> AIBenchmarkJob:
6873
"""Start an AI benchmark job against a SageMaker endpoint.
6974
7075
Args:
7176
endpoint: An ``Endpoint`` resource, or the name/ARN of an existing
7277
endpoint to benchmark.
73-
workload: Either a ``Workload`` (auto-creates a workload config) or
74-
the name/ARN of an existing ``AIWorkloadConfig``.
78+
workload: Optional. A ``Workload`` instance, or the name/ARN of an
79+
existing ``AIWorkloadConfig``. Omit this and pass workload
80+
keyword arguments inline (``tokenizer=``, ``concurrency=``,
81+
etc.) to construct a synthetic workload on the fly.
7582
output_path: ``s3://`` URI for benchmark output. Defaults to the
7683
session's default bucket.
7784
role: IAM execution role ARN. Defaults to the SageMaker execution
@@ -83,13 +90,28 @@ def start_benchmark(
8390
name: Optional benchmark job name. Auto-generated if omitted.
8491
workload_config_name: Optional name for the auto-created workload
8592
config. Auto-generated if omitted.
86-
wait: If True, block until the job reaches a terminal state.
93+
wait: If True (default), block until the job reaches a terminal
94+
state.
95+
**workload_kwargs: Inline workload parameters. Only used when
96+
``workload`` is omitted; forwarded to ``Workload.synthetic``.
8797
8898
Returns:
89-
The created ``AIBenchmarkJob`` resource. After it reaches a terminal
90-
state, pass it to ``BenchmarkResult.from_job(job)`` to retrieve the
91-
parsed metrics.
99+
The created :class:`BenchmarkJob`. Once terminal, call
100+
``job.show_result()`` to download and parse the metrics.
92101
"""
102+
if workload is None:
103+
if not workload_kwargs:
104+
raise ValueError(
105+
"start_benchmark requires either a workload= argument or "
106+
"inline workload keyword arguments (e.g. tokenizer=...)."
107+
)
108+
workload = Workload.synthetic(**workload_kwargs)
109+
elif workload_kwargs:
110+
raise ValueError(
111+
"start_benchmark accepts either workload= or inline workload "
112+
"keyword arguments, not both."
113+
)
114+
93115
sagemaker_session = Session()
94116
role_arn = role or get_execution_role(sagemaker_session=sagemaker_session)
95117
output_location = output_path or _default_output_path(sagemaker_session, "benchmarks")
@@ -124,8 +146,11 @@ def start_benchmark(
124146
network_config=network_config,
125147
tags=tags,
126148
)
127-
if builder is not None:
128-
builder._benchmark_job = job
149+
# Surface the BenchmarkJob subclass (which adds show_result) on the
150+
# returned instance.
151+
from sagemaker.serve.ai_inference_recommender.jobs import BenchmarkJob
152+
153+
job.__class__ = BenchmarkJob
129154
if wait:
130155
job.wait()
131156
return job
@@ -140,7 +165,7 @@ def run_recommendation_job(
140165
role_arn: Optional[str] = None,
141166
instance_types: Optional[List[str]] = None,
142167
capacity_reservation_arns: Optional[List[str]] = None,
143-
optimize_model: bool = True,
168+
advanced_optimization: bool = True,
144169
framework: Optional[str] = None,
145170
model_package_group: Optional[str] = None,
146171
tags: Optional[List[Tag]] = None,
@@ -150,9 +175,8 @@ def run_recommendation_job(
150175
) -> AIRecommendationJob:
151176
"""Submit an ``AIRecommendationJob`` for the model configured on this builder.
152177
153-
Backs the recommendation branch of :meth:`ModelBuilder.optimize`. Not
154-
intended to be called directly; pass ``workload`` and ``performance_target``
155-
to :meth:`ModelBuilder.optimize` instead.
178+
Backs :meth:`ModelBuilder.generate_deployment_recommendations`. Not intended
179+
to be called directly.
156180
157181
Args:
158182
workload: Either a ``Workload`` (auto-creates a workload config) or
@@ -165,8 +189,9 @@ def run_recommendation_job(
165189
role from the ambient session.
166190
instance_types: Up to 3 instance types to evaluate.
167191
capacity_reservation_arns: Optional list of ML reservation ARNs.
168-
optimize_model: If True (default), allow the service to apply model
169-
optimizations such as speculative decoding and kernel tuning.
192+
advanced_optimization: If True (default), allow the service to apply
193+
model optimizations such as speculative decoding and kernel
194+
tuning.
170195
framework: Inference framework. ``"LMI"`` or ``"VLLM"``.
171196
model_package_group: Optional model package group identifier in
172197
which to register the optimized model.
@@ -189,7 +214,7 @@ def run_recommendation_job(
189214
if not s3_uri:
190215
raise ValueError(
191216
"ModelBuilder must be configured with an S3 model_path before "
192-
"calling optimize() with a workload. Call build() first."
217+
"calling generate_deployment_recommendations. Call build() first."
193218
)
194219

195220
if instance_types and len(instance_types) > MAX_INSTANCE_TYPES:
@@ -235,10 +260,15 @@ def run_recommendation_job(
235260
),
236261
role_arn=resolved_role_arn,
237262
inference_specification=inference_spec,
238-
optimize_model=optimize_model,
263+
optimize_model=advanced_optimization,
239264
compute_spec=compute_spec,
240265
tags=tags,
241266
)
267+
# Surface the RecommendationJob subclass (which adds show_result) on the
268+
# returned instance.
269+
from sagemaker.serve.ai_inference_recommender.jobs import RecommendationJob
270+
271+
job.__class__ = RecommendationJob
242272
if wait:
243273
job.wait()
244274
return job

sagemaker-serve/src/sagemaker/serve/ai_inference_recommender/_recommendation_view.py

Lines changed: 148 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,178 @@
1212
# language governing permissions and limitations under the License.
1313
"""Pretty-printing wrapper over an AIRecommendation row.
1414
15-
Exists only because the auto-generated ``AIRecommendation`` shape has
16-
Pydantic's default repr (dumps every field). Wrapping each row in this
17-
class swaps the repr without owning the data — attribute access forwards
18-
to the raw shape transparently.
15+
Wraps each row to replace the default repr without owning the data;
16+
attribute access forwards to the underlying shape transparently.
1917
"""
2018
from __future__ import absolute_import
2119

20+
from collections import defaultdict
21+
from typing import Any, Dict, List, Optional
22+
2223
from sagemaker.serve.ai_inference_recommender.result import (
2324
_fmt_number,
2425
_format_table,
2526
_indent,
2627
)
2728

2829

30+
class _ExpectedPerformanceMetric:
31+
"""Aggregated stats for a single metric in ``expected_performance``.
32+
33+
Each metric on the recommendation row is reported as one or more rows
34+
keyed by ``stat`` (avg, p50, p90, p99, ...). This view groups the rows
35+
so customers can do ``rec.expected_performance.request_throughput.avg``
36+
or ``.p99`` directly.
37+
"""
38+
39+
__slots__ = ("_stats", "unit")
40+
41+
def __init__(self, stats: Dict[str, float], unit: Optional[str]):
42+
object.__setattr__(self, "_stats", stats)
43+
object.__setattr__(self, "unit", unit)
44+
45+
@property
46+
def avg(self) -> Optional[float]:
47+
return self._stats.get("avg")
48+
49+
@property
50+
def p50(self) -> Optional[float]:
51+
return self._stats.get("p50")
52+
53+
@property
54+
def p90(self) -> Optional[float]:
55+
return self._stats.get("p90")
56+
57+
@property
58+
def p99(self) -> Optional[float]:
59+
return self._stats.get("p99")
60+
61+
@property
62+
def stats(self) -> Dict[str, float]:
63+
return dict(self._stats)
64+
65+
def __repr__(self) -> str:
66+
parts = ", ".join(
67+
f"{stat}={_fmt_number(v)}" for stat, v in self._stats.items()
68+
)
69+
unit = f" {self.unit}" if self.unit else ""
70+
return f"<{parts}{unit}>"
71+
72+
73+
class _ExpectedPerformanceView:
74+
"""Typed + dict-style accessor over a recommendation's expected_performance.
75+
76+
Service shape is ``List[AIRecommendationPerformanceMetric]`` with one row
77+
per (metric, stat). This view groups rows by metric name so customers
78+
can do ``view.request_throughput.avg`` (snake_case attribute), or
79+
``view.get("RequestThroughput").p99`` (raw service name).
80+
"""
81+
82+
__slots__ = ("_by_metric",)
83+
84+
def __init__(self, raw_rows: Optional[List[Any]]):
85+
by_metric: Dict[str, Dict[str, Any]] = defaultdict(
86+
lambda: {"unit": None, "stats": {}}
87+
)
88+
for row in raw_rows or []:
89+
metric = getattr(row, "metric", None)
90+
if not metric:
91+
continue
92+
stat = getattr(row, "stat", None) or "value"
93+
value = _to_float(getattr(row, "value", None))
94+
if value is None:
95+
continue
96+
entry = by_metric[metric]
97+
entry["stats"][stat] = value
98+
unit = getattr(row, "unit", None)
99+
if unit and not entry["unit"]:
100+
entry["unit"] = unit
101+
102+
compiled: Dict[str, _ExpectedPerformanceMetric] = {
103+
name: _ExpectedPerformanceMetric(entry["stats"], entry["unit"])
104+
for name, entry in by_metric.items()
105+
}
106+
object.__setattr__(self, "_by_metric", compiled)
107+
108+
def get(self, name: str) -> Optional[_ExpectedPerformanceMetric]:
109+
"""Look up a metric by raw service name (e.g. ``"RequestThroughput"``)."""
110+
return self._by_metric.get(name)
111+
112+
def __getattr__(self, name: str) -> _ExpectedPerformanceMetric:
113+
# snake_case attribute access. Translate to CamelCase service name.
114+
service_name = _snake_to_camel(name)
115+
metric = self._by_metric.get(service_name) or self._by_metric.get(name)
116+
if metric is None:
117+
raise AttributeError(
118+
f"No expected_performance metric named {name!r}. "
119+
f"Available: {sorted(self._by_metric)}"
120+
)
121+
return metric
122+
123+
def __contains__(self, name: str) -> bool:
124+
return name in self._by_metric or _snake_to_camel(name) in self._by_metric
125+
126+
def __iter__(self):
127+
return iter(self._by_metric)
128+
129+
def keys(self):
130+
return self._by_metric.keys()
131+
132+
def items(self):
133+
return self._by_metric.items()
134+
135+
def values(self):
136+
return self._by_metric.values()
137+
138+
def __len__(self) -> int:
139+
return len(self._by_metric)
140+
141+
def __repr__(self) -> str:
142+
return "{" + ", ".join(
143+
f"{name}: {metric!r}" for name, metric in self._by_metric.items()
144+
) + "}"
145+
146+
147+
def _to_float(value):
148+
try:
149+
return float(value) if value is not None else None
150+
except (TypeError, ValueError):
151+
return None
152+
153+
154+
def _snake_to_camel(name: str) -> str:
155+
return "".join(word.capitalize() for word in name.split("_"))
156+
157+
29158
class _RecommendationView:
30159
"""Read-only view over a single recommendation row."""
31160

32-
__slots__ = ("_raw", "_index")
161+
__slots__ = ("_raw", "_index", "_expected_performance")
33162

34163
def __init__(self, raw, index: int = 0):
35164
# Use object.__setattr__ to avoid triggering __getattr__ during init.
36165
object.__setattr__(self, "_raw", raw)
37166
object.__setattr__(self, "_index", index)
167+
object.__setattr__(
168+
self,
169+
"_expected_performance",
170+
_ExpectedPerformanceView(getattr(raw, "expected_performance", None)),
171+
)
38172

39173
@property
40174
def raw(self):
41-
"""The underlying auto-generated ``AIRecommendation`` shape."""
175+
"""The underlying ``AIRecommendation`` shape."""
42176
return self._raw
43177

178+
@property
179+
def expected_performance(self) -> _ExpectedPerformanceView:
180+
"""Typed + dict-style accessor for the recommendation's expected metrics."""
181+
return self._expected_performance
182+
44183
def __getattr__(self, name):
45-
# Fall through to the underlying shape so ``view.model_details``,
46-
# ``view.deployment_configuration``, and ``view.expected_performance``
47-
# work as if the customer held the raw row.
184+
# Fall through to the underlying shape so ``view.model_details`` and
185+
# ``view.deployment_configuration`` work as if the customer held the
186+
# raw row. ``expected_performance`` is intercepted by the property above.
48187
return getattr(self._raw, name)
49188

50189
def __repr__(self) -> str:

0 commit comments

Comments
 (0)