|
52 | 52 | ModelCard, |
53 | 53 | ModelPackageModelCard, |
54 | 54 | ) |
55 | | -from sagemaker.core.utils.utils import logger |
| 55 | +from sagemaker.core.utils.utils import logger, Unassigned |
56 | 56 | from sagemaker.core.helper import session_helper |
57 | 57 | from sagemaker.core.helper.session_helper import ( |
58 | 58 | Session, |
@@ -380,6 +380,7 @@ class ModelBuilder(_InferenceRecommenderMixin, _ModelBuilderServers, _ModelBuild |
380 | 380 | _tags: Optional[Tags] = field(default=None, init=False) |
381 | 381 | _optimizing: bool = field(default=False, init=False) |
382 | 382 | _deployment_config: Optional[Dict[str, Any]] = field(default=None, init=False) |
| 383 | + _base_model_fields_resolved: bool = field(default=False, init=False) |
383 | 384 |
|
384 | 385 | shared_libs: List[str] = field( |
385 | 386 | default_factory=list, |
@@ -680,18 +681,198 @@ def _infer_instance_type_from_jumpstart(self) -> str: |
680 | 681 |
|
681 | 682 | raise ValueError(error_msg) |
682 | 683 |
|
| 684 | + @staticmethod |
| 685 | + def _normalize_field(value: object, default: str = "") -> str: |
| 686 | + """Normalize a field value, replacing Unassigned or falsy with default.""" |
| 687 | + if not value or isinstance(value, Unassigned): |
| 688 | + return default |
| 689 | + return value |
| 690 | + |
| 691 | + def _get_base_model_from_package(self) -> object: |
| 692 | + """Extract the base_model from the model package, or return None. |
| 693 | +
|
| 694 | + Returns: |
| 695 | + The base_model object if available, or None if the model package |
| 696 | + does not have a base_model (e.g., no inference_specification, |
| 697 | + no containers, or no base_model on the first container). |
| 698 | + """ |
| 699 | + model_package = self._fetch_model_package() |
| 700 | + if not model_package: |
| 701 | + return None |
| 702 | + inference_spec = getattr(model_package, "inference_specification", None) |
| 703 | + if not inference_spec: |
| 704 | + return None |
| 705 | + containers = getattr(inference_spec, "containers", None) |
| 706 | + if not containers: |
| 707 | + return None |
| 708 | + return getattr(containers[0], "base_model", None) |
| 709 | + |
| 710 | + def _resolve_base_model_fields(self) -> None: |
| 711 | + """Auto-resolve missing BaseModel fields (hub_content_version, recipe_name). |
| 712 | +
|
| 713 | + When a ModelPackage's BaseModel has hub_content_name set but is missing |
| 714 | + hub_content_version and/or recipe_name (returned as Unassigned from the |
| 715 | + DescribeModelPackage API), this method resolves them automatically: |
| 716 | + - hub_content_version: resolved by calling HubContent.get on SageMakerPublicHub |
| 717 | + - recipe_name: resolved from the first recipe in the hub document's RecipeCollection |
| 718 | +
|
| 719 | + Note: HubContent.get() supports being called without hub_content_version, |
| 720 | + in which case it returns the latest version of the hub content. |
| 721 | + """ |
| 722 | + if self._base_model_fields_resolved: |
| 723 | + return |
| 724 | + |
| 725 | + base_model = self._get_base_model_from_package() |
| 726 | + if not base_model: |
| 727 | + self._base_model_fields_resolved = True |
| 728 | + return |
| 729 | + |
| 730 | + hub_content_name = getattr(base_model, "hub_content_name", None) |
| 731 | + if not hub_content_name or isinstance(hub_content_name, Unassigned): |
| 732 | + self._base_model_fields_resolved = True |
| 733 | + return |
| 734 | + |
| 735 | + hub_content_version = getattr(base_model, "hub_content_version", None) |
| 736 | + recipe_name = getattr(base_model, "recipe_name", None) |
| 737 | + |
| 738 | + # Cache the HubContent response to avoid redundant API calls |
| 739 | + # when resolving both hub_content_version and recipe_name. |
| 740 | + cached_hub_content = None |
| 741 | + |
| 742 | + # Resolve hub_content_version if missing |
| 743 | + if not hub_content_version or isinstance(hub_content_version, Unassigned): |
| 744 | + logger.info( |
| 745 | + "hub_content_version is missing for hub content '%s'. " |
| 746 | + "Resolving automatically from SageMakerPublicHub...", |
| 747 | + hub_content_name, |
| 748 | + ) |
| 749 | + try: |
| 750 | + cached_hub_content = HubContent.get( |
| 751 | + hub_content_type="Model", |
| 752 | + hub_name="SageMakerPublicHub", |
| 753 | + hub_content_name=hub_content_name, |
| 754 | + ) |
| 755 | + base_model.hub_content_version = ( |
| 756 | + cached_hub_content.hub_content_version |
| 757 | + ) |
| 758 | + logger.info( |
| 759 | + "Resolved hub_content_version to '%s' " |
| 760 | + "for hub content '%s'.", |
| 761 | + cached_hub_content.hub_content_version, |
| 762 | + hub_content_name, |
| 763 | + ) |
| 764 | + except ClientError as e: |
| 765 | + # Only swallow ResourceNotFoundException; re-raise permission |
| 766 | + # errors (AccessDeniedException, etc.) so they aren't masked. |
| 767 | + if e.response["Error"]["Code"] == "ResourceNotFoundException": |
| 768 | + logger.warning( |
| 769 | + "Failed to resolve hub_content_version " |
| 770 | + "for hub content '%s': %s", |
| 771 | + hub_content_name, |
| 772 | + e, |
| 773 | + ) |
| 774 | + self._base_model_fields_resolved = True |
| 775 | + return |
| 776 | + raise |
| 777 | + |
| 778 | + # Resolve recipe_name if missing |
| 779 | + if not recipe_name or isinstance(recipe_name, Unassigned): |
| 780 | + logger.info( |
| 781 | + "recipe_name is missing for hub content '%s'. " |
| 782 | + "Resolving from hub document RecipeCollection...", |
| 783 | + hub_content_name, |
| 784 | + ) |
| 785 | + try: |
| 786 | + # Reuse cached hub content if available (it was fetched |
| 787 | + # without a version, so it already has the latest). |
| 788 | + if cached_hub_content is not None: |
| 789 | + hub_content = cached_hub_content |
| 790 | + else: |
| 791 | + hub_content = HubContent.get( |
| 792 | + hub_content_type="Model", |
| 793 | + hub_name="SageMakerPublicHub", |
| 794 | + hub_content_name=base_model.hub_content_name, |
| 795 | + hub_content_version=( |
| 796 | + base_model.hub_content_version |
| 797 | + ), |
| 798 | + ) |
| 799 | + hub_document = json.loads( |
| 800 | + hub_content.hub_content_document |
| 801 | + ) |
| 802 | + recipe_collection = hub_document.get( |
| 803 | + "RecipeCollection", [] |
| 804 | + ) |
| 805 | + if recipe_collection: |
| 806 | + resolved_recipe = recipe_collection[0].get( |
| 807 | + "Name", "" |
| 808 | + ) |
| 809 | + if resolved_recipe: |
| 810 | + base_model.recipe_name = resolved_recipe |
| 811 | + logger.info( |
| 812 | + "Resolved recipe_name to '%s' " |
| 813 | + "for hub content '%s'.", |
| 814 | + resolved_recipe, |
| 815 | + hub_content_name, |
| 816 | + ) |
| 817 | + else: |
| 818 | + logger.warning( |
| 819 | + "RecipeCollection found but first " |
| 820 | + "recipe has no Name for hub " |
| 821 | + "content '%s'.", |
| 822 | + hub_content_name, |
| 823 | + ) |
| 824 | + else: |
| 825 | + logger.warning( |
| 826 | + "No RecipeCollection found in hub " |
| 827 | + "document for hub content '%s'. " |
| 828 | + "recipe_name could not be " |
| 829 | + "auto-resolved.", |
| 830 | + hub_content_name, |
| 831 | + ) |
| 832 | + except ClientError as e: |
| 833 | + if e.response["Error"]["Code"] == "ResourceNotFoundException": |
| 834 | + logger.warning( |
| 835 | + "Failed to resolve recipe_name " |
| 836 | + "for hub content '%s': %s", |
| 837 | + hub_content_name, |
| 838 | + e, |
| 839 | + ) |
| 840 | + else: |
| 841 | + raise |
| 842 | + |
| 843 | + self._base_model_fields_resolved = True |
| 844 | + |
683 | 845 | def _fetch_hub_document_for_custom_model(self) -> dict: |
| 846 | + """Fetch the hub document for a custom (fine-tuned) model. |
| 847 | +
|
| 848 | + Calls _resolve_base_model_fields() first to ensure hub_content_version |
| 849 | + is populated. If hub_content_version is still Unassigned after |
| 850 | + resolution (e.g. resolution failed), HubContent.get() is called |
| 851 | + without a version parameter, which returns the latest version. |
| 852 | + """ |
684 | 853 | from sagemaker.core.shapes import BaseModel as CoreBaseModel |
685 | 854 |
|
| 855 | + self._resolve_base_model_fields() |
| 856 | + |
| 857 | + model_package = self._fetch_model_package() |
686 | 858 | base_model: CoreBaseModel = ( |
687 | | - self._fetch_model_package().inference_specification.containers[0].base_model |
| 859 | + model_package.inference_specification.containers[0].base_model |
| 860 | + ) |
| 861 | + |
| 862 | + hub_content_version = getattr( |
| 863 | + base_model, "hub_content_version", None |
688 | 864 | ) |
689 | | - hub_content = HubContent.get( |
| 865 | + get_kwargs = dict( |
690 | 866 | hub_content_type="Model", |
691 | 867 | hub_name="SageMakerPublicHub", |
692 | 868 | hub_content_name=base_model.hub_content_name, |
693 | | - hub_content_version=base_model.hub_content_version, |
694 | 869 | ) |
| 870 | + if hub_content_version and not isinstance( |
| 871 | + hub_content_version, Unassigned |
| 872 | + ): |
| 873 | + get_kwargs["hub_content_version"] = hub_content_version |
| 874 | + |
| 875 | + hub_content = HubContent.get(**get_kwargs) |
695 | 876 | return json.loads(hub_content.hub_content_document) |
696 | 877 |
|
697 | 878 | def _fetch_hosting_configs_for_custom_model(self) -> dict: |
@@ -937,9 +1118,26 @@ def _is_gpu_instance(self, instance_type: str) -> bool: |
937 | 1118 |
|
938 | 1119 | def _fetch_and_cache_recipe_config(self): |
939 | 1120 | """Fetch and cache image URI, compute requirements, and s3_upload_path from recipe during build.""" |
| 1121 | + # _fetch_hub_document_for_custom_model calls _resolve_base_model_fields |
| 1122 | + # internally, so no need to call it separately here. |
940 | 1123 | hub_document = self._fetch_hub_document_for_custom_model() |
941 | 1124 | model_package = self._fetch_model_package() |
942 | | - recipe_name = model_package.inference_specification.containers[0].base_model.recipe_name |
| 1125 | + base_model = ( |
| 1126 | + model_package.inference_specification |
| 1127 | + .containers[0].base_model |
| 1128 | + ) |
| 1129 | + hub_content_name = getattr( |
| 1130 | + base_model, "hub_content_name", "unknown" |
| 1131 | + ) |
| 1132 | + recipe_name = getattr(base_model, "recipe_name", None) |
| 1133 | + if not recipe_name or isinstance(recipe_name, Unassigned): |
| 1134 | + raise ValueError( |
| 1135 | + f"recipe_name is missing from the model package's BaseModel " |
| 1136 | + f"(hub_content_name='{hub_content_name}') and could not be " |
| 1137 | + f"auto-resolved from the hub document. Please ensure the model " |
| 1138 | + f"package has a valid recipe_name set, or manually set it before " |
| 1139 | + f"calling build()." |
| 1140 | + ) |
943 | 1141 |
|
944 | 1142 | if not self.s3_upload_path: |
945 | 1143 | self.s3_upload_path = model_package.inference_specification.containers[ |
@@ -1050,27 +1248,39 @@ def _fetch_and_cache_recipe_config(self): |
1050 | 1248 |
|
1051 | 1249 | def _is_nova_model(self) -> bool: |
1052 | 1250 | """Check if the model is a Nova model based on recipe name or hub content name.""" |
1053 | | - model_package = self._fetch_model_package() |
1054 | | - if not model_package: |
1055 | | - return False |
1056 | | - containers = getattr(model_package.inference_specification, "containers", None) |
1057 | | - if not containers: |
1058 | | - return False |
1059 | | - base_model = getattr(containers[0], "base_model", None) |
| 1251 | + self._resolve_base_model_fields() |
| 1252 | + base_model = self._get_base_model_from_package() |
1060 | 1253 | if not base_model: |
1061 | 1254 | return False |
1062 | | - recipe_name = getattr(base_model, "recipe_name", "") or "" |
1063 | | - hub_content_name = getattr(base_model, "hub_content_name", "") or "" |
| 1255 | + recipe_name = self._normalize_field( |
| 1256 | + getattr(base_model, "recipe_name", ""), default="" |
| 1257 | + ) |
| 1258 | + hub_content_name = self._normalize_field( |
| 1259 | + getattr(base_model, "hub_content_name", ""), default="" |
| 1260 | + ) |
1064 | 1261 | return "nova" in recipe_name.lower() or "nova" in hub_content_name.lower() |
1065 | 1262 |
|
| 1263 | + def _is_nova_model_for_telemetry(self) -> bool: |
| 1264 | + """Check if the model is a Nova model for telemetry tracking.""" |
| 1265 | + try: |
| 1266 | + return self._is_nova_model() |
| 1267 | + except Exception: |
| 1268 | + return False |
| 1269 | + |
1066 | 1270 | def _get_nova_hosting_config(self, instance_type=None): |
1067 | 1271 | """Get Nova hosting config (image URI, env vars, instance type). |
1068 | 1272 |
|
1069 | 1273 | Nova training recipes don't have hosting configs in the JumpStart hub document. |
1070 | 1274 | This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs(). |
| 1275 | +
|
| 1276 | + Note: _resolve_base_model_fields() is already called by _is_nova_model(), |
| 1277 | + which gates all calls to this method. |
1071 | 1278 | """ |
1072 | 1279 | model_package = self._fetch_model_package() |
1073 | | - hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name |
| 1280 | + hub_content_name = ( |
| 1281 | + model_package.inference_specification |
| 1282 | + .containers[0].base_model.hub_content_name |
| 1283 | + ) |
1074 | 1284 |
|
1075 | 1285 | configs = self._NOVA_HOSTING_CONFIGS.get(hub_content_name) |
1076 | 1286 | if not configs: |
|
0 commit comments