Skip to content

Commit a8011a4

Browse files
authored
fix(serve): prefer hub document hosting config for Nova models (#5962)
* fix(serve): prefer hub document hosting config for Nova models * fix: Prefer an explicitly configured hub, default still public JumpStart hub
1 parent db63167 commit a8011a4

2 files changed

Lines changed: 311 additions & 13 deletions

File tree

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -692,9 +692,13 @@ def _fetch_hub_document_for_custom_model(self) -> dict:
692692
base_model: CoreBaseModel = (
693693
self._fetch_model_package().inference_specification.containers[0].base_model
694694
)
695+
# Prefer an explicitly configured hub (e.g. a private hub used for
696+
# testing or for sourcing pre-release hosting configs); fall back to the
697+
# public JumpStart hub when none is set.
698+
hub_name = getattr(self, "hub_name", None) or "SageMakerPublicHub"
695699
hub_content = HubContent.get(
696700
hub_content_type="Model",
697-
hub_name="SageMakerPublicHub",
701+
hub_name=hub_name,
698702
hub_content_name=base_model.hub_content_name,
699703
hub_content_version=base_model.hub_content_version,
700704
)
@@ -1076,12 +1080,105 @@ def _is_nova_model_for_telemetry(self) -> bool:
10761080
except Exception:
10771081
return False
10781082

1083+
def _select_nova_hosting_config_entry(self, configs, instance_type, identifier):
1084+
"""Select a single hosting config entry from a list of Nova configs.
1085+
1086+
Picks the entry matching ``instance_type`` when provided, otherwise the
1087+
entry with ``Profile == "Default"`` (falling back to the first entry).
1088+
1089+
Args:
1090+
configs: List of hosting config dicts.
1091+
instance_type: Requested instance type, or None.
1092+
identifier: Model identifier used for error messages.
1093+
1094+
Returns:
1095+
The selected hosting config dict.
1096+
1097+
Raises:
1098+
ValueError: If ``instance_type`` is provided but no entry matches it.
1099+
"""
1100+
if instance_type:
1101+
config = next(
1102+
(c for c in configs if c.get("InstanceType") == instance_type), None
1103+
)
1104+
if not config:
1105+
supported = [c.get("InstanceType") for c in configs]
1106+
raise ValueError(
1107+
f"Instance type '{instance_type}' not supported for '{identifier}'. "
1108+
f"Supported: {supported}"
1109+
)
1110+
return config
1111+
return next((c for c in configs if c.get("Profile") == "Default"), configs[0])
1112+
1113+
def _get_nova_hosting_config_from_hub_document(self, instance_type=None):
1114+
"""Resolve Nova hosting config from the JumpStart hub document, if present.
1115+
1116+
Reads hosting configs published in the hub content document, matching the
1117+
standard schema used by other custom models. Looks first inside the
1118+
``RecipeCollection`` entry whose ``Name`` matches the recipe, then falls
1119+
back to the top-level ``HostingConfigs``.
1120+
1121+
Returns:
1122+
A dict with ``image_uri``, ``env_vars``, and ``instance_type`` when a
1123+
usable hosting config is found, otherwise ``None``.
1124+
"""
1125+
try:
1126+
hub_document = self._fetch_hub_document_for_custom_model()
1127+
except Exception as e: # pragma: no cover - defensive, hub may be unavailable
1128+
logger.debug(f"Could not fetch hub document for Nova hosting config: {e}")
1129+
return None
1130+
1131+
if not hub_document:
1132+
return None
1133+
1134+
container = self._fetch_model_package().inference_specification.containers[0]
1135+
recipe_name = getattr(container.base_model, "recipe_name", None) or ""
1136+
1137+
hosting_configs = None
1138+
for recipe in hub_document.get("RecipeCollection", []):
1139+
if recipe.get("Name") == recipe_name:
1140+
hosting_configs = recipe.get("HostingConfigs")
1141+
break
1142+
if not hosting_configs:
1143+
hosting_configs = hub_document.get("HostingConfigs")
1144+
1145+
if not hosting_configs:
1146+
return None
1147+
1148+
config = self._select_nova_hosting_config_entry(
1149+
hosting_configs, instance_type, recipe_name or "nova"
1150+
)
1151+
1152+
image_uri = config.get("EcrAddress")
1153+
if not image_uri:
1154+
# Hosting config present but no image override; let the hardcoded
1155+
# fallback supply the escrow image URI.
1156+
return None
1157+
1158+
resolved_instance_type = config.get("InstanceType") or config.get(
1159+
"DefaultInstanceType"
1160+
)
1161+
1162+
return {
1163+
"image_uri": image_uri,
1164+
"env_vars": config.get("Environment", {}),
1165+
"instance_type": resolved_instance_type,
1166+
}
1167+
10791168
def _get_nova_hosting_config(self, instance_type=None):
10801169
"""Get Nova hosting config (image URI, env vars, instance type).
10811170
1082-
Nova training recipes don't have hosting configs in the JumpStart hub document.
1083-
This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs().
1171+
Prefers hosting configs published in the JumpStart hub document (the
1172+
standard location used by other custom models). Falls back to the
1173+
hardcoded ``_NOVA_HOSTING_CONFIGS``, matching Rhinestone's
1174+
getNovaHostingConfigs(), when the hub document does not provide one.
10841175
"""
1176+
hub_config = self._get_nova_hosting_config_from_hub_document(
1177+
instance_type=instance_type
1178+
)
1179+
if hub_config:
1180+
return hub_config
1181+
10851182
model_package = self._fetch_model_package()
10861183
hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name
10871184

@@ -1102,16 +1199,9 @@ def _get_nova_hosting_config(self, instance_type=None):
11021199

11031200
image_uri = f"{escrow_account}.dkr.ecr.{region}.amazonaws.com/nova-inference-repo:SM-Inference-latest"
11041201

1105-
if instance_type:
1106-
config = next((c for c in configs if c["InstanceType"] == instance_type), None)
1107-
if not config:
1108-
supported = [c["InstanceType"] for c in configs]
1109-
raise ValueError(
1110-
f"Instance type '{instance_type}' not supported for '{hub_content_name}'. "
1111-
f"Supported: {supported}"
1112-
)
1113-
else:
1114-
config = next((c for c in configs if c.get("Profile") == "Default"), configs[0])
1202+
config = self._select_nova_hosting_config_entry(
1203+
configs, instance_type, hub_content_name
1204+
)
11151205

11161206
return {
11171207
"image_uri": image_uri,
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Unit tests for Nova hosting config resolution in ModelBuilder.
14+
15+
Verifies that hosting configs published in the JumpStart hub document take
16+
priority over the hardcoded ``_NOVA_HOSTING_CONFIGS`` fallback.
17+
"""
18+
19+
import unittest
20+
from unittest.mock import MagicMock, patch
21+
22+
from sagemaker.serve.model_builder import ModelBuilder
23+
24+
25+
def _make_builder(region="us-east-1"):
26+
"""Create a ModelBuilder without running __init__."""
27+
mb = ModelBuilder.__new__(ModelBuilder)
28+
mb.image_uri = None
29+
mb.env_vars = None
30+
mb.instance_type = None
31+
session = MagicMock()
32+
session.boto_region_name = region
33+
mb.sagemaker_session = session
34+
return mb
35+
36+
37+
def _make_model_package(recipe_name="", hub_content_name="nova-textgeneration-lite"):
38+
pkg = MagicMock()
39+
base_model = MagicMock()
40+
base_model.recipe_name = recipe_name
41+
base_model.hub_content_name = hub_content_name
42+
pkg.inference_specification.containers = [MagicMock(base_model=base_model)]
43+
return pkg
44+
45+
46+
class TestNovaHostingConfigResolution(unittest.TestCase):
47+
"""Tests for ModelBuilder._get_nova_hosting_config priority behavior."""
48+
49+
def test_hub_recipe_collection_config_takes_priority(self):
50+
"""Hosting config from RecipeCollection in the hub doc is preferred."""
51+
mb = _make_builder()
52+
hub_doc = {
53+
"RecipeCollection": [
54+
{
55+
"Name": "my-nova-recipe",
56+
"HostingConfigs": [
57+
{
58+
"Profile": "Default",
59+
"EcrAddress": "111.dkr.ecr.us-east-1.amazonaws.com/custom:tag",
60+
"InstanceType": "ml.p5.48xlarge",
61+
"Environment": {
62+
"CONTEXT_LENGTH": "999",
63+
"MAX_CONCURRENCY": "3",
64+
},
65+
}
66+
],
67+
}
68+
]
69+
}
70+
mp = _make_model_package(
71+
recipe_name="my-nova-recipe", hub_content_name="nova-textgeneration-lite"
72+
)
73+
with patch.object(
74+
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc
75+
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
76+
cfg = mb._get_nova_hosting_config()
77+
78+
self.assertEqual(
79+
cfg["image_uri"], "111.dkr.ecr.us-east-1.amazonaws.com/custom:tag"
80+
)
81+
self.assertEqual(
82+
cfg["env_vars"], {"CONTEXT_LENGTH": "999", "MAX_CONCURRENCY": "3"}
83+
)
84+
self.assertEqual(cfg["instance_type"], "ml.p5.48xlarge")
85+
86+
def test_top_level_hosting_configs_used_when_no_recipe_match(self):
87+
"""Top-level HostingConfigs is used when no RecipeCollection matches."""
88+
mb = _make_builder()
89+
hub_doc = {
90+
"HostingConfigs": [
91+
{
92+
"Profile": "Default",
93+
"EcrAddress": "222.dkr.ecr.us-east-1.amazonaws.com/top:tag",
94+
"InstanceType": "ml.g6.24xlarge",
95+
"Environment": {"CONTEXT_LENGTH": "100"},
96+
}
97+
]
98+
}
99+
mp = _make_model_package(
100+
recipe_name="unmatched", hub_content_name="nova-textgeneration-micro"
101+
)
102+
with patch.object(
103+
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc
104+
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
105+
cfg = mb._get_nova_hosting_config()
106+
107+
self.assertEqual(
108+
cfg["image_uri"], "222.dkr.ecr.us-east-1.amazonaws.com/top:tag"
109+
)
110+
111+
def test_hardcoded_fallback_when_hub_has_no_hosting_config(self):
112+
"""Hardcoded escrow config is used when the hub doc has no hosting config."""
113+
mb = _make_builder()
114+
mp = _make_model_package(hub_content_name="nova-textgeneration-lite")
115+
with patch.object(
116+
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value={}
117+
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
118+
cfg = mb._get_nova_hosting_config()
119+
120+
self.assertIn("nova-inference-repo:SM-Inference-latest", cfg["image_uri"])
121+
self.assertEqual(cfg["instance_type"], "ml.g6.48xlarge")
122+
123+
def test_hardcoded_fallback_when_hub_fetch_raises(self):
124+
"""Hardcoded config is used defensively when hub fetch raises."""
125+
mb = _make_builder()
126+
mp = _make_model_package(hub_content_name="nova-textgeneration-pro")
127+
with patch.object(
128+
ModelBuilder,
129+
"_fetch_hub_document_for_custom_model",
130+
side_effect=RuntimeError("hub unavailable"),
131+
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
132+
cfg = mb._get_nova_hosting_config()
133+
134+
self.assertEqual(cfg["instance_type"], "ml.p5.48xlarge")
135+
self.assertIn("nova-inference-repo:SM-Inference-latest", cfg["image_uri"])
136+
137+
def test_missing_ecr_address_falls_through_to_hardcoded(self):
138+
"""A hub hosting config without EcrAddress falls back to the escrow image."""
139+
mb = _make_builder()
140+
hub_doc = {
141+
"RecipeCollection": [
142+
{
143+
"Name": "r",
144+
"HostingConfigs": [
145+
{"Profile": "Default", "InstanceType": "ml.p5.48xlarge"}
146+
],
147+
}
148+
]
149+
}
150+
mp = _make_model_package(
151+
recipe_name="r", hub_content_name="nova-textgeneration-pro"
152+
)
153+
with patch.object(
154+
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc
155+
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
156+
cfg = mb._get_nova_hosting_config()
157+
158+
self.assertIn("nova-inference-repo:SM-Inference-latest", cfg["image_uri"])
159+
160+
def test_instance_type_match_in_hub_config(self):
161+
"""A requested instance type selects the matching hub config entry."""
162+
mb = _make_builder()
163+
hub_doc = {
164+
"RecipeCollection": [
165+
{
166+
"Name": "r",
167+
"HostingConfigs": [
168+
{
169+
"Profile": "Default",
170+
"EcrAddress": "333.dkr.ecr.us-east-1.amazonaws.com/a:tag",
171+
"InstanceType": "ml.p5.48xlarge",
172+
"Environment": {"CONTEXT_LENGTH": "1"},
173+
},
174+
{
175+
"EcrAddress": "333.dkr.ecr.us-east-1.amazonaws.com/b:tag",
176+
"InstanceType": "ml.g6.48xlarge",
177+
"Environment": {"CONTEXT_LENGTH": "2"},
178+
},
179+
],
180+
}
181+
]
182+
}
183+
mp = _make_model_package(
184+
recipe_name="r", hub_content_name="nova-textgeneration-lite"
185+
)
186+
with patch.object(
187+
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc
188+
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
189+
cfg = mb._get_nova_hosting_config(instance_type="ml.g6.48xlarge")
190+
191+
self.assertEqual(
192+
cfg["image_uri"], "333.dkr.ecr.us-east-1.amazonaws.com/b:tag"
193+
)
194+
self.assertEqual(cfg["instance_type"], "ml.g6.48xlarge")
195+
196+
def test_unsupported_instance_type_raises(self):
197+
"""Requesting an unsupported instance type raises ValueError (fallback path)."""
198+
mb = _make_builder()
199+
mp = _make_model_package(hub_content_name="nova-textgeneration-pro")
200+
with patch.object(
201+
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value={}
202+
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
203+
with self.assertRaises(ValueError):
204+
mb._get_nova_hosting_config(instance_type="ml.invalid.type")
205+
206+
207+
if __name__ == "__main__":
208+
unittest.main()

0 commit comments

Comments
 (0)