@@ -370,6 +370,128 @@ def get_lowest_conc(search_space_entry):
370370 return matrix_values
371371
372372
373+ def generate_test_config_sweep (args , all_config_data ):
374+ """Generate full sweep for specific config keys.
375+
376+ Validates that all specified config keys exist before generating.
377+ Expands all configs fully without any filtering.
378+ """
379+ # Validate all config keys exist
380+ missing_keys = [key for key in args .config_keys if key not in all_config_data ]
381+ if missing_keys :
382+ available_keys = sorted (all_config_data .keys ())
383+ raise ValueError (
384+ f"Config key(s) not found: { ', ' .join (missing_keys )} .\n "
385+ f"Available keys: { ', ' .join (available_keys )} "
386+ )
387+
388+ matrix_values = []
389+
390+ for key in args .config_keys :
391+ val = all_config_data [key ]
392+ is_multinode = val .get (Fields .MULTINODE .value , False )
393+
394+ image = val [Fields .IMAGE .value ]
395+ model = val [Fields .MODEL .value ]
396+ model_code = val [Fields .MODEL_PREFIX .value ]
397+ precision = val [Fields .PRECISION .value ]
398+ framework = val [Fields .FRAMEWORK .value ]
399+ runner = val [Fields .RUNNER .value ]
400+ disagg = val .get (Fields .DISAGG .value , False )
401+
402+ for seq_len_config in val [Fields .SEQ_LEN_CONFIGS .value ]:
403+ isl = seq_len_config [Fields .ISL .value ]
404+ osl = seq_len_config [Fields .OSL .value ]
405+ seq_len_str = seq_len_to_str (isl , osl )
406+
407+ for bmk in seq_len_config [Fields .SEARCH_SPACE .value ]:
408+ if is_multinode :
409+ # Multinode config
410+ spec_decoding = bmk .get (Fields .SPEC_DECODING .value , "none" )
411+ prefill = bmk [Fields .PREFILL .value ]
412+ decode = bmk [Fields .DECODE .value ]
413+
414+ # Get concurrency values
415+ if Fields .CONC_LIST .value in bmk :
416+ conc_values = bmk [Fields .CONC_LIST .value ]
417+ else :
418+ conc_start = bmk [Fields .CONC_START .value ]
419+ conc_end = bmk [Fields .CONC_END .value ]
420+ conc_values = []
421+ conc = conc_start
422+ while conc <= conc_end :
423+ conc_values .append (conc )
424+ if conc == conc_end :
425+ break
426+ conc *= 2
427+ if conc > conc_end :
428+ conc = conc_end
429+
430+ entry = {
431+ Fields .IMAGE .value : image ,
432+ Fields .MODEL .value : model ,
433+ Fields .MODEL_PREFIX .value : model_code ,
434+ Fields .PRECISION .value : precision ,
435+ Fields .FRAMEWORK .value : framework ,
436+ Fields .RUNNER .value : runner ,
437+ Fields .ISL .value : isl ,
438+ Fields .OSL .value : osl ,
439+ Fields .SPEC_DECODING .value : spec_decoding ,
440+ Fields .PREFILL .value : prefill ,
441+ Fields .DECODE .value : decode ,
442+ Fields .CONC .value : conc_values ,
443+ Fields .MAX_MODEL_LEN .value : isl + osl + 200 ,
444+ Fields .EXP_NAME .value : f"{ model_code } _{ seq_len_str } " ,
445+ Fields .DISAGG .value : disagg ,
446+ }
447+ matrix_values .append (validate_matrix_entry (entry , is_multinode = True ))
448+ else :
449+ # Single-node config
450+ tp = bmk [Fields .TP .value ]
451+ ep = bmk .get (Fields .EP .value )
452+ dp_attn = bmk .get (Fields .DP_ATTN .value )
453+ spec_decoding = bmk .get (Fields .SPEC_DECODING .value , "none" )
454+
455+ # Get concurrency values
456+ if Fields .CONC_LIST .value in bmk :
457+ conc_values = bmk [Fields .CONC_LIST .value ]
458+ else :
459+ conc_start = bmk [Fields .CONC_START .value ]
460+ conc_end = bmk [Fields .CONC_END .value ]
461+ conc_values = []
462+ conc = conc_start
463+ while conc <= conc_end :
464+ conc_values .append (conc )
465+ if conc == conc_end :
466+ break
467+ conc *= 2
468+ if conc > conc_end :
469+ conc = conc_end
470+
471+ for conc in conc_values :
472+ entry = {
473+ Fields .IMAGE .value : image ,
474+ Fields .MODEL .value : model ,
475+ Fields .MODEL_PREFIX .value : model_code ,
476+ Fields .PRECISION .value : precision ,
477+ Fields .FRAMEWORK .value : framework ,
478+ Fields .RUNNER .value : runner ,
479+ Fields .ISL .value : isl ,
480+ Fields .OSL .value : osl ,
481+ Fields .TP .value : tp ,
482+ Fields .CONC .value : conc ,
483+ Fields .MAX_MODEL_LEN .value : isl + osl + 200 ,
484+ Fields .EP .value : ep if ep is not None else 1 ,
485+ Fields .DP_ATTN .value : dp_attn if dp_attn is not None else False ,
486+ Fields .SPEC_DECODING .value : spec_decoding ,
487+ Fields .EXP_NAME .value : f"{ model_code } _{ seq_len_str } " ,
488+ Fields .DISAGG .value : disagg ,
489+ }
490+ matrix_values .append (validate_matrix_entry (entry , is_multinode = False ))
491+
492+ return matrix_values
493+
494+
373495def main ():
374496 # Create parent parser with common arguments
375497 parent_parser = argparse .ArgumentParser (add_help = False )
@@ -511,6 +633,25 @@ def main():
511633 help = 'Show this help message and exit'
512634 )
513635
636+ # Subcommand: test-config
637+ test_config_keys_parser = subparsers .add_parser (
638+ 'test-config' ,
639+ parents = [parent_parser ],
640+ add_help = False ,
641+ help = 'Generate full sweep for specific config keys. Validates that all specified keys exist before generating.'
642+ )
643+ test_config_keys_parser .add_argument (
644+ '--config-keys' ,
645+ nargs = '+' ,
646+ required = True ,
647+ help = 'One or more config keys to generate sweep for (e.g., dsr1-fp4-b200-sglang dsr1-fp8-h200-trt)'
648+ )
649+ test_config_keys_parser .add_argument (
650+ '-h' , '--help' ,
651+ action = 'help' ,
652+ help = 'Show this help message and exit'
653+ )
654+
514655 args = parser .parse_args ()
515656
516657 # Load and validate configuration files (validation happens by default in load functions)
@@ -523,6 +664,8 @@ def main():
523664 elif args .command == 'runner-model-sweep' :
524665 matrix_values = generate_runner_model_sweep_config (
525666 args , all_config_data , runner_data )
667+ elif args .command == 'test-config' :
668+ matrix_values = generate_test_config_sweep (args , all_config_data )
526669 else :
527670 parser .error (f"Unknown command: { args .command } " )
528671
0 commit comments