From 34829ce5f1f15f1ec91720f5d8579a5dda94542e Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Tue, 14 Apr 2026 13:14:23 -0700 Subject: [PATCH 1/2] fix: Model builder unable to (5667) --- .../src/sagemaker/serve/model_builder.py | 144 +++++++++- .../unit/test_resolve_base_model_fields.py | 251 ++++++++++++++++++ 2 files changed, 391 insertions(+), 4 deletions(-) create mode 100644 sagemaker-serve/tests/unit/test_resolve_base_model_fields.py diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 7c7af2defc..63b9d5268c 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -52,7 +52,7 @@ ModelCard, ModelPackageModelCard, ) -from sagemaker.core.utils.utils import logger +from sagemaker.core.utils.utils import logger, Unassigned from sagemaker.core.helper import session_helper from sagemaker.core.helper.session_helper import ( Session, @@ -680,18 +680,138 @@ def _infer_instance_type_from_jumpstart(self) -> str: raise ValueError(error_msg) + def _resolve_base_model_fields(self): + """Auto-resolve missing BaseModel fields (hub_content_version, recipe_name). + + When a ModelPackage's BaseModel has hub_content_name set but is missing + hub_content_version and/or recipe_name (returned as Unassigned from the + DescribeModelPackage API), this method resolves them automatically: + - hub_content_version: resolved by calling HubContent.get on SageMakerPublicHub + - recipe_name: resolved from the first recipe in the hub document's RecipeCollection + """ + if hasattr(self, "_base_model_fields_resolved") and self._base_model_fields_resolved: + return + + model_package = self._fetch_model_package() + if not model_package: + self._base_model_fields_resolved = True + return + + inference_spec = getattr(model_package, "inference_specification", None) + if not inference_spec: + self._base_model_fields_resolved = True + return + + containers = getattr(inference_spec, "containers", None) + if not containers: + self._base_model_fields_resolved = True + return + + base_model = getattr(containers[0], "base_model", None) + if not base_model: + self._base_model_fields_resolved = True + return + + hub_content_name = getattr(base_model, "hub_content_name", None) + if not hub_content_name or isinstance(hub_content_name, Unassigned): + self._base_model_fields_resolved = True + return + + hub_content_version = getattr(base_model, "hub_content_version", None) + recipe_name = getattr(base_model, "recipe_name", None) + + # Resolve hub_content_version if missing + if not hub_content_version or isinstance(hub_content_version, Unassigned): + logger.info( + "hub_content_version is missing for hub content '%s'. " + "Resolving automatically from SageMakerPublicHub...", + hub_content_name, + ) + try: + hc = HubContent.get( + hub_content_type="Model", + hub_name="SageMakerPublicHub", + hub_content_name=hub_content_name, + ) + base_model.hub_content_version = hc.hub_content_version + logger.info( + "Resolved hub_content_version to '%s' for hub content '%s'.", + hc.hub_content_version, + hub_content_name, + ) + except Exception as e: + logger.warning( + "Failed to resolve hub_content_version for hub content '%s': %s", + hub_content_name, + e, + ) + self._base_model_fields_resolved = True + return + + # Resolve recipe_name if missing + if not recipe_name or isinstance(recipe_name, Unassigned): + logger.info( + "recipe_name is missing for hub content '%s'. " + "Resolving automatically from hub document RecipeCollection...", + hub_content_name, + ) + try: + hub_content = HubContent.get( + hub_content_type="Model", + hub_name="SageMakerPublicHub", + hub_content_name=base_model.hub_content_name, + hub_content_version=base_model.hub_content_version, + ) + hub_document = json.loads(hub_content.hub_content_document) + recipe_collection = hub_document.get("RecipeCollection", []) + if recipe_collection: + resolved_recipe = recipe_collection[0].get("Name", "") + if resolved_recipe: + base_model.recipe_name = resolved_recipe + logger.info( + "Resolved recipe_name to '%s' for hub content '%s'.", + resolved_recipe, + hub_content_name, + ) + else: + logger.warning( + "RecipeCollection found but first recipe has no Name for hub content '%s'.", + hub_content_name, + ) + else: + logger.warning( + "No RecipeCollection found in hub document for hub content '%s'. " + "recipe_name could not be auto-resolved.", + hub_content_name, + ) + except Exception as e: + logger.warning( + "Failed to resolve recipe_name for hub content '%s': %s", + hub_content_name, + e, + ) + + self._base_model_fields_resolved = True + def _fetch_hub_document_for_custom_model(self) -> dict: from sagemaker.core.shapes import BaseModel as CoreBaseModel + self._resolve_base_model_fields() + base_model: CoreBaseModel = ( self._fetch_model_package().inference_specification.containers[0].base_model ) - hub_content = HubContent.get( + + hub_content_version = getattr(base_model, "hub_content_version", None) + get_kwargs = dict( hub_content_type="Model", hub_name="SageMakerPublicHub", hub_content_name=base_model.hub_content_name, - hub_content_version=base_model.hub_content_version, ) + if hub_content_version and not isinstance(hub_content_version, Unassigned): + get_kwargs["hub_content_version"] = hub_content_version + + hub_content = HubContent.get(**get_kwargs) return json.loads(hub_content.hub_content_document) def _fetch_hosting_configs_for_custom_model(self) -> dict: @@ -937,9 +1057,20 @@ def _is_gpu_instance(self, instance_type: str) -> bool: def _fetch_and_cache_recipe_config(self): """Fetch and cache image URI, compute requirements, and s3_upload_path from recipe during build.""" + self._resolve_base_model_fields() hub_document = self._fetch_hub_document_for_custom_model() model_package = self._fetch_model_package() - recipe_name = model_package.inference_specification.containers[0].base_model.recipe_name + recipe_name = getattr( + model_package.inference_specification.containers[0].base_model, "recipe_name", None + ) + if not recipe_name or isinstance(recipe_name, Unassigned): + raise ValueError( + "recipe_name is missing from the model package's BaseModel and could not be " + "auto-resolved from the hub document. Please ensure the model package has a " + "valid recipe_name set, or manually set it before calling build(). " + "Example: model_package.inference_specification.containers[0].base_model" + ".recipe_name = 'your-recipe-name'" + ) if not self.s3_upload_path: self.s3_upload_path = model_package.inference_specification.containers[ @@ -1060,7 +1191,11 @@ def _is_nova_model(self) -> bool: if not base_model: return False recipe_name = getattr(base_model, "recipe_name", "") or "" + if isinstance(recipe_name, Unassigned): + recipe_name = "" hub_content_name = getattr(base_model, "hub_content_name", "") or "" + if isinstance(hub_content_name, Unassigned): + hub_content_name = "" return "nova" in recipe_name.lower() or "nova" in hub_content_name.lower() def _is_nova_model_for_telemetry(self) -> bool: @@ -1076,6 +1211,7 @@ def _get_nova_hosting_config(self, instance_type=None): Nova training recipes don't have hosting configs in the JumpStart hub document. This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs(). """ + self._resolve_base_model_fields() model_package = self._fetch_model_package() hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name diff --git a/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py b/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py new file mode 100644 index 0000000000..754968a959 --- /dev/null +++ b/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py @@ -0,0 +1,251 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for _resolve_base_model_fields and related Unassigned handling.""" +from __future__ import absolute_import + +import json +import pytest +from unittest.mock import MagicMock, patch, PropertyMock + +from sagemaker.core.utils.utils import Unassigned + + +def _make_model_builder(**kwargs): + """Create a ModelBuilder instance with mocked session to avoid real AWS calls.""" + with patch("sagemaker.serve.model_builder.Session"): + with patch("sagemaker.serve.model_builder.get_execution_role", return_value="arn:aws:iam::123456789012:role/SageMakerRole"): + from sagemaker.serve.model_builder import ModelBuilder + defaults = dict( + model="dummy-model", + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + ) + defaults.update(kwargs) + mb = ModelBuilder(**defaults) + # Reset the resolution flag so tests can trigger it + mb._base_model_fields_resolved = False + return mb + + +def _make_base_model(hub_content_name=None, hub_content_version=None, recipe_name=None): + """Create a mock BaseModel with the given fields.""" + base_model = MagicMock() + base_model.hub_content_name = hub_content_name if hub_content_name is not None else Unassigned() + base_model.hub_content_version = hub_content_version if hub_content_version is not None else Unassigned() + base_model.recipe_name = recipe_name if recipe_name is not None else Unassigned() + return base_model + + +def _make_model_package(base_model): + """Create a mock ModelPackage with the given base_model.""" + container = MagicMock() + container.base_model = base_model + container.model_data_source = MagicMock() + container.model_data_source.s3_data_source = MagicMock() + container.model_data_source.s3_data_source.s3_uri = "s3://bucket/path" + + model_package = MagicMock() + model_package.inference_specification.containers = [container] + return model_package + + +def _make_hub_content(hub_content_version="1.0.0", hub_content_document=None): + """Create a mock HubContent object.""" + hc = MagicMock() + hc.hub_content_version = hub_content_version + if hub_content_document is None: + hub_content_document = json.dumps({ + "RecipeCollection": [ + {"Name": "auto-resolved-recipe", "HostingConfigs": []} + ], + "HostingConfigs": [], + }) + hc.hub_content_document = hub_content_document + return hc + + +class TestResolveBaseModelFields: + """Tests for _resolve_base_model_fields method.""" + + @patch("sagemaker.serve.model_builder.HubContent") + def test_resolve_missing_hub_content_version(self, mock_hub_content_cls): + """When hub_content_version is Unassigned, it should be resolved from HubContent.get.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, # Will be Unassigned + recipe_name="some-recipe", + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock(return_value=model_package) + + mock_hc = _make_hub_content(hub_content_version="2.5.0") + mock_hub_content_cls.get.return_value = mock_hc + + mb._resolve_base_model_fields() + + assert base_model.hub_content_version == "2.5.0" + # recipe_name should remain unchanged since it was already set + assert base_model.recipe_name == "some-recipe" + + @patch("sagemaker.serve.model_builder.HubContent") + def test_resolve_missing_recipe_name(self, mock_hub_content_cls): + """When recipe_name is Unassigned, it should be resolved from RecipeCollection.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version="1.0.0", + recipe_name=None, # Will be Unassigned + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock(return_value=model_package) + + hub_doc = json.dumps({ + "RecipeCollection": [ + {"Name": "verl-grpo-rlaif-qwen-3-32b-lora", "HostingConfigs": []} + ], + }) + mock_hc = _make_hub_content(hub_content_version="1.0.0", hub_content_document=hub_doc) + mock_hub_content_cls.get.return_value = mock_hc + + mb._resolve_base_model_fields() + + assert base_model.recipe_name == "verl-grpo-rlaif-qwen-3-32b-lora" + + @patch("sagemaker.serve.model_builder.HubContent") + def test_noop_when_all_fields_present(self, mock_hub_content_cls): + """When all fields are present, HubContent.get should not be called.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version="1.0.0", + recipe_name="some-recipe", + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock(return_value=model_package) + + mb._resolve_base_model_fields() + + mock_hub_content_cls.get.assert_not_called() + assert base_model.hub_content_version == "1.0.0" + assert base_model.recipe_name == "some-recipe" + + @patch("sagemaker.serve.model_builder.HubContent") + def test_resolve_both_version_and_recipe(self, mock_hub_content_cls): + """When both hub_content_version and recipe_name are Unassigned, both should be resolved.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock(return_value=model_package) + + hub_doc = json.dumps({ + "RecipeCollection": [ + {"Name": "auto-resolved-recipe", "HostingConfigs": []} + ], + }) + mock_hc = _make_hub_content(hub_content_version="3.0.0", hub_content_document=hub_doc) + mock_hub_content_cls.get.return_value = mock_hc + + mb._resolve_base_model_fields() + + assert base_model.hub_content_version == "3.0.0" + assert base_model.recipe_name == "auto-resolved-recipe" + + @patch("sagemaker.serve.model_builder.HubContent") + def test_fetch_hub_document_works_after_resolution(self, mock_hub_content_cls): + """_fetch_hub_document_for_custom_model should work when hub_content_version was Unassigned.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name="some-recipe", + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock(return_value=model_package) + + hub_doc = json.dumps({"HostingConfigs": [{"Profile": "Default"}]}) + mock_hc = _make_hub_content(hub_content_version="1.0.0", hub_content_document=hub_doc) + mock_hub_content_cls.get.return_value = mock_hc + + result = mb._fetch_hub_document_for_custom_model() + + assert result == {"HostingConfigs": [{"Profile": "Default"}]} + + @patch("sagemaker.serve.model_builder.HubContent") + def test_no_base_model_is_noop(self, mock_hub_content_cls): + """When containers[0] has no base_model, method should return without error.""" + mb = _make_model_builder() + container = MagicMock() + container.base_model = None + model_package = MagicMock() + model_package.inference_specification.containers = [container] + mb._fetch_model_package = MagicMock(return_value=model_package) + + mb._resolve_base_model_fields() + + mock_hub_content_cls.get.assert_not_called() + + @patch("sagemaker.serve.model_builder.HubContent") + def test_no_hub_content_name_is_noop(self, mock_hub_content_cls): + """When hub_content_name is Unassigned, method should return without calling HubContent.get.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name=None, # Will be Unassigned + hub_content_version=None, + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock(return_value=model_package) + + mb._resolve_base_model_fields() + + mock_hub_content_cls.get.assert_not_called() + + @patch("sagemaker.serve.model_builder.HubContent") + def test_is_nova_model_with_unassigned_fields_does_not_crash(self, mock_hub_content_cls): + """_is_nova_model should return False without raising when fields are Unassigned.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name=None, # Unassigned + hub_content_version=None, + recipe_name=None, # Unassigned + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock(return_value=model_package) + + result = mb._is_nova_model() + + assert result is False + + @patch("sagemaker.serve.model_builder.HubContent") + def test_fetch_and_cache_recipe_config_raises_when_recipe_unresolvable(self, mock_hub_content_cls): + """When recipe_name cannot be resolved, _fetch_and_cache_recipe_config should raise ValueError.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version="1.0.0", + recipe_name=None, # Unassigned + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock(return_value=model_package) + + # Hub document with empty RecipeCollection - recipe cannot be resolved + hub_doc = json.dumps({"RecipeCollection": [], "HostingConfigs": []}) + mock_hc = _make_hub_content(hub_content_version="1.0.0", hub_content_document=hub_doc) + mock_hub_content_cls.get.return_value = mock_hc + + with pytest.raises(ValueError, match="recipe_name is missing"): + mb._fetch_and_cache_recipe_config() From 7a80bae8fa3162f9950ae6d69553c39458bc1008 Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:27:40 -0700 Subject: [PATCH 2/2] fix: address review comments (iteration #1) --- .../src/sagemaker/serve/model_builder.py | 125 +++++-- .../unit/test_resolve_base_model_fields.py | 336 +++++++++++++++--- 2 files changed, 369 insertions(+), 92 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 63b9d5268c..5fbd923437 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -414,6 +414,7 @@ def __post_init__(self) -> None: if self.log_level is not None: logger.setLevel(self.log_level) + self._base_model_fields_resolved: bool = False self._warn_about_deprecated_parameters(warnings) self._initialize_compute_config() self._initialize_network_config() @@ -688,8 +689,11 @@ def _resolve_base_model_fields(self): DescribeModelPackage API), this method resolves them automatically: - hub_content_version: resolved by calling HubContent.get on SageMakerPublicHub - recipe_name: resolved from the first recipe in the hub document's RecipeCollection + + Note: HubContent.get() supports being called without hub_content_version, + in which case it returns the latest version of the hub content. """ - if hasattr(self, "_base_model_fields_resolved") and self._base_model_fields_resolved: + if self._base_model_fields_resolved: return model_package = self._fetch_model_package() @@ -720,6 +724,9 @@ def _resolve_base_model_fields(self): hub_content_version = getattr(base_model, "hub_content_version", None) recipe_name = getattr(base_model, "recipe_name", None) + # Cache the HubContent response to avoid redundant API calls + cached_hub_content = None + # Resolve hub_content_version if missing if not hub_content_version or isinstance(hub_content_version, Unassigned): logger.info( @@ -728,20 +735,24 @@ def _resolve_base_model_fields(self): hub_content_name, ) try: - hc = HubContent.get( + cached_hub_content = HubContent.get( hub_content_type="Model", hub_name="SageMakerPublicHub", hub_content_name=hub_content_name, ) - base_model.hub_content_version = hc.hub_content_version + base_model.hub_content_version = ( + cached_hub_content.hub_content_version + ) logger.info( - "Resolved hub_content_version to '%s' for hub content '%s'.", - hc.hub_content_version, + "Resolved hub_content_version to '%s' " + "for hub content '%s'.", + cached_hub_content.hub_content_version, hub_content_name, ) - except Exception as e: + except (ClientError, ValueError) as e: logger.warning( - "Failed to resolve hub_content_version for hub content '%s': %s", + "Failed to resolve hub_content_version " + "for hub content '%s': %s", hub_content_name, e, ) @@ -752,41 +763,63 @@ def _resolve_base_model_fields(self): if not recipe_name or isinstance(recipe_name, Unassigned): logger.info( "recipe_name is missing for hub content '%s'. " - "Resolving automatically from hub document RecipeCollection...", + "Resolving from hub document RecipeCollection...", hub_content_name, ) try: - hub_content = HubContent.get( - hub_content_type="Model", - hub_name="SageMakerPublicHub", - hub_content_name=base_model.hub_content_name, - hub_content_version=base_model.hub_content_version, + # Reuse cached hub content if available and version matches + if ( + cached_hub_content is not None + and cached_hub_content.hub_content_version + == base_model.hub_content_version + ): + hub_content = cached_hub_content + else: + hub_content = HubContent.get( + hub_content_type="Model", + hub_name="SageMakerPublicHub", + hub_content_name=base_model.hub_content_name, + hub_content_version=( + base_model.hub_content_version + ), + ) + hub_document = json.loads( + hub_content.hub_content_document + ) + recipe_collection = hub_document.get( + "RecipeCollection", [] ) - hub_document = json.loads(hub_content.hub_content_document) - recipe_collection = hub_document.get("RecipeCollection", []) if recipe_collection: - resolved_recipe = recipe_collection[0].get("Name", "") + resolved_recipe = recipe_collection[0].get( + "Name", "" + ) if resolved_recipe: base_model.recipe_name = resolved_recipe logger.info( - "Resolved recipe_name to '%s' for hub content '%s'.", + "Resolved recipe_name to '%s' " + "for hub content '%s'.", resolved_recipe, hub_content_name, ) else: logger.warning( - "RecipeCollection found but first recipe has no Name for hub content '%s'.", + "RecipeCollection found but first " + "recipe has no Name for hub " + "content '%s'.", hub_content_name, ) else: logger.warning( - "No RecipeCollection found in hub document for hub content '%s'. " - "recipe_name could not be auto-resolved.", + "No RecipeCollection found in hub " + "document for hub content '%s'. " + "recipe_name could not be " + "auto-resolved.", hub_content_name, ) - except Exception as e: + except (ClientError, ValueError) as e: logger.warning( - "Failed to resolve recipe_name for hub content '%s': %s", + "Failed to resolve recipe_name " + "for hub content '%s': %s", hub_content_name, e, ) @@ -794,21 +827,33 @@ def _resolve_base_model_fields(self): self._base_model_fields_resolved = True def _fetch_hub_document_for_custom_model(self) -> dict: + """Fetch the hub document for a custom (fine-tuned) model. + + Calls _resolve_base_model_fields() first to ensure hub_content_version + is populated. If hub_content_version is still Unassigned after + resolution (e.g. resolution failed), HubContent.get() is called + without a version parameter, which returns the latest version. + """ from sagemaker.core.shapes import BaseModel as CoreBaseModel self._resolve_base_model_fields() base_model: CoreBaseModel = ( - self._fetch_model_package().inference_specification.containers[0].base_model + self._fetch_model_package() + .inference_specification.containers[0].base_model ) - hub_content_version = getattr(base_model, "hub_content_version", None) + hub_content_version = getattr( + base_model, "hub_content_version", None + ) get_kwargs = dict( hub_content_type="Model", hub_name="SageMakerPublicHub", hub_content_name=base_model.hub_content_name, ) - if hub_content_version and not isinstance(hub_content_version, Unassigned): + if hub_content_version and not isinstance( + hub_content_version, Unassigned + ): get_kwargs["hub_content_version"] = hub_content_version hub_content = HubContent.get(**get_kwargs) @@ -1057,19 +1102,28 @@ def _is_gpu_instance(self, instance_type: str) -> bool: def _fetch_and_cache_recipe_config(self): """Fetch and cache image URI, compute requirements, and s3_upload_path from recipe during build.""" - self._resolve_base_model_fields() + # _fetch_hub_document_for_custom_model calls _resolve_base_model_fields + # internally, so no need to call it separately here. hub_document = self._fetch_hub_document_for_custom_model() model_package = self._fetch_model_package() - recipe_name = getattr( - model_package.inference_specification.containers[0].base_model, "recipe_name", None + base_model = ( + model_package.inference_specification + .containers[0].base_model + ) + hub_content_name = getattr( + base_model, "hub_content_name", "unknown" ) + recipe_name = getattr(base_model, "recipe_name", None) if not recipe_name or isinstance(recipe_name, Unassigned): raise ValueError( - "recipe_name is missing from the model package's BaseModel and could not be " - "auto-resolved from the hub document. Please ensure the model package has a " - "valid recipe_name set, or manually set it before calling build(). " - "Example: model_package.inference_specification.containers[0].base_model" - ".recipe_name = 'your-recipe-name'" + f"recipe_name is missing from the model package's " + f"BaseModel (hub_content_name='{hub_content_name}') " + f"and could not be auto-resolved from the hub " + f"document. Please ensure the model package has a " + f"valid recipe_name set, or manually set it before " + f"calling build(). Example: model_package." + f"inference_specification.containers[0].base_model" + f".recipe_name = 'your-recipe-name'" ) if not self.s3_upload_path: @@ -1213,7 +1267,10 @@ def _get_nova_hosting_config(self, instance_type=None): """ self._resolve_base_model_fields() model_package = self._fetch_model_package() - hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name + hub_content_name = ( + model_package.inference_specification + .containers[0].base_model.hub_content_name + ) configs = self._NOVA_HOSTING_CONFIGS.get(hub_content_name) if not configs: diff --git a/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py b/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py index 754968a959..203d389e4d 100644 --- a/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py +++ b/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py @@ -11,23 +11,31 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Tests for _resolve_base_model_fields and related Unassigned handling.""" -from __future__ import absolute_import +from __future__ import annotations import json import pytest -from unittest.mock import MagicMock, patch, PropertyMock +from unittest.mock import MagicMock, patch +from botocore.exceptions import ClientError from sagemaker.core.utils.utils import Unassigned def _make_model_builder(**kwargs): - """Create a ModelBuilder instance with mocked session to avoid real AWS calls.""" + """Create a ModelBuilder with mocked session to avoid real AWS calls.""" with patch("sagemaker.serve.model_builder.Session"): - with patch("sagemaker.serve.model_builder.get_execution_role", return_value="arn:aws:iam::123456789012:role/SageMakerRole"): + with patch( + "sagemaker.serve.model_builder.get_execution_role", + return_value=( + "arn:aws:iam::123456789012:role/SageMakerRole" + ), + ): from sagemaker.serve.model_builder import ModelBuilder defaults = dict( model="dummy-model", - role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + role_arn=( + "arn:aws:iam::123456789012:role/SageMakerRole" + ), ) defaults.update(kwargs) mb = ModelBuilder(**defaults) @@ -36,12 +44,28 @@ def _make_model_builder(**kwargs): return mb -def _make_base_model(hub_content_name=None, hub_content_version=None, recipe_name=None): +def _make_base_model( + hub_content_name=None, + hub_content_version=None, + recipe_name=None, +): """Create a mock BaseModel with the given fields.""" base_model = MagicMock() - base_model.hub_content_name = hub_content_name if hub_content_name is not None else Unassigned() - base_model.hub_content_version = hub_content_version if hub_content_version is not None else Unassigned() - base_model.recipe_name = recipe_name if recipe_name is not None else Unassigned() + base_model.hub_content_name = ( + hub_content_name + if hub_content_name is not None + else Unassigned() + ) + base_model.hub_content_version = ( + hub_content_version + if hub_content_version is not None + else Unassigned() + ) + base_model.recipe_name = ( + recipe_name + if recipe_name is not None + else Unassigned() + ) return base_model @@ -51,21 +75,29 @@ def _make_model_package(base_model): container.base_model = base_model container.model_data_source = MagicMock() container.model_data_source.s3_data_source = MagicMock() - container.model_data_source.s3_data_source.s3_uri = "s3://bucket/path" + container.model_data_source.s3_data_source.s3_uri = ( + "s3://bucket/path" + ) model_package = MagicMock() model_package.inference_specification.containers = [container] return model_package -def _make_hub_content(hub_content_version="1.0.0", hub_content_document=None): +def _make_hub_content( + hub_content_version="1.0.0", + hub_content_document=None, +): """Create a mock HubContent object.""" hc = MagicMock() hc.hub_content_version = hub_content_version if hub_content_document is None: hub_content_document = json.dumps({ "RecipeCollection": [ - {"Name": "auto-resolved-recipe", "HostingConfigs": []} + { + "Name": "auto-resolved-recipe", + "HostingConfigs": [], + } ], "HostingConfigs": [], }) @@ -77,16 +109,20 @@ class TestResolveBaseModelFields: """Tests for _resolve_base_model_fields method.""" @patch("sagemaker.serve.model_builder.HubContent") - def test_resolve_missing_hub_content_version(self, mock_hub_content_cls): - """When hub_content_version is Unassigned, it should be resolved from HubContent.get.""" + def test_resolve_missing_hub_content_version( + self, mock_hub_content_cls + ): + """hub_content_version Unassigned => resolved from HubContent.get.""" mb = _make_model_builder() base_model = _make_base_model( hub_content_name="huggingface-reasoning-qwen3-32b", - hub_content_version=None, # Will be Unassigned + hub_content_version=None, recipe_name="some-recipe", ) model_package = _make_model_package(base_model) - mb._fetch_model_package = MagicMock(return_value=model_package) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) mock_hc = _make_hub_content(hub_content_version="2.5.0") mock_hub_content_cls.get.return_value = mock_hc @@ -94,36 +130,49 @@ def test_resolve_missing_hub_content_version(self, mock_hub_content_cls): mb._resolve_base_model_fields() assert base_model.hub_content_version == "2.5.0" - # recipe_name should remain unchanged since it was already set assert base_model.recipe_name == "some-recipe" @patch("sagemaker.serve.model_builder.HubContent") - def test_resolve_missing_recipe_name(self, mock_hub_content_cls): - """When recipe_name is Unassigned, it should be resolved from RecipeCollection.""" + def test_resolve_missing_recipe_name( + self, mock_hub_content_cls + ): + """recipe_name Unassigned => resolved from RecipeCollection.""" mb = _make_model_builder() base_model = _make_base_model( hub_content_name="huggingface-reasoning-qwen3-32b", hub_content_version="1.0.0", - recipe_name=None, # Will be Unassigned + recipe_name=None, ) model_package = _make_model_package(base_model) - mb._fetch_model_package = MagicMock(return_value=model_package) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) hub_doc = json.dumps({ "RecipeCollection": [ - {"Name": "verl-grpo-rlaif-qwen-3-32b-lora", "HostingConfigs": []} + { + "Name": "verl-grpo-rlaif-qwen-3-32b-lora", + "HostingConfigs": [], + } ], }) - mock_hc = _make_hub_content(hub_content_version="1.0.0", hub_content_document=hub_doc) + mock_hc = _make_hub_content( + hub_content_version="1.0.0", + hub_content_document=hub_doc, + ) mock_hub_content_cls.get.return_value = mock_hc mb._resolve_base_model_fields() - assert base_model.recipe_name == "verl-grpo-rlaif-qwen-3-32b-lora" + assert base_model.recipe_name == ( + "verl-grpo-rlaif-qwen-3-32b-lora" + ) @patch("sagemaker.serve.model_builder.HubContent") - def test_noop_when_all_fields_present(self, mock_hub_content_cls): - """When all fields are present, HubContent.get should not be called.""" + def test_noop_when_all_fields_present( + self, mock_hub_content_cls + ): + """All fields present => HubContent.get not called.""" mb = _make_model_builder() base_model = _make_base_model( hub_content_name="huggingface-reasoning-qwen3-32b", @@ -131,7 +180,9 @@ def test_noop_when_all_fields_present(self, mock_hub_content_cls): recipe_name="some-recipe", ) model_package = _make_model_package(base_model) - mb._fetch_model_package = MagicMock(return_value=model_package) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) mb._resolve_base_model_fields() @@ -140,8 +191,10 @@ def test_noop_when_all_fields_present(self, mock_hub_content_cls): assert base_model.recipe_name == "some-recipe" @patch("sagemaker.serve.model_builder.HubContent") - def test_resolve_both_version_and_recipe(self, mock_hub_content_cls): - """When both hub_content_version and recipe_name are Unassigned, both should be resolved.""" + def test_resolve_both_version_and_recipe( + self, mock_hub_content_cls + ): + """Both Unassigned => both resolved.""" mb = _make_model_builder() base_model = _make_base_model( hub_content_name="huggingface-reasoning-qwen3-32b", @@ -149,14 +202,22 @@ def test_resolve_both_version_and_recipe(self, mock_hub_content_cls): recipe_name=None, ) model_package = _make_model_package(base_model) - mb._fetch_model_package = MagicMock(return_value=model_package) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) hub_doc = json.dumps({ "RecipeCollection": [ - {"Name": "auto-resolved-recipe", "HostingConfigs": []} + { + "Name": "auto-resolved-recipe", + "HostingConfigs": [], + } ], }) - mock_hc = _make_hub_content(hub_content_version="3.0.0", hub_content_document=hub_doc) + mock_hc = _make_hub_content( + hub_content_version="3.0.0", + hub_content_document=hub_doc, + ) mock_hub_content_cls.get.return_value = mock_hc mb._resolve_base_model_fields() @@ -165,8 +226,10 @@ def test_resolve_both_version_and_recipe(self, mock_hub_content_cls): assert base_model.recipe_name == "auto-resolved-recipe" @patch("sagemaker.serve.model_builder.HubContent") - def test_fetch_hub_document_works_after_resolution(self, mock_hub_content_cls): - """_fetch_hub_document_for_custom_model should work when hub_content_version was Unassigned.""" + def test_fetch_hub_document_works_after_resolution( + self, mock_hub_content_cls + ): + """_fetch_hub_document_for_custom_model works after resolution.""" mb = _make_model_builder() base_model = _make_base_model( hub_content_name="huggingface-reasoning-qwen3-32b", @@ -174,78 +237,235 @@ def test_fetch_hub_document_works_after_resolution(self, mock_hub_content_cls): recipe_name="some-recipe", ) model_package = _make_model_package(base_model) - mb._fetch_model_package = MagicMock(return_value=model_package) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) - hub_doc = json.dumps({"HostingConfigs": [{"Profile": "Default"}]}) - mock_hc = _make_hub_content(hub_content_version="1.0.0", hub_content_document=hub_doc) + hub_doc = json.dumps( + {"HostingConfigs": [{"Profile": "Default"}]} + ) + mock_hc = _make_hub_content( + hub_content_version="1.0.0", + hub_content_document=hub_doc, + ) mock_hub_content_cls.get.return_value = mock_hc result = mb._fetch_hub_document_for_custom_model() - assert result == {"HostingConfigs": [{"Profile": "Default"}]} + assert result == { + "HostingConfigs": [{"Profile": "Default"}] + } @patch("sagemaker.serve.model_builder.HubContent") - def test_no_base_model_is_noop(self, mock_hub_content_cls): - """When containers[0] has no base_model, method should return without error.""" + def test_no_base_model_is_noop( + self, mock_hub_content_cls + ): + """No base_model => method returns without error.""" mb = _make_model_builder() container = MagicMock() container.base_model = None model_package = MagicMock() - model_package.inference_specification.containers = [container] - mb._fetch_model_package = MagicMock(return_value=model_package) + model_package.inference_specification.containers = [ + container + ] + mb._fetch_model_package = MagicMock( + return_value=model_package + ) mb._resolve_base_model_fields() mock_hub_content_cls.get.assert_not_called() @patch("sagemaker.serve.model_builder.HubContent") - def test_no_hub_content_name_is_noop(self, mock_hub_content_cls): - """When hub_content_name is Unassigned, method should return without calling HubContent.get.""" + def test_no_hub_content_name_is_noop( + self, mock_hub_content_cls + ): + """hub_content_name Unassigned => no HubContent.get call.""" mb = _make_model_builder() base_model = _make_base_model( - hub_content_name=None, # Will be Unassigned + hub_content_name=None, hub_content_version=None, recipe_name=None, ) model_package = _make_model_package(base_model) - mb._fetch_model_package = MagicMock(return_value=model_package) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) mb._resolve_base_model_fields() mock_hub_content_cls.get.assert_not_called() @patch("sagemaker.serve.model_builder.HubContent") - def test_is_nova_model_with_unassigned_fields_does_not_crash(self, mock_hub_content_cls): - """_is_nova_model should return False without raising when fields are Unassigned.""" + def test_is_nova_model_with_unassigned_fields( + self, mock_hub_content_cls + ): + """_is_nova_model returns False when fields are Unassigned.""" mb = _make_model_builder() base_model = _make_base_model( - hub_content_name=None, # Unassigned + hub_content_name=None, hub_content_version=None, - recipe_name=None, # Unassigned + recipe_name=None, ) model_package = _make_model_package(base_model) - mb._fetch_model_package = MagicMock(return_value=model_package) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) result = mb._is_nova_model() assert result is False @patch("sagemaker.serve.model_builder.HubContent") - def test_fetch_and_cache_recipe_config_raises_when_recipe_unresolvable(self, mock_hub_content_cls): - """When recipe_name cannot be resolved, _fetch_and_cache_recipe_config should raise ValueError.""" + def test_fetch_and_cache_recipe_raises_when_unresolvable( + self, mock_hub_content_cls + ): + """recipe_name unresolvable => ValueError from _fetch_and_cache.""" mb = _make_model_builder() base_model = _make_base_model( hub_content_name="huggingface-reasoning-qwen3-32b", hub_content_version="1.0.0", - recipe_name=None, # Unassigned + recipe_name=None, ) model_package = _make_model_package(base_model) - mb._fetch_model_package = MagicMock(return_value=model_package) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) - # Hub document with empty RecipeCollection - recipe cannot be resolved - hub_doc = json.dumps({"RecipeCollection": [], "HostingConfigs": []}) - mock_hc = _make_hub_content(hub_content_version="1.0.0", hub_content_document=hub_doc) + hub_doc = json.dumps( + {"RecipeCollection": [], "HostingConfigs": []} + ) + mock_hc = _make_hub_content( + hub_content_version="1.0.0", + hub_content_document=hub_doc, + ) mock_hub_content_cls.get.return_value = mock_hc with pytest.raises(ValueError, match="recipe_name is missing"): mb._fetch_and_cache_recipe_config() + + @patch("sagemaker.serve.model_builder.HubContent") + def test_resolve_graceful_on_hub_content_get_failure( + self, mock_hub_content_cls + ): + """When HubContent.get fails, resolution returns early. + + Downstream code (_fetch_and_cache_recipe_config) should still + raise the appropriate ValueError because recipe_name remains + Unassigned. + """ + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + # Simulate HubContent.get raising a ClientError + mock_hub_content_cls.get.side_effect = ClientError( + error_response={ + "Error": { + "Code": "ResourceNotFoundException", + "Message": "Hub content not found", + } + }, + operation_name="DescribeHubContent", + ) + + # _resolve_base_model_fields should not raise + mb._resolve_base_model_fields() + + # Fields should still be Unassigned + assert isinstance( + base_model.hub_content_version, Unassigned + ) + assert isinstance(base_model.recipe_name, Unassigned) + # The flag should be set so it doesn't retry + assert mb._base_model_fields_resolved is True + + @patch("sagemaker.serve.model_builder.HubContent") + def test_resolve_failure_then_fetch_and_cache_raises( + self, mock_hub_content_cls + ): + """When resolution fails, _fetch_and_cache_recipe_config raises. + + This tests the full flow: resolution fails gracefully, then + downstream code raises ValueError because recipe_name is still + Unassigned. + """ + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + # First call (from _resolve_base_model_fields) fails + # Second call (from _fetch_hub_document_for_custom_model) + # succeeds but returns empty RecipeCollection + hub_doc = json.dumps( + {"RecipeCollection": [], "HostingConfigs": []} + ) + mock_hc = _make_hub_content( + hub_content_version="1.0.0", + hub_content_document=hub_doc, + ) + + call_count = [0] + + def side_effect(**kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise ClientError( + error_response={ + "Error": { + "Code": "ResourceNotFoundException", + "Message": "Not found", + } + }, + operation_name="DescribeHubContent", + ) + return mock_hc + + mock_hub_content_cls.get.side_effect = side_effect + + with pytest.raises( + ValueError, match="recipe_name is missing" + ): + mb._fetch_and_cache_recipe_config() + + @patch("sagemaker.serve.model_builder.HubContent") + def test_idempotent_second_call_is_noop( + self, mock_hub_content_cls + ): + """Second call to _resolve_base_model_fields is a no-op.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name="some-recipe", + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + mock_hc = _make_hub_content(hub_content_version="2.0.0") + mock_hub_content_cls.get.return_value = mock_hc + + mb._resolve_base_model_fields() + assert base_model.hub_content_version == "2.0.0" + + # Reset mock to verify no additional calls + mock_hub_content_cls.get.reset_mock() + + mb._resolve_base_model_fields() + mock_hub_content_cls.get.assert_not_called()