@@ -670,7 +670,7 @@ def test_retrieve_huggingface(config_for_framework):
670670 )
671671 assert (
672672 "564829616587.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:"
673- "1.6-transformers4.2-gpu-py37-cu110-ubuntu18.04" == pt_uri_mv
673+ "1.6-transformers4.2.1 -gpu-py37-cu110-ubuntu18.04" == pt_uri_mv
674674 )
675675
676676 pt_uri = image_uris .retrieve (
@@ -715,7 +715,7 @@ def test_retrieve_huggingface(config_for_framework):
715715 )
716716 assert (
717717 "564829616587.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:"
718- "1.6.0-transformers4.3 .1-gpu-py37-cu110-ubuntu18.04" == pt_new_version
718+ "1.6.0-transformers4.2 .1-gpu-py37-cu110-ubuntu18.04" == pt_new_version
719719 )
720720
721721
@@ -787,6 +787,88 @@ def test_get_latest_version_function_with_no_framework(config_for_framework):
787787 assert "No framework config for framework" in str (e .exception )
788788
789789
790+ def _get_huggingface_alias_test_cases ():
791+ """Build parametrized test cases for every HuggingFace version alias."""
792+ config = image_uris .config_for_framework ("huggingface" )
793+ cases = []
794+ for scope in ("training" , "inference" ):
795+ section = config [scope ]
796+ aliases = section .get ("version_aliases" , {})
797+ for alias , resolved in aliases .items ():
798+ ver_cfg = section ["versions" ][resolved ]
799+ base_fws = [k for k in ver_cfg if k != "version_aliases" ]
800+ base_fw = base_fws [0 ]
801+ py_ver = ver_cfg [base_fw ]["py_versions" ][0 ]
802+ inst = "ml.p3.2xlarge" if scope == "training" else "ml.c5.xlarge"
803+ cases .append (
804+ pytest .param (scope , alias , resolved , base_fw , py_ver , inst ,
805+ id = f"{ scope } -{ alias } ->{ resolved } " )
806+ )
807+ return cases
808+
809+
810+ def _get_huggingface_full_version_test_cases ():
811+ """Build parametrized test cases for every non-aliased HuggingFace version."""
812+ config = image_uris .config_for_framework ("huggingface" )
813+ cases = []
814+ for scope in ("training" , "inference" ):
815+ section = config [scope ]
816+ for full_ver , ver_cfg in section ["versions" ].items ():
817+ base_fws = [k for k in ver_cfg if k != "version_aliases" ]
818+ base_fw = base_fws [0 ]
819+ py_ver = ver_cfg [base_fw ]["py_versions" ][0 ]
820+ inst = "ml.p3.2xlarge" if scope == "training" else "ml.c5.xlarge"
821+ cases .append (
822+ pytest .param (scope , full_ver , base_fw , py_ver , inst ,
823+ id = f"{ scope } -{ full_ver } " )
824+ )
825+ return cases
826+
827+
828+ @pytest .mark .parametrize (
829+ "scope,alias,resolved,base_fw,py_ver,instance_type" ,
830+ _get_huggingface_alias_test_cases (),
831+ )
832+ def test_huggingface_version_alias_resolves_in_tag (
833+ scope , alias , resolved , base_fw , py_ver , instance_type
834+ ):
835+ """Version aliases must be resolved to full versions in image URI tags."""
836+ uri = image_uris .retrieve (
837+ framework = "huggingface" ,
838+ region = "us-east-1" ,
839+ version = alias ,
840+ py_version = py_ver ,
841+ image_scope = scope ,
842+ base_framework_version = base_fw ,
843+ instance_type = instance_type ,
844+ )
845+ assert f"transformers{ resolved } -" in uri , (
846+ f"Expected resolved version 'transformers{ resolved } -' in URI, got: { uri } "
847+ )
848+
849+
850+ @pytest .mark .parametrize (
851+ "scope,full_version,base_fw,py_ver,instance_type" ,
852+ _get_huggingface_full_version_test_cases (),
853+ )
854+ def test_huggingface_full_version_in_tag (
855+ scope , full_version , base_fw , py_ver , instance_type
856+ ):
857+ """Full (non-aliased) versions must appear unchanged in image URI tags."""
858+ uri = image_uris .retrieve (
859+ framework = "huggingface" ,
860+ region = "us-east-1" ,
861+ version = full_version ,
862+ py_version = py_ver ,
863+ image_scope = scope ,
864+ base_framework_version = base_fw ,
865+ instance_type = instance_type ,
866+ )
867+ assert f"transformers{ full_version } -" in uri , (
868+ f"Expected full version 'transformers{ full_version } -' in URI, got: { uri } "
869+ )
870+
871+
790872@pytest .mark .parametrize (
791873 "framework" ,
792874 [
0 commit comments