Skip to content

Commit b84338b

Browse files
authored
fix(serve): repair HuggingFace -> JumpStart redirect in ModelBuilder (#5958)
Two bugs prevented `ModelBuilder(model="<HF_id_with_JS_mirror>").build()` from working with any bare HuggingFace id (gpt2, Qwen/Qwen3-0.6B, microsoft/phi-2, etc.), reproducible end-to-end: ModelBuilder.build() -> _build_single_modelbuilder (model_builder.py:2579) -> _use_jumpstart_equivalent (model_builder_utils.py) -> _hf_schema_builder_init raises TaskNotFoundException -> redirect aborts before HF -> JS id swap 1. `sagemaker/serve/schema/task.json` was missing from the published wheel. setuptools needs an explicit `[tool.setuptools.package-data]` declaration; `include-package-data = true` alone is a no-op without a MANIFEST.in or version-controlled file list. Without the JSON in the wheel, `retrieve_local_schemas("text-generation")` raised FileNotFoundError, the caller wrapped it in TaskNotFoundException, and the redirect aborted before swapping to the JumpStart id. 2. `_use_jumpstart_equivalent` called `_build_for_jumpstart()` itself, then `_build_single_modelbuilder` called it again on return. The second call hit `_prepare_for_mode` with consumed state and raised FileNotFoundError. Removed the inner call; the caller already handles it. Also wraps the HF schema bootstrap in a try/except so a missing local task or flaky remote schema fetch can't abort the redirect — the JumpStart build path supplies its own schema downstream. Verified end-to-end: ModelBuilder(model="gpt2").build() and ModelBuilder(model="Qwen/Qwen3-0.6B").build() both succeed (model created with TGI / DJL_SERVING respectively). Pure-boto3 control with the same JumpStart djl-inference image + S3 artifacts was already accepted by CreateModel/CreateEndpointConfig/CreateEndpoint, so the AWS contract and JumpStart artifacts were never the issue.
1 parent 089bf3c commit b84338b

3 files changed

Lines changed: 61 additions & 8 deletions

File tree

sagemaker-serve/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ where = ["src"]
6161
include = ["sagemaker*"]
6262
namespaces = true
6363

64+
[tool.setuptools.package-data]
65+
"sagemaker.serve" = ["schema/*.json"]
66+
6467
[tool.setuptools.dynamic]
6568
version = { file = "VERSION"}
6669

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -942,17 +942,20 @@ def _use_jumpstart_equivalent(self) -> bool:
942942
if not model_task:
943943
model_task = hf_model_md.get("pipeline_tag")
944944
if model_task:
945-
self._hf_schema_builder_init(model_task)
945+
try:
946+
self._hf_schema_builder_init(model_task)
947+
except (TaskNotFoundException, FileNotFoundError, OSError) as e:
948+
logger.warning(
949+
"Could not initialize HF schema builder for task %r "
950+
"(%s: %s); falling back to the JumpStart-supplied schema.",
951+
model_task, type(e).__name__, e,
952+
)
946953

947954
huggingface_model_id = self.model
948955
jumpstart_model_id = self._jumpstart_mapping[huggingface_model_id]["jumpstart-model-id"]
949956
self.model = jumpstart_model_id
950957
merged_date = self._jumpstart_mapping[huggingface_model_id].get("merged-at")
951958

952-
# Call _build_for_jumpstart if method exists
953-
if hasattr(self, "_build_for_jumpstart"):
954-
self._build_for_jumpstart()
955-
956959
compare_model_diff_message = (
957960
"If you want to identify the differences between the two, "
958961
"please use model_uris.retrieve() to retrieve the model "

sagemaker-serve/tests/unit/test_model_builder_utils_additional_gaps.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from sagemaker.serve.model_builder_utils import _ModelBuilderUtils
1212
from sagemaker.serve.constants import Framework
13+
from sagemaker.serve.utils.exceptions import TaskNotFoundException
1314
from sagemaker.serve.utils.types import ModelServer
1415
from sagemaker.train import ModelTrainer
1516

@@ -189,10 +190,56 @@ def test_auto_detect_image_uri_object_model(self, mock_detect_obj):
189190
class TestUseJumpStartEquivalent(unittest.TestCase):
190191
"""Test _use_jumpstart_equivalent method."""
191192

193+
def _make_utils(self):
194+
utils = _ModelBuilderUtils()
195+
utils.model = "gpt2"
196+
utils.image_uri = None
197+
utils.env_vars = None
198+
utils.schema_builder = None
199+
utils.model_metadata = None
200+
utils._is_gated_model = Mock(return_value=False)
201+
utils._build_for_jumpstart = Mock()
202+
return utils
203+
204+
@patch.object(_ModelBuilderUtils, 'get_huggingface_model_metadata')
205+
@patch.object(_ModelBuilderUtils, '_hf_schema_builder_init')
206+
@patch.object(_ModelBuilderUtils, '_retrieve_hugging_face_model_mapping')
207+
def test_use_jumpstart_equivalent_redirects_hf_to_js(
208+
self, mock_retrieve, mock_schema_init, mock_md
209+
):
210+
"""HF id with a JumpStart mirror is rewritten to its JS id."""
211+
utils = self._make_utils()
212+
mock_retrieve.return_value = {
213+
"gpt2": {"jumpstart-model-id": "huggingface-textgeneration-gpt2", "merged-at": "2024-01-01"}
214+
}
215+
mock_md.return_value = {"pipeline_tag": "text-generation"}
216+
217+
result = utils._use_jumpstart_equivalent()
218+
219+
self.assertTrue(result)
220+
self.assertEqual(utils.model, "huggingface-textgeneration-gpt2")
221+
utils._build_for_jumpstart.assert_not_called()
222+
mock_schema_init.assert_called_once_with("text-generation")
223+
224+
@patch.object(_ModelBuilderUtils, 'get_huggingface_model_metadata')
225+
@patch.object(_ModelBuilderUtils, '_hf_schema_builder_init')
192226
@patch.object(_ModelBuilderUtils, '_retrieve_hugging_face_model_mapping')
193-
def test_use_jumpstart_equivalent_no_image_uri(self, mock_retrieve):
194-
"""Test using JumpStart equivalent without image_uri - skipped (complex schema builder init)."""
195-
pass
227+
def test_use_jumpstart_equivalent_swallows_schema_failures(
228+
self, mock_retrieve, mock_schema_init, mock_md
229+
):
230+
"""Schema bootstrap failures are tolerated and the JumpStart id swap still happens."""
231+
mock_retrieve.return_value = {
232+
"gpt2": {"jumpstart-model-id": "huggingface-textgeneration-gpt2"}
233+
}
234+
mock_md.return_value = {"pipeline_tag": "text-generation"}
235+
236+
for exc in (TaskNotFoundException("x"), FileNotFoundError("x"), OSError("x")):
237+
with self.subTest(exception=type(exc).__name__):
238+
mock_schema_init.side_effect = exc
239+
utils = self._make_utils()
240+
result = utils._use_jumpstart_equivalent()
241+
self.assertTrue(result)
242+
self.assertEqual(utils.model, "huggingface-textgeneration-gpt2")
196243

197244
def test_use_jumpstart_equivalent_with_image_uri(self):
198245
"""Test using JumpStart equivalent with image_uri provided."""

0 commit comments

Comments
 (0)