@@ -99,7 +99,11 @@ def test_interpolates_env(self):
9999
100100class TestValidateGPUVendorAndImage :
101101 def prepare_conf (
102- self , * , image : Optional [str ] = None , gpu_spec : Optional [str ] = None
102+ self ,
103+ * ,
104+ image : Optional [str ] = None ,
105+ gpu_spec : Optional [str ] = None ,
106+ docker : Optional [bool ] = None ,
103107 ) -> BaseRunConfiguration :
104108 conf_dict = {
105109 "type" : "none" ,
@@ -110,6 +114,8 @@ def prepare_conf(
110114 conf_dict ["resources" ] = {
111115 "gpu" : gpu_spec ,
112116 }
117+ if docker is not None :
118+ conf_dict ["docker" ] = docker
113119 return BaseRunConfiguration .parse_obj (conf_dict )
114120
115121 def validate (self , conf : BaseRunConfiguration ) -> None :
@@ -199,6 +205,12 @@ def test_amd_vendor_declared_no_image(self):
199205 ):
200206 self .validate (conf )
201207
208+ @pytest .mark .parametrize ("gpu_spec" , ["AMD" , "MI300X" ])
209+ def test_amd_vendor_docker_true_no_image (self , gpu_spec ):
210+ conf = self .prepare_conf (gpu_spec = gpu_spec , docker = True )
211+ self .validate (conf )
212+ assert conf .resources .gpu .vendor == AcceleratorVendor .AMD
213+
202214 @pytest .mark .parametrize ("gpu_spec" , ["MI300X" , "MI300x" , "mi300x" ])
203215 def test_amd_vendor_inferred_no_image (self , gpu_spec ):
204216 conf = self .prepare_conf (gpu_spec = gpu_spec )
@@ -222,6 +234,12 @@ def test_two_vendors_including_amd_inferred_no_image(self, gpu_spec):
222234 ):
223235 self .validate (conf )
224236
237+ @pytest .mark .parametrize ("gpu_spec" , ["n150" , "n300" ])
238+ def test_tenstorrent_docker_true_no_image (self , gpu_spec ):
239+ conf = self .prepare_conf (gpu_spec = gpu_spec , docker = True )
240+ self .validate (conf )
241+ assert conf .resources .gpu .vendor == AcceleratorVendor .TENSTORRENT
242+
225243
226244class TestValidateCPUArchAndImage :
227245 def prepare_conf (
0 commit comments