Skip to content

Commit 5a5f0bd

Browse files
committed
fix: address review comments (iteration #2)
1 parent 7cdb047 commit 5a5f0bd

2 files changed

Lines changed: 757 additions & 160 deletions

File tree

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 225 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
ModelCard,
5353
ModelPackageModelCard,
5454
)
55-
from sagemaker.core.utils.utils import logger
55+
from sagemaker.core.utils.utils import logger, Unassigned
5656
from sagemaker.core.helper import session_helper
5757
from sagemaker.core.helper.session_helper import (
5858
Session,
@@ -380,6 +380,7 @@ class ModelBuilder(_InferenceRecommenderMixin, _ModelBuilderServers, _ModelBuild
380380
_tags: Optional[Tags] = field(default=None, init=False)
381381
_optimizing: bool = field(default=False, init=False)
382382
_deployment_config: Optional[Dict[str, Any]] = field(default=None, init=False)
383+
_base_model_fields_resolved: bool = field(default=False, init=False)
383384

384385
shared_libs: List[str] = field(
385386
default_factory=list,
@@ -680,18 +681,198 @@ def _infer_instance_type_from_jumpstart(self) -> str:
680681

681682
raise ValueError(error_msg)
682683

684+
@staticmethod
685+
def _normalize_field(value: object, default: str = "") -> str:
686+
"""Normalize a field value, replacing Unassigned or falsy with default."""
687+
if not value or isinstance(value, Unassigned):
688+
return default
689+
return value
690+
691+
def _get_base_model_from_package(self) -> object:
692+
"""Extract the base_model from the model package, or return None.
693+
694+
Returns:
695+
The base_model object if available, or None if the model package
696+
does not have a base_model (e.g., no inference_specification,
697+
no containers, or no base_model on the first container).
698+
"""
699+
model_package = self._fetch_model_package()
700+
if not model_package:
701+
return None
702+
inference_spec = getattr(model_package, "inference_specification", None)
703+
if not inference_spec:
704+
return None
705+
containers = getattr(inference_spec, "containers", None)
706+
if not containers:
707+
return None
708+
return getattr(containers[0], "base_model", None)
709+
710+
def _resolve_base_model_fields(self) -> None:
711+
"""Auto-resolve missing BaseModel fields (hub_content_version, recipe_name).
712+
713+
When a ModelPackage's BaseModel has hub_content_name set but is missing
714+
hub_content_version and/or recipe_name (returned as Unassigned from the
715+
DescribeModelPackage API), this method resolves them automatically:
716+
- hub_content_version: resolved by calling HubContent.get on SageMakerPublicHub
717+
- recipe_name: resolved from the first recipe in the hub document's RecipeCollection
718+
719+
Note: HubContent.get() supports being called without hub_content_version,
720+
in which case it returns the latest version of the hub content.
721+
"""
722+
if self._base_model_fields_resolved:
723+
return
724+
725+
base_model = self._get_base_model_from_package()
726+
if not base_model:
727+
self._base_model_fields_resolved = True
728+
return
729+
730+
hub_content_name = getattr(base_model, "hub_content_name", None)
731+
if not hub_content_name or isinstance(hub_content_name, Unassigned):
732+
self._base_model_fields_resolved = True
733+
return
734+
735+
hub_content_version = getattr(base_model, "hub_content_version", None)
736+
recipe_name = getattr(base_model, "recipe_name", None)
737+
738+
# Cache the HubContent response to avoid redundant API calls
739+
# when resolving both hub_content_version and recipe_name.
740+
cached_hub_content = None
741+
742+
# Resolve hub_content_version if missing
743+
if not hub_content_version or isinstance(hub_content_version, Unassigned):
744+
logger.info(
745+
"hub_content_version is missing for hub content '%s'. "
746+
"Resolving automatically from SageMakerPublicHub...",
747+
hub_content_name,
748+
)
749+
try:
750+
cached_hub_content = HubContent.get(
751+
hub_content_type="Model",
752+
hub_name="SageMakerPublicHub",
753+
hub_content_name=hub_content_name,
754+
)
755+
base_model.hub_content_version = (
756+
cached_hub_content.hub_content_version
757+
)
758+
logger.info(
759+
"Resolved hub_content_version to '%s' "
760+
"for hub content '%s'.",
761+
cached_hub_content.hub_content_version,
762+
hub_content_name,
763+
)
764+
except ClientError as e:
765+
# Only swallow ResourceNotFoundException; re-raise permission
766+
# errors (AccessDeniedException, etc.) so they aren't masked.
767+
if e.response["Error"]["Code"] == "ResourceNotFoundException":
768+
logger.warning(
769+
"Failed to resolve hub_content_version "
770+
"for hub content '%s': %s",
771+
hub_content_name,
772+
e,
773+
)
774+
self._base_model_fields_resolved = True
775+
return
776+
raise
777+
778+
# Resolve recipe_name if missing
779+
if not recipe_name or isinstance(recipe_name, Unassigned):
780+
logger.info(
781+
"recipe_name is missing for hub content '%s'. "
782+
"Resolving from hub document RecipeCollection...",
783+
hub_content_name,
784+
)
785+
try:
786+
# Reuse cached hub content if available (it was fetched
787+
# without a version, so it already has the latest).
788+
if cached_hub_content is not None:
789+
hub_content = cached_hub_content
790+
else:
791+
hub_content = HubContent.get(
792+
hub_content_type="Model",
793+
hub_name="SageMakerPublicHub",
794+
hub_content_name=base_model.hub_content_name,
795+
hub_content_version=(
796+
base_model.hub_content_version
797+
),
798+
)
799+
hub_document = json.loads(
800+
hub_content.hub_content_document
801+
)
802+
recipe_collection = hub_document.get(
803+
"RecipeCollection", []
804+
)
805+
if recipe_collection:
806+
resolved_recipe = recipe_collection[0].get(
807+
"Name", ""
808+
)
809+
if resolved_recipe:
810+
base_model.recipe_name = resolved_recipe
811+
logger.info(
812+
"Resolved recipe_name to '%s' "
813+
"for hub content '%s'.",
814+
resolved_recipe,
815+
hub_content_name,
816+
)
817+
else:
818+
logger.warning(
819+
"RecipeCollection found but first "
820+
"recipe has no Name for hub "
821+
"content '%s'.",
822+
hub_content_name,
823+
)
824+
else:
825+
logger.warning(
826+
"No RecipeCollection found in hub "
827+
"document for hub content '%s'. "
828+
"recipe_name could not be "
829+
"auto-resolved.",
830+
hub_content_name,
831+
)
832+
except ClientError as e:
833+
if e.response["Error"]["Code"] == "ResourceNotFoundException":
834+
logger.warning(
835+
"Failed to resolve recipe_name "
836+
"for hub content '%s': %s",
837+
hub_content_name,
838+
e,
839+
)
840+
else:
841+
raise
842+
843+
self._base_model_fields_resolved = True
844+
683845
def _fetch_hub_document_for_custom_model(self) -> dict:
846+
"""Fetch the hub document for a custom (fine-tuned) model.
847+
848+
Calls _resolve_base_model_fields() first to ensure hub_content_version
849+
is populated. If hub_content_version is still Unassigned after
850+
resolution (e.g. resolution failed), HubContent.get() is called
851+
without a version parameter, which returns the latest version.
852+
"""
684853
from sagemaker.core.shapes import BaseModel as CoreBaseModel
685854

855+
self._resolve_base_model_fields()
856+
857+
model_package = self._fetch_model_package()
686858
base_model: CoreBaseModel = (
687-
self._fetch_model_package().inference_specification.containers[0].base_model
859+
model_package.inference_specification.containers[0].base_model
860+
)
861+
862+
hub_content_version = getattr(
863+
base_model, "hub_content_version", None
688864
)
689-
hub_content = HubContent.get(
865+
get_kwargs = dict(
690866
hub_content_type="Model",
691867
hub_name="SageMakerPublicHub",
692868
hub_content_name=base_model.hub_content_name,
693-
hub_content_version=base_model.hub_content_version,
694869
)
870+
if hub_content_version and not isinstance(
871+
hub_content_version, Unassigned
872+
):
873+
get_kwargs["hub_content_version"] = hub_content_version
874+
875+
hub_content = HubContent.get(**get_kwargs)
695876
return json.loads(hub_content.hub_content_document)
696877

697878
def _fetch_hosting_configs_for_custom_model(self) -> dict:
@@ -937,9 +1118,26 @@ def _is_gpu_instance(self, instance_type: str) -> bool:
9371118

9381119
def _fetch_and_cache_recipe_config(self):
9391120
"""Fetch and cache image URI, compute requirements, and s3_upload_path from recipe during build."""
1121+
# _fetch_hub_document_for_custom_model calls _resolve_base_model_fields
1122+
# internally, so no need to call it separately here.
9401123
hub_document = self._fetch_hub_document_for_custom_model()
9411124
model_package = self._fetch_model_package()
942-
recipe_name = model_package.inference_specification.containers[0].base_model.recipe_name
1125+
base_model = (
1126+
model_package.inference_specification
1127+
.containers[0].base_model
1128+
)
1129+
hub_content_name = getattr(
1130+
base_model, "hub_content_name", "unknown"
1131+
)
1132+
recipe_name = getattr(base_model, "recipe_name", None)
1133+
if not recipe_name or isinstance(recipe_name, Unassigned):
1134+
raise ValueError(
1135+
f"recipe_name is missing from the model package's BaseModel "
1136+
f"(hub_content_name='{hub_content_name}') and could not be "
1137+
f"auto-resolved from the hub document. Please ensure the model "
1138+
f"package has a valid recipe_name set, or manually set it before "
1139+
f"calling build()."
1140+
)
9431141

9441142
if not self.s3_upload_path:
9451143
self.s3_upload_path = model_package.inference_specification.containers[
@@ -1050,27 +1248,39 @@ def _fetch_and_cache_recipe_config(self):
10501248

10511249
def _is_nova_model(self) -> bool:
10521250
"""Check if the model is a Nova model based on recipe name or hub content name."""
1053-
model_package = self._fetch_model_package()
1054-
if not model_package:
1055-
return False
1056-
containers = getattr(model_package.inference_specification, "containers", None)
1057-
if not containers:
1058-
return False
1059-
base_model = getattr(containers[0], "base_model", None)
1251+
self._resolve_base_model_fields()
1252+
base_model = self._get_base_model_from_package()
10601253
if not base_model:
10611254
return False
1062-
recipe_name = getattr(base_model, "recipe_name", "") or ""
1063-
hub_content_name = getattr(base_model, "hub_content_name", "") or ""
1255+
recipe_name = self._normalize_field(
1256+
getattr(base_model, "recipe_name", ""), default=""
1257+
)
1258+
hub_content_name = self._normalize_field(
1259+
getattr(base_model, "hub_content_name", ""), default=""
1260+
)
10641261
return "nova" in recipe_name.lower() or "nova" in hub_content_name.lower()
10651262

1263+
def _is_nova_model_for_telemetry(self) -> bool:
1264+
"""Check if the model is a Nova model for telemetry tracking."""
1265+
try:
1266+
return self._is_nova_model()
1267+
except Exception:
1268+
return False
1269+
10661270
def _get_nova_hosting_config(self, instance_type=None):
10671271
"""Get Nova hosting config (image URI, env vars, instance type).
10681272
10691273
Nova training recipes don't have hosting configs in the JumpStart hub document.
10701274
This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs().
1275+
1276+
Note: _resolve_base_model_fields() is already called by _is_nova_model(),
1277+
which gates all calls to this method.
10711278
"""
10721279
model_package = self._fetch_model_package()
1073-
hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name
1280+
hub_content_name = (
1281+
model_package.inference_specification
1282+
.containers[0].base_model.hub_content_name
1283+
)
10741284

10751285
configs = self._NOVA_HOSTING_CONFIGS.get(hub_content_name)
10761286
if not configs:

0 commit comments

Comments
 (0)