Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 199 additions & 6 deletions sagemaker-serve/src/sagemaker/serve/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
ModelCard,
ModelPackageModelCard,
)
from sagemaker.core.utils.utils import logger
from sagemaker.core.utils.utils import logger, Unassigned
from sagemaker.core.helper import session_helper
from sagemaker.core.helper.session_helper import (
Session,
Expand Down Expand Up @@ -414,6 +414,7 @@ def __post_init__(self) -> None:
if self.log_level is not None:
logger.setLevel(self.log_level)

self._base_model_fields_resolved: bool = False
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This instance attribute is set in __post_init__ but lacks a type annotation on the class body. Per SDK conventions (PEP 484, Pydantic BaseModel), all class attributes should have type annotations. Consider declaring this as a class-level field:

_base_model_fields_resolved: bool = False

in the class body (or as a PrivateAttr if this is a Pydantic model), rather than setting it in __post_init__.

self._warn_about_deprecated_parameters(warnings)
self._initialize_compute_config()
self._initialize_network_config()
Expand Down Expand Up @@ -680,18 +681,182 @@ def _infer_instance_type_from_jumpstart(self) -> str:

raise ValueError(error_msg)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type annotation on the return type. Per SDK conventions, all public and non-trivial private methods should have type annotations:

def _resolve_base_model_fields(self) -> None:

def _resolve_base_model_fields(self):
"""Auto-resolve missing BaseModel fields (hub_content_version, recipe_name).

When a ModelPackage's BaseModel has hub_content_name set but is missing
hub_content_version and/or recipe_name (returned as Unassigned from the
DescribeModelPackage API), this method resolves them automatically:
- hub_content_version: resolved by calling HubContent.get on SageMakerPublicHub
- recipe_name: resolved from the first recipe in the hub document's RecipeCollection

Note: HubContent.get() supports being called without hub_content_version,
in which case it returns the latest version of the hub content.
"""
if self._base_model_fields_resolved:
return

model_package = self._fetch_model_package()
if not model_package:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The early-return pattern with self._base_model_fields_resolved = True is repeated 4 times before the main logic. Consider restructuring to reduce duplication — e.g., extract the chain of getattr checks into a helper that returns the base_model or None, then have a single early return:

base_model = self._get_base_model_from_package()
if not base_model:
    self._base_model_fields_resolved = True
    return

This would also make the method significantly shorter and more readable.

self._base_model_fields_resolved = True
return

inference_spec = getattr(model_package, "inference_specification", None)
if not inference_spec:
self._base_model_fields_resolved = True
return

containers = getattr(inference_spec, "containers", None)
if not containers:
self._base_model_fields_resolved = True
return

base_model = getattr(containers[0], "base_model", None)
if not base_model:
self._base_model_fields_resolved = True
return

hub_content_name = getattr(base_model, "hub_content_name", None)
if not hub_content_name or isinstance(hub_content_name, Unassigned):
self._base_model_fields_resolved = True
return

hub_content_version = getattr(base_model, "hub_content_version", None)
recipe_name = getattr(base_model, "recipe_name", None)

# Cache the HubContent response to avoid redundant API calls
cached_hub_content = None

# Resolve hub_content_version if missing
if not hub_content_version or isinstance(hub_content_version, Unassigned):
logger.info(
"hub_content_version is missing for hub content '%s'. "
"Resolving automatically from SageMakerPublicHub...",
hub_content_name,
)
try:
cached_hub_content = HubContent.get(
hub_content_type="Model",
hub_name="SageMakerPublicHub",
hub_content_name=hub_content_name,
)
base_model.hub_content_version = (
cached_hub_content.hub_content_version
)
logger.info(
"Resolved hub_content_version to '%s' "
"for hub content '%s'.",
cached_hub_content.hub_content_version,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching ValueError here is quite broad. What ValueError would HubContent.get() raise? If this is for input validation errors from sagemaker-core, that's fine, but consider documenting why ValueError is caught alongside ClientError. Also, if the ClientError is something other than ResourceNotFoundException (e.g., AccessDeniedException), silently swallowing it and returning early could mask real permission issues. Consider only catching specific error codes:

except ClientError as e:
    if e.response['Error']['Code'] == 'ResourceNotFoundException':
        logger.warning(...)
    else:
        raise

hub_content_name,
)
except (ClientError, ValueError) as e:
logger.warning(
"Failed to resolve hub_content_version "
"for hub content '%s': %s",
hub_content_name,
e,
)
self._base_model_fields_resolved = True
return

# Resolve recipe_name if missing
if not recipe_name or isinstance(recipe_name, Unassigned):
logger.info(
"recipe_name is missing for hub content '%s'. "
"Resolving from hub document RecipeCollection...",
hub_content_name,
)
try:
# Reuse cached hub content if available and version matches
if (
cached_hub_content is not None
and cached_hub_content.hub_content_version
== base_model.hub_content_version
):
hub_content = cached_hub_content
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The version comparison cached_hub_content.hub_content_version == base_model.hub_content_version — at this point base_model.hub_content_version was just set to cached_hub_content.hub_content_version a few lines above (line 741), so this condition will always be True when cached_hub_content is not None. The else branch (re-fetching with a specific version) is dead code in the current flow. Consider simplifying by always reusing the cached hub content when available, or adding a comment explaining when the else branch would be hit.

else:
hub_content = HubContent.get(
hub_content_type="Model",
hub_name="SageMakerPublicHub",
hub_content_name=base_model.hub_content_name,
hub_content_version=(
base_model.hub_content_version
),
)
hub_document = json.loads(
hub_content.hub_content_document
)
recipe_collection = hub_document.get(
"RecipeCollection", []
)
if recipe_collection:
resolved_recipe = recipe_collection[0].get(
"Name", ""
)
if resolved_recipe:
base_model.recipe_name = resolved_recipe
logger.info(
"Resolved recipe_name to '%s' "
"for hub content '%s'.",
resolved_recipe,
hub_content_name,
)
else:
logger.warning(
"RecipeCollection found but first "
"recipe has no Name for hub "
"content '%s'.",
hub_content_name,
)
else:
logger.warning(
"No RecipeCollection found in hub "
"document for hub content '%s'. "
"recipe_name could not be "
"auto-resolved.",
hub_content_name,
)
except (ClientError, ValueError) as e:
logger.warning(
"Failed to resolve recipe_name "
"for hub content '%s': %s",
hub_content_name,
e,
)

self._base_model_fields_resolved = True

def _fetch_hub_document_for_custom_model(self) -> dict:
"""Fetch the hub document for a custom (fine-tuned) model.

Calls _resolve_base_model_fields() first to ensure hub_content_version
is populated. If hub_content_version is still Unassigned after
resolution (e.g. resolution failed), HubContent.get() is called
without a version parameter, which returns the latest version.
"""
from sagemaker.core.shapes import BaseModel as CoreBaseModel

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line break in the middle of the attribute chain is unusual and reduces readability:

        base_model: CoreBaseModel = (
            self._fetch_model_package()
            .inference_specification.containers[0].base_model
        )

Consider breaking it more naturally:

        model_package = self._fetch_model_package()
        base_model: CoreBaseModel = (
            model_package.inference_specification.containers[0].base_model
        )

This also avoids the long chained call and is consistent with how it's done in _fetch_and_cache_recipe_config.

self._resolve_base_model_fields()

base_model: CoreBaseModel = (
self._fetch_model_package().inference_specification.containers[0].base_model
self._fetch_model_package()
.inference_specification.containers[0].base_model
)

hub_content_version = getattr(
base_model, "hub_content_version", None
)
hub_content = HubContent.get(
get_kwargs = dict(
hub_content_type="Model",
hub_name="SageMakerPublicHub",
hub_content_name=base_model.hub_content_name,
hub_content_version=base_model.hub_content_version,
)
if hub_content_version and not isinstance(
hub_content_version, Unassigned
):
get_kwargs["hub_content_version"] = hub_content_version

hub_content = HubContent.get(**get_kwargs)
return json.loads(hub_content.hub_content_document)

def _fetch_hosting_configs_for_custom_model(self) -> dict:
Expand Down Expand Up @@ -937,9 +1102,29 @@ def _is_gpu_instance(self, instance_type: str) -> bool:

def _fetch_and_cache_recipe_config(self):
"""Fetch and cache image URI, compute requirements, and s3_upload_path from recipe during build."""
# _fetch_hub_document_for_custom_model calls _resolve_base_model_fields
# internally, so no need to call it separately here.
hub_document = self._fetch_hub_document_for_custom_model()
model_package = self._fetch_model_package()
recipe_name = model_package.inference_specification.containers[0].base_model.recipe_name
base_model = (
model_package.inference_specification
.containers[0].base_model
)
hub_content_name = getattr(
base_model, "hub_content_name", "unknown"
)
recipe_name = getattr(base_model, "recipe_name", None)
if not recipe_name or isinstance(recipe_name, Unassigned):
raise ValueError(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message uses multiple f-string concatenations (f"..." f"..."). While this works, it's harder to read. Consider using a single f-string with line continuation or a regular multi-line string:

raise ValueError(
    f"recipe_name is missing from the model package's BaseModel "
    f"(hub_content_name='{hub_content_name}') and could not be "
    f"auto-resolved from the hub document. Please ensure the model "
    f"package has a valid recipe_name set, or manually set it before "
    f"calling build()."
)

Also, the example in the error message (model_package.inference_specification.containers[0].base_model.recipe_name = 'your-recipe-name') references internal SDK objects — users may not have access to the model_package object. Consider suggesting a user-facing API instead.

f"recipe_name is missing from the model package's "
f"BaseModel (hub_content_name='{hub_content_name}') "
f"and could not be auto-resolved from the hub "
f"document. Please ensure the model package has a "
f"valid recipe_name set, or manually set it before "
f"calling build(). Example: model_package."
f"inference_specification.containers[0].base_model"
f".recipe_name = 'your-recipe-name'"
)

if not self.s3_upload_path:
self.s3_upload_path = model_package.inference_specification.containers[
Expand Down Expand Up @@ -1060,7 +1245,11 @@ def _is_nova_model(self) -> bool:
if not base_model:
return False
recipe_name = getattr(base_model, "recipe_name", "") or ""
if isinstance(recipe_name, Unassigned):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good defensive fix for _is_nova_model. However, the pattern recipe_name = getattr(base_model, "recipe_name", "") or "" followed by if isinstance(recipe_name, Unassigned): recipe_name = "" could be simplified. Since Unassigned is truthy (it's an object), the or "" won't catch it. Consider a utility function to normalize Unassigned values:

def _normalize_field(value, default=""):
    if not value or isinstance(value, Unassigned):
        return default
    return value

recipe_name = ""
hub_content_name = getattr(base_model, "hub_content_name", "") or ""
if isinstance(hub_content_name, Unassigned):
hub_content_name = ""
return "nova" in recipe_name.lower() or "nova" in hub_content_name.lower()

def _is_nova_model_for_telemetry(self) -> bool:
Expand All @@ -1076,8 +1265,12 @@ def _get_nova_hosting_config(self, instance_type=None):
Nova training recipes don't have hosting configs in the JumpStart hub document.
This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs().
"""
self._resolve_base_model_fields()
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _resolve_base_model_fields() call was added here in _get_nova_hosting_config, but _is_nova_model() (which gates the call to this method) does NOT call _resolve_base_model_fields(). This means _is_nova_model() could return False for a Nova model if hub_content_name is Unassigned, and _get_nova_hosting_config would never be reached. Should _is_nova_model() also call _resolve_base_model_fields() first?

model_package = self._fetch_model_package()
hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name
hub_content_name = (
model_package.inference_specification
.containers[0].base_model.hub_content_name
)

configs = self._NOVA_HOSTING_CONFIGS.get(hub_content_name)
if not configs:
Expand Down
Loading
Loading