Skip to content

Commit 34829ce

Browse files
committed
fix: Model builder unable to (5667)
1 parent daf19b0 commit 34829ce

File tree

2 files changed

+391
-4
lines changed

2 files changed

+391
-4
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 140 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
ModelCard,
5353
ModelPackageModelCard,
5454
)
55-
from sagemaker.core.utils.utils import logger
55+
from sagemaker.core.utils.utils import logger, Unassigned
5656
from sagemaker.core.helper import session_helper
5757
from sagemaker.core.helper.session_helper import (
5858
Session,
@@ -680,18 +680,138 @@ def _infer_instance_type_from_jumpstart(self) -> str:
680680

681681
raise ValueError(error_msg)
682682

683+
def _resolve_base_model_fields(self):
684+
"""Auto-resolve missing BaseModel fields (hub_content_version, recipe_name).
685+
686+
When a ModelPackage's BaseModel has hub_content_name set but is missing
687+
hub_content_version and/or recipe_name (returned as Unassigned from the
688+
DescribeModelPackage API), this method resolves them automatically:
689+
- hub_content_version: resolved by calling HubContent.get on SageMakerPublicHub
690+
- recipe_name: resolved from the first recipe in the hub document's RecipeCollection
691+
"""
692+
if hasattr(self, "_base_model_fields_resolved") and self._base_model_fields_resolved:
693+
return
694+
695+
model_package = self._fetch_model_package()
696+
if not model_package:
697+
self._base_model_fields_resolved = True
698+
return
699+
700+
inference_spec = getattr(model_package, "inference_specification", None)
701+
if not inference_spec:
702+
self._base_model_fields_resolved = True
703+
return
704+
705+
containers = getattr(inference_spec, "containers", None)
706+
if not containers:
707+
self._base_model_fields_resolved = True
708+
return
709+
710+
base_model = getattr(containers[0], "base_model", None)
711+
if not base_model:
712+
self._base_model_fields_resolved = True
713+
return
714+
715+
hub_content_name = getattr(base_model, "hub_content_name", None)
716+
if not hub_content_name or isinstance(hub_content_name, Unassigned):
717+
self._base_model_fields_resolved = True
718+
return
719+
720+
hub_content_version = getattr(base_model, "hub_content_version", None)
721+
recipe_name = getattr(base_model, "recipe_name", None)
722+
723+
# Resolve hub_content_version if missing
724+
if not hub_content_version or isinstance(hub_content_version, Unassigned):
725+
logger.info(
726+
"hub_content_version is missing for hub content '%s'. "
727+
"Resolving automatically from SageMakerPublicHub...",
728+
hub_content_name,
729+
)
730+
try:
731+
hc = HubContent.get(
732+
hub_content_type="Model",
733+
hub_name="SageMakerPublicHub",
734+
hub_content_name=hub_content_name,
735+
)
736+
base_model.hub_content_version = hc.hub_content_version
737+
logger.info(
738+
"Resolved hub_content_version to '%s' for hub content '%s'.",
739+
hc.hub_content_version,
740+
hub_content_name,
741+
)
742+
except Exception as e:
743+
logger.warning(
744+
"Failed to resolve hub_content_version for hub content '%s': %s",
745+
hub_content_name,
746+
e,
747+
)
748+
self._base_model_fields_resolved = True
749+
return
750+
751+
# Resolve recipe_name if missing
752+
if not recipe_name or isinstance(recipe_name, Unassigned):
753+
logger.info(
754+
"recipe_name is missing for hub content '%s'. "
755+
"Resolving automatically from hub document RecipeCollection...",
756+
hub_content_name,
757+
)
758+
try:
759+
hub_content = HubContent.get(
760+
hub_content_type="Model",
761+
hub_name="SageMakerPublicHub",
762+
hub_content_name=base_model.hub_content_name,
763+
hub_content_version=base_model.hub_content_version,
764+
)
765+
hub_document = json.loads(hub_content.hub_content_document)
766+
recipe_collection = hub_document.get("RecipeCollection", [])
767+
if recipe_collection:
768+
resolved_recipe = recipe_collection[0].get("Name", "")
769+
if resolved_recipe:
770+
base_model.recipe_name = resolved_recipe
771+
logger.info(
772+
"Resolved recipe_name to '%s' for hub content '%s'.",
773+
resolved_recipe,
774+
hub_content_name,
775+
)
776+
else:
777+
logger.warning(
778+
"RecipeCollection found but first recipe has no Name for hub content '%s'.",
779+
hub_content_name,
780+
)
781+
else:
782+
logger.warning(
783+
"No RecipeCollection found in hub document for hub content '%s'. "
784+
"recipe_name could not be auto-resolved.",
785+
hub_content_name,
786+
)
787+
except Exception as e:
788+
logger.warning(
789+
"Failed to resolve recipe_name for hub content '%s': %s",
790+
hub_content_name,
791+
e,
792+
)
793+
794+
self._base_model_fields_resolved = True
795+
683796
def _fetch_hub_document_for_custom_model(self) -> dict:
684797
from sagemaker.core.shapes import BaseModel as CoreBaseModel
685798

799+
self._resolve_base_model_fields()
800+
686801
base_model: CoreBaseModel = (
687802
self._fetch_model_package().inference_specification.containers[0].base_model
688803
)
689-
hub_content = HubContent.get(
804+
805+
hub_content_version = getattr(base_model, "hub_content_version", None)
806+
get_kwargs = dict(
690807
hub_content_type="Model",
691808
hub_name="SageMakerPublicHub",
692809
hub_content_name=base_model.hub_content_name,
693-
hub_content_version=base_model.hub_content_version,
694810
)
811+
if hub_content_version and not isinstance(hub_content_version, Unassigned):
812+
get_kwargs["hub_content_version"] = hub_content_version
813+
814+
hub_content = HubContent.get(**get_kwargs)
695815
return json.loads(hub_content.hub_content_document)
696816

697817
def _fetch_hosting_configs_for_custom_model(self) -> dict:
@@ -937,9 +1057,20 @@ def _is_gpu_instance(self, instance_type: str) -> bool:
9371057

9381058
def _fetch_and_cache_recipe_config(self):
9391059
"""Fetch and cache image URI, compute requirements, and s3_upload_path from recipe during build."""
1060+
self._resolve_base_model_fields()
9401061
hub_document = self._fetch_hub_document_for_custom_model()
9411062
model_package = self._fetch_model_package()
942-
recipe_name = model_package.inference_specification.containers[0].base_model.recipe_name
1063+
recipe_name = getattr(
1064+
model_package.inference_specification.containers[0].base_model, "recipe_name", None
1065+
)
1066+
if not recipe_name or isinstance(recipe_name, Unassigned):
1067+
raise ValueError(
1068+
"recipe_name is missing from the model package's BaseModel and could not be "
1069+
"auto-resolved from the hub document. Please ensure the model package has a "
1070+
"valid recipe_name set, or manually set it before calling build(). "
1071+
"Example: model_package.inference_specification.containers[0].base_model"
1072+
".recipe_name = 'your-recipe-name'"
1073+
)
9431074

9441075
if not self.s3_upload_path:
9451076
self.s3_upload_path = model_package.inference_specification.containers[
@@ -1060,7 +1191,11 @@ def _is_nova_model(self) -> bool:
10601191
if not base_model:
10611192
return False
10621193
recipe_name = getattr(base_model, "recipe_name", "") or ""
1194+
if isinstance(recipe_name, Unassigned):
1195+
recipe_name = ""
10631196
hub_content_name = getattr(base_model, "hub_content_name", "") or ""
1197+
if isinstance(hub_content_name, Unassigned):
1198+
hub_content_name = ""
10641199
return "nova" in recipe_name.lower() or "nova" in hub_content_name.lower()
10651200

10661201
def _is_nova_model_for_telemetry(self) -> bool:
@@ -1076,6 +1211,7 @@ def _get_nova_hosting_config(self, instance_type=None):
10761211
Nova training recipes don't have hosting configs in the JumpStart hub document.
10771212
This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs().
10781213
"""
1214+
self._resolve_base_model_fields()
10791215
model_package = self._fetch_model_package()
10801216
hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name
10811217

0 commit comments

Comments
 (0)