Skip to content

Commit 7a80bae

Browse files
committed
fix: address review comments (iteration #1)
1 parent 34829ce commit 7a80bae

File tree

2 files changed

+369
-92
lines changed

2 files changed

+369
-92
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 91 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ def __post_init__(self) -> None:
414414
if self.log_level is not None:
415415
logger.setLevel(self.log_level)
416416

417+
self._base_model_fields_resolved: bool = False
417418
self._warn_about_deprecated_parameters(warnings)
418419
self._initialize_compute_config()
419420
self._initialize_network_config()
@@ -688,8 +689,11 @@ def _resolve_base_model_fields(self):
688689
DescribeModelPackage API), this method resolves them automatically:
689690
- hub_content_version: resolved by calling HubContent.get on SageMakerPublicHub
690691
- recipe_name: resolved from the first recipe in the hub document's RecipeCollection
692+
693+
Note: HubContent.get() supports being called without hub_content_version,
694+
in which case it returns the latest version of the hub content.
691695
"""
692-
if hasattr(self, "_base_model_fields_resolved") and self._base_model_fields_resolved:
696+
if self._base_model_fields_resolved:
693697
return
694698

695699
model_package = self._fetch_model_package()
@@ -720,6 +724,9 @@ def _resolve_base_model_fields(self):
720724
hub_content_version = getattr(base_model, "hub_content_version", None)
721725
recipe_name = getattr(base_model, "recipe_name", None)
722726

727+
# Cache the HubContent response to avoid redundant API calls
728+
cached_hub_content = None
729+
723730
# Resolve hub_content_version if missing
724731
if not hub_content_version or isinstance(hub_content_version, Unassigned):
725732
logger.info(
@@ -728,20 +735,24 @@ def _resolve_base_model_fields(self):
728735
hub_content_name,
729736
)
730737
try:
731-
hc = HubContent.get(
738+
cached_hub_content = HubContent.get(
732739
hub_content_type="Model",
733740
hub_name="SageMakerPublicHub",
734741
hub_content_name=hub_content_name,
735742
)
736-
base_model.hub_content_version = hc.hub_content_version
743+
base_model.hub_content_version = (
744+
cached_hub_content.hub_content_version
745+
)
737746
logger.info(
738-
"Resolved hub_content_version to '%s' for hub content '%s'.",
739-
hc.hub_content_version,
747+
"Resolved hub_content_version to '%s' "
748+
"for hub content '%s'.",
749+
cached_hub_content.hub_content_version,
740750
hub_content_name,
741751
)
742-
except Exception as e:
752+
except (ClientError, ValueError) as e:
743753
logger.warning(
744-
"Failed to resolve hub_content_version for hub content '%s': %s",
754+
"Failed to resolve hub_content_version "
755+
"for hub content '%s': %s",
745756
hub_content_name,
746757
e,
747758
)
@@ -752,63 +763,97 @@ def _resolve_base_model_fields(self):
752763
if not recipe_name or isinstance(recipe_name, Unassigned):
753764
logger.info(
754765
"recipe_name is missing for hub content '%s'. "
755-
"Resolving automatically from hub document RecipeCollection...",
766+
"Resolving from hub document RecipeCollection...",
756767
hub_content_name,
757768
)
758769
try:
759-
hub_content = HubContent.get(
760-
hub_content_type="Model",
761-
hub_name="SageMakerPublicHub",
762-
hub_content_name=base_model.hub_content_name,
763-
hub_content_version=base_model.hub_content_version,
770+
# Reuse cached hub content if available and version matches
771+
if (
772+
cached_hub_content is not None
773+
and cached_hub_content.hub_content_version
774+
== base_model.hub_content_version
775+
):
776+
hub_content = cached_hub_content
777+
else:
778+
hub_content = HubContent.get(
779+
hub_content_type="Model",
780+
hub_name="SageMakerPublicHub",
781+
hub_content_name=base_model.hub_content_name,
782+
hub_content_version=(
783+
base_model.hub_content_version
784+
),
785+
)
786+
hub_document = json.loads(
787+
hub_content.hub_content_document
788+
)
789+
recipe_collection = hub_document.get(
790+
"RecipeCollection", []
764791
)
765-
hub_document = json.loads(hub_content.hub_content_document)
766-
recipe_collection = hub_document.get("RecipeCollection", [])
767792
if recipe_collection:
768-
resolved_recipe = recipe_collection[0].get("Name", "")
793+
resolved_recipe = recipe_collection[0].get(
794+
"Name", ""
795+
)
769796
if resolved_recipe:
770797
base_model.recipe_name = resolved_recipe
771798
logger.info(
772-
"Resolved recipe_name to '%s' for hub content '%s'.",
799+
"Resolved recipe_name to '%s' "
800+
"for hub content '%s'.",
773801
resolved_recipe,
774802
hub_content_name,
775803
)
776804
else:
777805
logger.warning(
778-
"RecipeCollection found but first recipe has no Name for hub content '%s'.",
806+
"RecipeCollection found but first "
807+
"recipe has no Name for hub "
808+
"content '%s'.",
779809
hub_content_name,
780810
)
781811
else:
782812
logger.warning(
783-
"No RecipeCollection found in hub document for hub content '%s'. "
784-
"recipe_name could not be auto-resolved.",
813+
"No RecipeCollection found in hub "
814+
"document for hub content '%s'. "
815+
"recipe_name could not be "
816+
"auto-resolved.",
785817
hub_content_name,
786818
)
787-
except Exception as e:
819+
except (ClientError, ValueError) as e:
788820
logger.warning(
789-
"Failed to resolve recipe_name for hub content '%s': %s",
821+
"Failed to resolve recipe_name "
822+
"for hub content '%s': %s",
790823
hub_content_name,
791824
e,
792825
)
793826

794827
self._base_model_fields_resolved = True
795828

796829
def _fetch_hub_document_for_custom_model(self) -> dict:
830+
"""Fetch the hub document for a custom (fine-tuned) model.
831+
832+
Calls _resolve_base_model_fields() first to ensure hub_content_version
833+
is populated. If hub_content_version is still Unassigned after
834+
resolution (e.g. resolution failed), HubContent.get() is called
835+
without a version parameter, which returns the latest version.
836+
"""
797837
from sagemaker.core.shapes import BaseModel as CoreBaseModel
798838

799839
self._resolve_base_model_fields()
800840

801841
base_model: CoreBaseModel = (
802-
self._fetch_model_package().inference_specification.containers[0].base_model
842+
self._fetch_model_package()
843+
.inference_specification.containers[0].base_model
803844
)
804845

805-
hub_content_version = getattr(base_model, "hub_content_version", None)
846+
hub_content_version = getattr(
847+
base_model, "hub_content_version", None
848+
)
806849
get_kwargs = dict(
807850
hub_content_type="Model",
808851
hub_name="SageMakerPublicHub",
809852
hub_content_name=base_model.hub_content_name,
810853
)
811-
if hub_content_version and not isinstance(hub_content_version, Unassigned):
854+
if hub_content_version and not isinstance(
855+
hub_content_version, Unassigned
856+
):
812857
get_kwargs["hub_content_version"] = hub_content_version
813858

814859
hub_content = HubContent.get(**get_kwargs)
@@ -1057,19 +1102,28 @@ def _is_gpu_instance(self, instance_type: str) -> bool:
10571102

10581103
def _fetch_and_cache_recipe_config(self):
10591104
"""Fetch and cache image URI, compute requirements, and s3_upload_path from recipe during build."""
1060-
self._resolve_base_model_fields()
1105+
# _fetch_hub_document_for_custom_model calls _resolve_base_model_fields
1106+
# internally, so no need to call it separately here.
10611107
hub_document = self._fetch_hub_document_for_custom_model()
10621108
model_package = self._fetch_model_package()
1063-
recipe_name = getattr(
1064-
model_package.inference_specification.containers[0].base_model, "recipe_name", None
1109+
base_model = (
1110+
model_package.inference_specification
1111+
.containers[0].base_model
1112+
)
1113+
hub_content_name = getattr(
1114+
base_model, "hub_content_name", "unknown"
10651115
)
1116+
recipe_name = getattr(base_model, "recipe_name", None)
10661117
if not recipe_name or isinstance(recipe_name, Unassigned):
10671118
raise ValueError(
1068-
"recipe_name is missing from the model package's BaseModel and could not be "
1069-
"auto-resolved from the hub document. Please ensure the model package has a "
1070-
"valid recipe_name set, or manually set it before calling build(). "
1071-
"Example: model_package.inference_specification.containers[0].base_model"
1072-
".recipe_name = 'your-recipe-name'"
1119+
f"recipe_name is missing from the model package's "
1120+
f"BaseModel (hub_content_name='{hub_content_name}') "
1121+
f"and could not be auto-resolved from the hub "
1122+
f"document. Please ensure the model package has a "
1123+
f"valid recipe_name set, or manually set it before "
1124+
f"calling build(). Example: model_package."
1125+
f"inference_specification.containers[0].base_model"
1126+
f".recipe_name = 'your-recipe-name'"
10731127
)
10741128

10751129
if not self.s3_upload_path:
@@ -1213,7 +1267,10 @@ def _get_nova_hosting_config(self, instance_type=None):
12131267
"""
12141268
self._resolve_base_model_fields()
12151269
model_package = self._fetch_model_package()
1216-
hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name
1270+
hub_content_name = (
1271+
model_package.inference_specification
1272+
.containers[0].base_model.hub_content_name
1273+
)
12171274

12181275
configs = self._NOVA_HOSTING_CONFIGS.get(hub_content_name)
12191276
if not configs:

0 commit comments

Comments
 (0)