@@ -94,9 +94,16 @@ def generate_test_config(args, all_config_data):
9494 f"Available keys: { available_keys } "
9595 )
9696
97- # Extract model code (everything before first hyphen)
97+ # Extract model code from config key
9898 model_code = args .key .split ('-' )[0 ]
99-
99+ # Extract GPU from config key
100+ config_gpu = args .key .split ('-' )[2 ]
101+ runner_gpu = args .runner_node .split ('-' )[0 ] if args .runner_node else None
102+
103+ # If user enters a runner not compatible with input GPU sku, error
104+ if runner_gpu and config_gpu != runner_gpu :
105+ raise ValueError (f"GPU '{ config_gpu } ' used in selected config '{ args .key } ' cannot run on selected runner node '{ args .runner_node } '." )
106+
100107 val = all_config_data [args .key ]
101108
102109 # Validate required fields
@@ -107,7 +114,8 @@ def generate_test_config(args, all_config_data):
107114 model = val .get ('model' )
108115 precision = val .get ('precision' )
109116 framework = val .get ('framework' )
110- runner = val .get ('runner' )
117+ # Use default runner or specific runner node if input by user
118+ runner = val .get ('runner' ) if not args .runner_node else args .runner_node
111119
112120 assert None not in (image , model , precision , framework , runner ), \
113121 f"Missing required fields (image, model, precision, framework, runner) for key '{ args .key } '"
@@ -287,6 +295,11 @@ def main():
287295 required = True ,
288296 help = 'Configuration key to use'
289297 )
298+ test_config_parser .add_argument (
299+ '--runner-node' ,
300+ required = False ,
301+ help = 'Specific runner node to use'
302+ )
290303 test_config_parser .add_argument (
291304 '--seq-lens' ,
292305 nargs = '+' ,
0 commit comments