@@ -414,6 +414,7 @@ def __post_init__(self) -> None:
414414 if self .log_level is not None :
415415 logger .setLevel (self .log_level )
416416
417+ self ._base_model_fields_resolved : bool = False
417418 self ._warn_about_deprecated_parameters (warnings )
418419 self ._initialize_compute_config ()
419420 self ._initialize_network_config ()
@@ -688,8 +689,11 @@ def _resolve_base_model_fields(self):
688689 DescribeModelPackage API), this method resolves them automatically:
689690 - hub_content_version: resolved by calling HubContent.get on SageMakerPublicHub
690691 - recipe_name: resolved from the first recipe in the hub document's RecipeCollection
692+
693+ Note: HubContent.get() supports being called without hub_content_version,
694+ in which case it returns the latest version of the hub content.
691695 """
692- if hasattr ( self , "_base_model_fields_resolved" ) and self ._base_model_fields_resolved :
696+ if self ._base_model_fields_resolved :
693697 return
694698
695699 model_package = self ._fetch_model_package ()
@@ -720,6 +724,9 @@ def _resolve_base_model_fields(self):
720724 hub_content_version = getattr (base_model , "hub_content_version" , None )
721725 recipe_name = getattr (base_model , "recipe_name" , None )
722726
727+ # Cache the HubContent response to avoid redundant API calls
728+ cached_hub_content = None
729+
723730 # Resolve hub_content_version if missing
724731 if not hub_content_version or isinstance (hub_content_version , Unassigned ):
725732 logger .info (
@@ -728,20 +735,24 @@ def _resolve_base_model_fields(self):
728735 hub_content_name ,
729736 )
730737 try :
731- hc = HubContent .get (
738+ cached_hub_content = HubContent .get (
732739 hub_content_type = "Model" ,
733740 hub_name = "SageMakerPublicHub" ,
734741 hub_content_name = hub_content_name ,
735742 )
736- base_model .hub_content_version = hc .hub_content_version
743+ base_model .hub_content_version = (
744+ cached_hub_content .hub_content_version
745+ )
737746 logger .info (
738- "Resolved hub_content_version to '%s' for hub content '%s'." ,
739- hc .hub_content_version ,
747+ "Resolved hub_content_version to '%s' "
748+ "for hub content '%s'." ,
749+ cached_hub_content .hub_content_version ,
740750 hub_content_name ,
741751 )
742- except Exception as e :
752+ except ( ClientError , ValueError ) as e :
743753 logger .warning (
744- "Failed to resolve hub_content_version for hub content '%s': %s" ,
754+ "Failed to resolve hub_content_version "
755+ "for hub content '%s': %s" ,
745756 hub_content_name ,
746757 e ,
747758 )
@@ -752,63 +763,97 @@ def _resolve_base_model_fields(self):
752763 if not recipe_name or isinstance (recipe_name , Unassigned ):
753764 logger .info (
754765 "recipe_name is missing for hub content '%s'. "
755- "Resolving automatically from hub document RecipeCollection..." ,
766+ "Resolving from hub document RecipeCollection..." ,
756767 hub_content_name ,
757768 )
758769 try :
759- hub_content = HubContent .get (
760- hub_content_type = "Model" ,
761- hub_name = "SageMakerPublicHub" ,
762- hub_content_name = base_model .hub_content_name ,
763- hub_content_version = base_model .hub_content_version ,
770+ # Reuse cached hub content if available and version matches
771+ if (
772+ cached_hub_content is not None
773+ and cached_hub_content .hub_content_version
774+ == base_model .hub_content_version
775+ ):
776+ hub_content = cached_hub_content
777+ else :
778+ hub_content = HubContent .get (
779+ hub_content_type = "Model" ,
780+ hub_name = "SageMakerPublicHub" ,
781+ hub_content_name = base_model .hub_content_name ,
782+ hub_content_version = (
783+ base_model .hub_content_version
784+ ),
785+ )
786+ hub_document = json .loads (
787+ hub_content .hub_content_document
788+ )
789+ recipe_collection = hub_document .get (
790+ "RecipeCollection" , []
764791 )
765- hub_document = json .loads (hub_content .hub_content_document )
766- recipe_collection = hub_document .get ("RecipeCollection" , [])
767792 if recipe_collection :
768- resolved_recipe = recipe_collection [0 ].get ("Name" , "" )
793+ resolved_recipe = recipe_collection [0 ].get (
794+ "Name" , ""
795+ )
769796 if resolved_recipe :
770797 base_model .recipe_name = resolved_recipe
771798 logger .info (
772- "Resolved recipe_name to '%s' for hub content '%s'." ,
799+ "Resolved recipe_name to '%s' "
800+ "for hub content '%s'." ,
773801 resolved_recipe ,
774802 hub_content_name ,
775803 )
776804 else :
777805 logger .warning (
778- "RecipeCollection found but first recipe has no Name for hub content '%s'." ,
806+ "RecipeCollection found but first "
807+ "recipe has no Name for hub "
808+ "content '%s'." ,
779809 hub_content_name ,
780810 )
781811 else :
782812 logger .warning (
783- "No RecipeCollection found in hub document for hub content '%s'. "
784- "recipe_name could not be auto-resolved." ,
813+ "No RecipeCollection found in hub "
814+ "document for hub content '%s'. "
815+ "recipe_name could not be "
816+ "auto-resolved." ,
785817 hub_content_name ,
786818 )
787- except Exception as e :
819+ except ( ClientError , ValueError ) as e :
788820 logger .warning (
789- "Failed to resolve recipe_name for hub content '%s': %s" ,
821+ "Failed to resolve recipe_name "
822+ "for hub content '%s': %s" ,
790823 hub_content_name ,
791824 e ,
792825 )
793826
794827 self ._base_model_fields_resolved = True
795828
796829 def _fetch_hub_document_for_custom_model (self ) -> dict :
830+ """Fetch the hub document for a custom (fine-tuned) model.
831+
832+ Calls _resolve_base_model_fields() first to ensure hub_content_version
833+ is populated. If hub_content_version is still Unassigned after
834+ resolution (e.g. resolution failed), HubContent.get() is called
835+ without a version parameter, which returns the latest version.
836+ """
797837 from sagemaker .core .shapes import BaseModel as CoreBaseModel
798838
799839 self ._resolve_base_model_fields ()
800840
801841 base_model : CoreBaseModel = (
802- self ._fetch_model_package ().inference_specification .containers [0 ].base_model
842+ self ._fetch_model_package ()
843+ .inference_specification .containers [0 ].base_model
803844 )
804845
805- hub_content_version = getattr (base_model , "hub_content_version" , None )
846+ hub_content_version = getattr (
847+ base_model , "hub_content_version" , None
848+ )
806849 get_kwargs = dict (
807850 hub_content_type = "Model" ,
808851 hub_name = "SageMakerPublicHub" ,
809852 hub_content_name = base_model .hub_content_name ,
810853 )
811- if hub_content_version and not isinstance (hub_content_version , Unassigned ):
854+ if hub_content_version and not isinstance (
855+ hub_content_version , Unassigned
856+ ):
812857 get_kwargs ["hub_content_version" ] = hub_content_version
813858
814859 hub_content = HubContent .get (** get_kwargs )
@@ -1057,19 +1102,28 @@ def _is_gpu_instance(self, instance_type: str) -> bool:
10571102
10581103 def _fetch_and_cache_recipe_config (self ):
10591104 """Fetch and cache image URI, compute requirements, and s3_upload_path from recipe during build."""
1060- self ._resolve_base_model_fields ()
1105+ # _fetch_hub_document_for_custom_model calls _resolve_base_model_fields
1106+ # internally, so no need to call it separately here.
10611107 hub_document = self ._fetch_hub_document_for_custom_model ()
10621108 model_package = self ._fetch_model_package ()
1063- recipe_name = getattr (
1064- model_package .inference_specification .containers [0 ].base_model , "recipe_name" , None
1109+ base_model = (
1110+ model_package .inference_specification
1111+ .containers [0 ].base_model
1112+ )
1113+ hub_content_name = getattr (
1114+ base_model , "hub_content_name" , "unknown"
10651115 )
1116+ recipe_name = getattr (base_model , "recipe_name" , None )
10661117 if not recipe_name or isinstance (recipe_name , Unassigned ):
10671118 raise ValueError (
1068- "recipe_name is missing from the model package's BaseModel and could not be "
1069- "auto-resolved from the hub document. Please ensure the model package has a "
1070- "valid recipe_name set, or manually set it before calling build(). "
1071- "Example: model_package.inference_specification.containers[0].base_model"
1072- ".recipe_name = 'your-recipe-name'"
1119+ f"recipe_name is missing from the model package's "
1120+ f"BaseModel (hub_content_name='{ hub_content_name } ') "
1121+ f"and could not be auto-resolved from the hub "
1122+ f"document. Please ensure the model package has a "
1123+ f"valid recipe_name set, or manually set it before "
1124+ f"calling build(). Example: model_package."
1125+ f"inference_specification.containers[0].base_model"
1126+ f".recipe_name = 'your-recipe-name'"
10731127 )
10741128
10751129 if not self .s3_upload_path :
@@ -1213,7 +1267,10 @@ def _get_nova_hosting_config(self, instance_type=None):
12131267 """
12141268 self ._resolve_base_model_fields ()
12151269 model_package = self ._fetch_model_package ()
1216- hub_content_name = model_package .inference_specification .containers [0 ].base_model .hub_content_name
1270+ hub_content_name = (
1271+ model_package .inference_specification
1272+ .containers [0 ].base_model .hub_content_name
1273+ )
12171274
12181275 configs = self ._NOVA_HOSTING_CONFIGS .get (hub_content_name )
12191276 if not configs :
0 commit comments