Skip to content

Commit 7cdb047

Browse files
committed
fix: Model builder unable to (5667)
1 parent 272fdbf commit 7cdb047

2 files changed

Lines changed: 301 additions & 1 deletion

File tree

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def build(self):
5454

5555
# SageMaker core imports
5656
from sagemaker.core.helper.session_helper import Session
57-
from sagemaker.core.utils.utils import logger
57+
from sagemaker.core.utils.utils import logger, Unassigned
5858

5959
from sagemaker.train import ModelTrainer
6060

@@ -137,6 +137,98 @@ def build(self):
137137
from sagemaker.serve.model_server.triton.config_template import CONFIG_TEMPLATE
138138

139139
SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources"
140+
141+
142+
def resolve_base_model_fields(base_model):
143+
"""Resolve missing BaseModel fields (hub_content_version, recipe_name).
144+
145+
When a ModelPackage's BaseModel has hub_content_name set but is missing
146+
hub_content_version and/or recipe_name (returned as Unassigned from the
147+
DescribeModelPackage API), this function attempts to resolve them
148+
automatically by querying SageMakerPublicHub.
149+
150+
Args:
151+
base_model: A BaseModel object with hub_content_name, hub_content_version,
152+
and recipe_name attributes.
153+
154+
Returns:
155+
The mutated base_model with resolved fields where possible.
156+
"""
157+
if base_model is None:
158+
return base_model
159+
160+
# Check if hub_content_name is present and valid
161+
hub_content_name = getattr(base_model, "hub_content_name", None)
162+
if hub_content_name is None or isinstance(hub_content_name, Unassigned):
163+
return base_model
164+
165+
if not hub_content_name or not str(hub_content_name).strip():
166+
return base_model
167+
168+
hub_content_version = getattr(base_model, "hub_content_version", None)
169+
recipe_name = getattr(base_model, "recipe_name", None)
170+
171+
version_missing = (
172+
hub_content_version is None
173+
or isinstance(hub_content_version, Unassigned)
174+
or not str(hub_content_version).strip()
175+
)
176+
recipe_missing = (
177+
recipe_name is None
178+
or isinstance(recipe_name, Unassigned)
179+
or not str(recipe_name).strip()
180+
)
181+
182+
if not version_missing and not recipe_missing:
183+
return base_model
184+
185+
# Attempt to resolve from SageMakerPublicHub
186+
if version_missing:
187+
try:
188+
from sagemaker.core.resources import HubContent
189+
190+
logger.info(
191+
"Resolving missing hub_content_version for hub_content_name='%s' "
192+
"from SageMakerPublicHub...",
193+
hub_content_name,
194+
)
195+
hc = HubContent.get(
196+
hub_content_type="Model",
197+
hub_name="SageMakerPublicHub",
198+
hub_content_name=str(hub_content_name),
199+
)
200+
if hasattr(hc, "hub_content_version") and not isinstance(
201+
hc.hub_content_version, Unassigned
202+
):
203+
base_model.hub_content_version = hc.hub_content_version
204+
logger.info(
205+
"Resolved hub_content_version='%s' for hub_content_name='%s'",
206+
base_model.hub_content_version,
207+
hub_content_name,
208+
)
209+
else:
210+
logger.warning(
211+
"Could not resolve hub_content_version for hub_content_name='%s'. "
212+
"The HubContent response did not contain a valid version.",
213+
hub_content_name,
214+
)
215+
except Exception as e:
216+
logger.warning(
217+
"Failed to resolve hub_content_version for hub_content_name='%s' "
218+
"from SageMakerPublicHub. You may need to set it manually. Error: %s",
219+
hub_content_name,
220+
e,
221+
)
222+
223+
if recipe_missing:
224+
logger.warning(
225+
"recipe_name is missing (Unassigned) for hub_content_name='%s'. "
226+
"ModelBuilder will proceed without it. If a recipe is required, "
227+
"please set base_model.recipe_name manually before calling build().",
228+
hub_content_name,
229+
)
230+
231+
return base_model
140232
_DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py"
141233
_NO_JS_MODEL_EX = "HuggingFace JumpStart Model ID not detected. Building for HuggingFace Model ID."
142234
_JS_SCOPE = "inference"
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tests for resolve_base_model_fields utility function."""
14+
from __future__ import absolute_import
15+
16+
import pytest
17+
from unittest.mock import patch, MagicMock
18+
19+
from sagemaker.core.utils.utils import Unassigned
20+
from sagemaker.serve.model_builder_utils import resolve_base_model_fields
21+
22+
23+
class FakeBaseModel:
24+
"""Fake BaseModel for testing."""
25+
26+
def __init__(self, hub_content_name=None, hub_content_version=None, recipe_name=None):
27+
self.hub_content_name = hub_content_name
28+
self.hub_content_version = hub_content_version
29+
self.recipe_name = recipe_name
30+
31+
32+
class FakeHubContent:
33+
"""Fake HubContent response."""
34+
35+
def __init__(self, hub_content_version=None):
36+
self.hub_content_version = hub_content_version
37+
38+
39+
class TestResolveBaseModelFields:
40+
"""Tests for resolve_base_model_fields."""
41+
42+
def test_resolve_with_none_base_model(self):
43+
"""Test that None base_model is returned unchanged."""
44+
result = resolve_base_model_fields(None)
45+
assert result is None
46+
47+
def test_resolve_with_no_hub_content_name_returns_unchanged(self):
48+
"""Test that base_model without hub_content_name is returned unchanged."""
49+
base_model = FakeBaseModel(
50+
hub_content_name=Unassigned(),
51+
hub_content_version=Unassigned(),
52+
recipe_name=Unassigned(),
53+
)
54+
result = resolve_base_model_fields(base_model)
55+
assert isinstance(result.hub_content_version, Unassigned)
56+
assert isinstance(result.recipe_name, Unassigned)
57+
58+
def test_resolve_with_none_hub_content_name_returns_unchanged(self):
59+
"""Test that base_model with None hub_content_name is returned unchanged."""
60+
base_model = FakeBaseModel(
61+
hub_content_name=None,
62+
hub_content_version=Unassigned(),
63+
recipe_name=Unassigned(),
64+
)
65+
result = resolve_base_model_fields(base_model)
66+
assert isinstance(result.hub_content_version, Unassigned)
67+
68+
def test_resolve_with_empty_hub_content_name_returns_unchanged(self):
69+
"""Test that base_model with empty hub_content_name is returned unchanged."""
70+
base_model = FakeBaseModel(
71+
hub_content_name="",
72+
hub_content_version=Unassigned(),
73+
recipe_name=Unassigned(),
74+
)
75+
result = resolve_base_model_fields(base_model)
76+
assert isinstance(result.hub_content_version, Unassigned)
77+
78+
def test_resolve_with_all_fields_present_no_api_call(self):
79+
"""Test that no API call is made when all fields are already present."""
80+
base_model = FakeBaseModel(
81+
hub_content_name="huggingface-model-abc",
82+
hub_content_version="1.0.0",
83+
recipe_name="my-recipe",
84+
)
85+
with patch("sagemaker.serve.model_builder_utils.HubContent", autospec=True) as mock_hc:
86+
# HubContent should NOT be imported/called
87+
result = resolve_base_model_fields(base_model)
88+
assert result.hub_content_version == "1.0.0"
89+
assert result.recipe_name == "my-recipe"
90+
91+
@patch("sagemaker.core.resources.HubContent")
92+
def test_resolve_missing_hub_content_version_resolves_from_hub(self, mock_hub_content_cls):
93+
"""Test that missing hub_content_version is resolved from SageMakerPublicHub."""
94+
fake_hc = FakeHubContent(hub_content_version="2.5.0")
95+
mock_hub_content_cls.get.return_value = fake_hc
96+
97+
base_model = FakeBaseModel(
98+
hub_content_name="huggingface-reasoning-qwen3-32b",
99+
hub_content_version=Unassigned(),
100+
recipe_name="some-recipe",
101+
)
102+
103+
with patch(
104+
"sagemaker.serve.model_builder_utils.HubContent", mock_hub_content_cls
105+
):
106+
result = resolve_base_model_fields(base_model)
107+
108+
assert result.hub_content_version == "2.5.0"
109+
mock_hub_content_cls.get.assert_called_once_with(
110+
hub_content_type="Model",
111+
hub_name="SageMakerPublicHub",
112+
hub_content_name="huggingface-reasoning-qwen3-32b",
113+
)
114+
115+
@patch("sagemaker.core.resources.HubContent")
116+
def test_resolve_missing_recipe_name_logs_warning(self, mock_hub_content_cls):
117+
"""Test that missing recipe_name logs a warning but does not crash."""
118+
base_model = FakeBaseModel(
119+
hub_content_name="huggingface-reasoning-qwen3-32b",
120+
hub_content_version="1.0.0",
121+
recipe_name=Unassigned(),
122+
)
123+
124+
result = resolve_base_model_fields(base_model)
125+
# recipe_name should still be Unassigned (not resolved automatically)
126+
assert isinstance(result.recipe_name, Unassigned)
127+
# But the function should not crash
128+
assert result.hub_content_version == "1.0.0"
129+
130+
@patch("sagemaker.core.resources.HubContent")
131+
def test_resolve_hub_content_not_found_does_not_crash(self, mock_hub_content_cls):
132+
"""Test that HubContent.get() failure is handled gracefully."""
133+
mock_hub_content_cls.get.side_effect = Exception("HubContent not found")
134+
135+
base_model = FakeBaseModel(
136+
hub_content_name="nonexistent-model",
137+
hub_content_version=Unassigned(),
138+
recipe_name="some-recipe",
139+
)
140+
141+
with patch(
142+
"sagemaker.serve.model_builder_utils.HubContent", mock_hub_content_cls
143+
):
144+
# Should not raise, just log a warning
145+
result = resolve_base_model_fields(base_model)
146+
147+
# hub_content_version should still be Unassigned since resolution failed
148+
assert isinstance(result.hub_content_version, Unassigned)
149+
150+
@patch("sagemaker.core.resources.HubContent")
151+
def test_resolve_both_version_and_recipe_missing(self, mock_hub_content_cls):
152+
"""Test resolution when both hub_content_version and recipe_name are missing."""
153+
fake_hc = FakeHubContent(hub_content_version="3.0.0")
154+
mock_hub_content_cls.get.return_value = fake_hc
155+
156+
base_model = FakeBaseModel(
157+
hub_content_name="huggingface-reasoning-qwen3-32b",
158+
hub_content_version=Unassigned(),
159+
recipe_name=Unassigned(),
160+
)
161+
162+
with patch(
163+
"sagemaker.serve.model_builder_utils.HubContent", mock_hub_content_cls
164+
):
165+
result = resolve_base_model_fields(base_model)
166+
167+
# Version should be resolved
168+
assert result.hub_content_version == "3.0.0"
169+
# Recipe should still be Unassigned (with warning logged)
170+
assert isinstance(result.recipe_name, Unassigned)
171+
172+
@patch("sagemaker.core.resources.HubContent")
173+
def test_resolve_with_none_version_resolves(self, mock_hub_content_cls):
174+
"""Test that None hub_content_version (not just Unassigned) is also resolved."""
175+
fake_hc = FakeHubContent(hub_content_version="1.2.3")
176+
mock_hub_content_cls.get.return_value = fake_hc
177+
178+
base_model = FakeBaseModel(
179+
hub_content_name="huggingface-model-xyz",
180+
hub_content_version=None,
181+
recipe_name="my-recipe",
182+
)
183+
184+
with patch(
185+
"sagemaker.serve.model_builder_utils.HubContent", mock_hub_content_cls
186+
):
187+
result = resolve_base_model_fields(base_model)
188+
189+
assert result.hub_content_version == "1.2.3"
190+
191+
@patch("sagemaker.core.resources.HubContent")
192+
def test_resolve_with_empty_string_version_resolves(self, mock_hub_content_cls):
193+
"""Test that empty string hub_content_version is also resolved."""
194+
fake_hc = FakeHubContent(hub_content_version="4.0.0")
195+
mock_hub_content_cls.get.return_value = fake_hc
196+
197+
base_model = FakeBaseModel(
198+
hub_content_name="huggingface-model-xyz",
199+
hub_content_version="",
200+
recipe_name="my-recipe",
201+
)
202+
203+
with patch(
204+
"sagemaker.serve.model_builder_utils.HubContent", mock_hub_content_cls
205+
):
206+
result = resolve_base_model_fields(base_model)
207+
208+
assert result.hub_content_version == "4.0.0"

0 commit comments

Comments
 (0)