@@ -100,12 +100,15 @@ def get_compute(compute: Optional[Compute] = None) -> Compute:
100100 volume_size_in_gb = DEFAULT_VOLUME_SIZE ,
101101 )
102102 logger .info (f"Compute not provided. Using default:\n { compute } " )
103- if compute .instance_type is None :
104- compute .instance_type = DEFAULT_INSTANCE_TYPE
105- logger .info (f"Instance type not provided. Using default:\n { DEFAULT_INSTANCE_TYPE } " )
106- if compute .instance_count is None :
107- compute .instance_count = DEFAULT_INSTANCE_COUNT
108- logger .info (f"Instance count not provided. Using default:\n { compute .instance_count } " )
103+ if not compute .instance_groups :
104+ if compute .instance_type is None :
105+ compute .instance_type = DEFAULT_INSTANCE_TYPE
106+ logger .info (f"Instance type not provided. Using default:\n { DEFAULT_INSTANCE_TYPE } " )
107+ if compute .instance_count is None :
108+ compute .instance_count = DEFAULT_INSTANCE_COUNT
109+ logger .info (
110+ f"Instance count not provided. Using default:\n { compute .instance_count } "
111+ )
109112 if compute .volume_size_in_gb is None :
110113 compute .volume_size_in_gb = DEFAULT_VOLUME_SIZE
111114 logger .info (f"Volume size not provided. Using default:\n { compute .volume_size_in_gb } " )
@@ -225,21 +228,28 @@ def get_compute(
225228 ),
226229 )
227230 logger .info (f"Compute not provided. Using default compute:\n { compute } " )
228- if compute .instance_type is None and training_components_model .DefaultTrainingInstanceType :
229- compute .instance_type = training_components_model .DefaultTrainingInstanceType
230- logger .info (
231- f"Instance type not provided. Using default instance type:\n { compute .instance_type } "
232- )
231+ if not compute .instance_groups :
232+ if (
233+ compute .instance_type is None
234+ and training_components_model .DefaultTrainingInstanceType
235+ ):
236+ compute .instance_type = training_components_model .DefaultTrainingInstanceType
237+ logger .info (
238+ f"Instance type not provided. Using default instance type:"
239+ f"\n { compute .instance_type } "
240+ )
241+ if compute .instance_count is None :
242+ compute .instance_count = DEFAULT_INSTANCE_COUNT
243+ logger .info (
244+ f"Instance count not provided. Using default instance count:\n { compute } "
245+ )
233246 if compute .volume_size_in_gb is None :
234247 compute .volume_size_in_gb = (
235248 training_components_model .TrainingVolumeSize or DEFAULT_VOLUME_SIZE
236249 )
237250 logger .info (
238251 f"Volume size not provided. Using default volume size:\n { compute .volume_size_in_gb } "
239252 )
240- if compute .instance_count is None :
241- compute .instance_count = DEFAULT_INSTANCE_COUNT
242- logger .info (f"Instance count not provided. Using default instance count:\n { compute } " )
243253 return compute
244254
245255 def get_networking (
0 commit comments