diff --git a/sagemaker-core/src/sagemaker/core/model_registry.py b/sagemaker-core/src/sagemaker/core/model_registry.py index 15f1300059..5576580ea3 100644 --- a/sagemaker-core/src/sagemaker/core/model_registry.py +++ b/sagemaker-core/src/sagemaker/core/model_registry.py @@ -12,6 +12,7 @@ MODEL_PACKAGE_INFERENCE_SPECIFICATION_CONTAINERS_PATH, MODEL_PACKAGE_VALIDATION_PROFILES_PATH, ) +from sagemaker.core.resources import ModelPackageModelCard from botocore.exceptions import ClientError import logging @@ -100,12 +101,12 @@ def get_model_package_args( if model_life_cycle is not None: model_package_args["model_life_cycle"] = model_life_cycle._to_request_dict() if model_card is not None: - original_req = model_card._create_request_args() - if original_req.get("ModelCardName") is not None: - del original_req["ModelCardName"] - if original_req.get("Content") is not None: - original_req["ModelCardContent"] = original_req["Content"] - del original_req["Content"] + original_req = {} + if isinstance(model_card, ModelPackageModelCard): + original_req["ModelCardContent"] = model_card.model_card_content + else: + original_req["ModelCardContent"] = model_card.content + original_req["ModelCardStatus"] = model_card.model_card_status model_package_args["model_card"] = original_req return model_package_args diff --git a/sagemaker-core/tests/unit/test_model_registry.py b/sagemaker-core/tests/unit/test_model_registry.py index 06e7a9ad1b..eedf667613 100644 --- a/sagemaker-core/tests/unit/test_model_registry.py +++ b/sagemaker-core/tests/unit/test_model_registry.py @@ -126,6 +126,38 @@ def test_get_model_package_args_with_all_params(self): assert args["skip_model_validation"] == "All" assert args["source_uri"] == "s3://bucket/source" + def test_get_model_package_args_model_card(self): + from sagemaker.core.shapes import ModelCard + + model_card = ModelCard() + model_card.content = '{"model_details": {"name": "test"}}' + model_card.model_card_status = "Approved" + + args = get_model_package_args( + model_card=model_card, + ) + + assert args["model_card"] == { + "ModelCardContent": '{"model_details": {"name": "test"}}', + "ModelCardStatus": "Approved", + } + + def test_get_model_package_args_model_package_model_card(self): + from sagemaker.core.shapes import ModelPackageModelCard + + model_card = ModelPackageModelCard() + model_card.model_card_content = '{"model_details": {"name": "test"}}' + model_card.model_card_status = "Approved" + + args = get_model_package_args( + model_card=model_card, + ) + + assert args["model_card"] == { + "ModelCardContent": '{"model_details": {"name": "test"}}', + "ModelCardStatus": "Approved", + } + def test_get_create_model_package_request_minimal(self): """Test get_create_model_package_request with minimal parameters""" request = get_create_model_package_request(