Skip to content

Commit d3dbc28

Browse files
Fix model registration with a model card (#5611)
* Add docker-compose path * Check for MacOS * Fix model registration with a model card * Account for both ModelCard and ModelPackageModelCard objects * Add unit tests for model card during model registration
1 parent 0906526 commit d3dbc28

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

sagemaker-core/src/sagemaker/core/model_registry.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
MODEL_PACKAGE_INFERENCE_SPECIFICATION_CONTAINERS_PATH,
1313
MODEL_PACKAGE_VALIDATION_PROFILES_PATH,
1414
)
15+
from sagemaker.core.resources import ModelPackageModelCard
1516
from botocore.exceptions import ClientError
1617
import logging
1718

@@ -100,12 +101,12 @@ def get_model_package_args(
100101
if model_life_cycle is not None:
101102
model_package_args["model_life_cycle"] = model_life_cycle._to_request_dict()
102103
if model_card is not None:
103-
original_req = model_card._create_request_args()
104-
if original_req.get("ModelCardName") is not None:
105-
del original_req["ModelCardName"]
106-
if original_req.get("Content") is not None:
107-
original_req["ModelCardContent"] = original_req["Content"]
108-
del original_req["Content"]
104+
original_req = {}
105+
if isinstance(model_card, ModelPackageModelCard):
106+
original_req["ModelCardContent"] = model_card.model_card_content
107+
else:
108+
original_req["ModelCardContent"] = model_card.content
109+
original_req["ModelCardStatus"] = model_card.model_card_status
109110
model_package_args["model_card"] = original_req
110111
return model_package_args
111112

sagemaker-core/tests/unit/test_model_registry.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,38 @@ def test_get_model_package_args_with_all_params(self):
126126
assert args["skip_model_validation"] == "All"
127127
assert args["source_uri"] == "s3://bucket/source"
128128

129+
def test_get_model_package_args_model_card(self):
130+
from sagemaker.core.shapes import ModelCard
131+
132+
model_card = ModelCard()
133+
model_card.content = '{"model_details": {"name": "test"}}'
134+
model_card.model_card_status = "Approved"
135+
136+
args = get_model_package_args(
137+
model_card=model_card,
138+
)
139+
140+
assert args["model_card"] == {
141+
"ModelCardContent": '{"model_details": {"name": "test"}}',
142+
"ModelCardStatus": "Approved",
143+
}
144+
145+
def test_get_model_package_args_model_package_model_card(self):
146+
from sagemaker.core.shapes import ModelPackageModelCard
147+
148+
model_card = ModelPackageModelCard()
149+
model_card.model_card_content = '{"model_details": {"name": "test"}}'
150+
model_card.model_card_status = "Approved"
151+
152+
args = get_model_package_args(
153+
model_card=model_card,
154+
)
155+
156+
assert args["model_card"] == {
157+
"ModelCardContent": '{"model_details": {"name": "test"}}',
158+
"ModelCardStatus": "Approved",
159+
}
160+
129161
def test_get_create_model_package_request_minimal(self):
130162
"""Test get_create_model_package_request with minimal parameters"""
131163
request = get_create_model_package_request(

0 commit comments

Comments
 (0)