Skip to content

Commit 6149d4a

Browse files
committed
test: add comprehensive version alias tests for HuggingFace image URIs
- Fix existing test_retrieve_huggingface assertions to expect resolved versions (4.2.1) instead of raw aliases (4.2) in image tags - Add parametrized test covering all 25 version aliases across training and inference scopes to verify aliases resolve correctly in tags - Add parametrized test covering all 26 non-aliased (full) versions to confirm no regression when passing complete version strings directly Made-with: Cursor
1 parent 4aa6abd commit 6149d4a

File tree

1 file changed

+84
-2
lines changed

1 file changed

+84
-2
lines changed

sagemaker-core/tests/unit/image_uris/test_retrieve.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def test_retrieve_huggingface(config_for_framework):
670670
)
671671
assert (
672672
"564829616587.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:"
673-
"1.6-transformers4.2-gpu-py37-cu110-ubuntu18.04" == pt_uri_mv
673+
"1.6-transformers4.2.1-gpu-py37-cu110-ubuntu18.04" == pt_uri_mv
674674
)
675675

676676
pt_uri = image_uris.retrieve(
@@ -715,7 +715,7 @@ def test_retrieve_huggingface(config_for_framework):
715715
)
716716
assert (
717717
"564829616587.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:"
718-
"1.6.0-transformers4.3.1-gpu-py37-cu110-ubuntu18.04" == pt_new_version
718+
"1.6.0-transformers4.2.1-gpu-py37-cu110-ubuntu18.04" == pt_new_version
719719
)
720720

721721

@@ -787,6 +787,88 @@ def test_get_latest_version_function_with_no_framework(config_for_framework):
787787
assert "No framework config for framework" in str(e.exception)
788788

789789

790+
def _get_huggingface_alias_test_cases():
791+
"""Build parametrized test cases for every HuggingFace version alias."""
792+
config = image_uris.config_for_framework("huggingface")
793+
cases = []
794+
for scope in ("training", "inference"):
795+
section = config[scope]
796+
aliases = section.get("version_aliases", {})
797+
for alias, resolved in aliases.items():
798+
ver_cfg = section["versions"][resolved]
799+
base_fws = [k for k in ver_cfg if k != "version_aliases"]
800+
base_fw = base_fws[0]
801+
py_ver = ver_cfg[base_fw]["py_versions"][0]
802+
inst = "ml.p3.2xlarge" if scope == "training" else "ml.c5.xlarge"
803+
cases.append(
804+
pytest.param(scope, alias, resolved, base_fw, py_ver, inst,
805+
id=f"{scope}-{alias}->{resolved}")
806+
)
807+
return cases
808+
809+
810+
def _get_huggingface_full_version_test_cases():
811+
"""Build parametrized test cases for every non-aliased HuggingFace version."""
812+
config = image_uris.config_for_framework("huggingface")
813+
cases = []
814+
for scope in ("training", "inference"):
815+
section = config[scope]
816+
for full_ver, ver_cfg in section["versions"].items():
817+
base_fws = [k for k in ver_cfg if k != "version_aliases"]
818+
base_fw = base_fws[0]
819+
py_ver = ver_cfg[base_fw]["py_versions"][0]
820+
inst = "ml.p3.2xlarge" if scope == "training" else "ml.c5.xlarge"
821+
cases.append(
822+
pytest.param(scope, full_ver, base_fw, py_ver, inst,
823+
id=f"{scope}-{full_ver}")
824+
)
825+
return cases
826+
827+
828+
@pytest.mark.parametrize(
829+
"scope,alias,resolved,base_fw,py_ver,instance_type",
830+
_get_huggingface_alias_test_cases(),
831+
)
832+
def test_huggingface_version_alias_resolves_in_tag(
833+
scope, alias, resolved, base_fw, py_ver, instance_type
834+
):
835+
"""Version aliases must be resolved to full versions in image URI tags."""
836+
uri = image_uris.retrieve(
837+
framework="huggingface",
838+
region="us-east-1",
839+
version=alias,
840+
py_version=py_ver,
841+
image_scope=scope,
842+
base_framework_version=base_fw,
843+
instance_type=instance_type,
844+
)
845+
assert f"transformers{resolved}-" in uri, (
846+
f"Expected resolved version 'transformers{resolved}-' in URI, got: {uri}"
847+
)
848+
849+
850+
@pytest.mark.parametrize(
851+
"scope,full_version,base_fw,py_ver,instance_type",
852+
_get_huggingface_full_version_test_cases(),
853+
)
854+
def test_huggingface_full_version_in_tag(
855+
scope, full_version, base_fw, py_ver, instance_type
856+
):
857+
"""Full (non-aliased) versions must appear unchanged in image URI tags."""
858+
uri = image_uris.retrieve(
859+
framework="huggingface",
860+
region="us-east-1",
861+
version=full_version,
862+
py_version=py_ver,
863+
image_scope=scope,
864+
base_framework_version=base_fw,
865+
instance_type=instance_type,
866+
)
867+
assert f"transformers{full_version}-" in uri, (
868+
f"Expected full version 'transformers{full_version}-' in URI, got: {uri}"
869+
)
870+
871+
790872
@pytest.mark.parametrize(
791873
"framework",
792874
[

0 commit comments

Comments
 (0)