@@ -276,7 +276,11 @@ def create_instance(
276276 image_id , username = self ._get_image_id_and_username (
277277 ec2_client = ec2_client ,
278278 region = instance_offer .region ,
279- cuda = len (instance_offer .instance .resources .gpus ) > 0 ,
279+ gpu_name = (
280+ instance_offer .instance .resources .gpus [0 ].name
281+ if len (instance_offer .instance .resources .gpus ) > 0
282+ else None
283+ ),
280284 instance_type = instance_offer .instance .name ,
281285 image_config = self .config .os_images ,
282286 )
@@ -882,11 +886,13 @@ def _get_image_id_and_username_cache_key(
882886 self ,
883887 ec2_client : botocore .client .BaseClient ,
884888 region : str ,
885- cuda : bool ,
889+ gpu_name : Optional [ str ] ,
886890 instance_type : str ,
887891 image_config : Optional [AWSOSImageConfig ] = None ,
888892 ) -> tuple :
889- return hashkey (region , cuda , instance_type , image_config .json () if image_config else None )
893+ return hashkey (
894+ region , gpu_name , instance_type , image_config .json () if image_config else None
895+ )
890896
891897 @cachedmethod (
892898 cache = lambda self : self ._get_image_id_and_username_cache ,
@@ -897,13 +903,13 @@ def _get_image_id_and_username(
897903 self ,
898904 ec2_client : botocore .client .BaseClient ,
899905 region : str ,
900- cuda : bool ,
906+ gpu_name : Optional [ str ] ,
901907 instance_type : str ,
902908 image_config : Optional [AWSOSImageConfig ] = None ,
903909 ) -> tuple [str , str ]:
904910 return aws_resources .get_image_id_and_username (
905911 ec2_client = ec2_client ,
906- cuda = cuda ,
912+ gpu_name = gpu_name ,
907913 instance_type = instance_type ,
908914 image_config = image_config ,
909915 )
0 commit comments