1+ import json
2+ import yaml
3+ import sys
4+ import argparse
5+
6+ seq_len_stoi = {
7+ "1k1k" : (1024 , 1024 ),
8+ "1k8k" : (1024 , 8192 ),
9+ "8k1k" : (8192 , 1024 )
10+ }
11+
12+ def main ():
13+ parser = argparse .ArgumentParser (
14+ description = 'Generate benchmark matrix from a specific configuration key'
15+ )
16+ parser .add_argument (
17+ '--config-files' ,
18+ nargs = '+' ,
19+ required = True ,
20+ help = 'One or more configuration files (YAML format)'
21+ )
22+ parser .add_argument (
23+ '--key' ,
24+ required = True ,
25+ help = 'Configuration key to use'
26+ )
27+ parser .add_argument (
28+ '--seq-lens' ,
29+ nargs = '+' ,
30+ choices = list (seq_len_stoi .keys ()),
31+ required = False ,
32+ help = f"Sequence length configurations to include: { ', ' .join (seq_len_stoi .keys ())} . If not specified, all sequence lengths are included."
33+ )
34+ parser .add_argument (
35+ '--step-size' ,
36+ type = int ,
37+ default = 2 ,
38+ help = 'Step size for concurrency values (default: 2)'
39+ )
40+
41+ args = parser .parse_args ()
42+
43+ # Convert seq-lens to set of (isl, osl) tuples for filtering
44+ seq_lens_filter = None
45+ if args .seq_lens :
46+ seq_lens_filter = {seq_len_stoi [sl ] for sl in args .seq_lens }
47+
48+ # Load and merge all config files
49+ all_config_data = {}
50+ for config_file in args .config_files :
51+ try :
52+ with open (config_file , 'r' ) as f :
53+ config_data = yaml .safe_load (f )
54+ assert isinstance (config_data , dict ), f"Config file '{ config_file } ' must contain a dictionary"
55+
56+ # Check for duplicate keys
57+ duplicate_keys = set (all_config_data .keys ()) & set (config_data .keys ())
58+ if duplicate_keys :
59+ raise ValueError (
60+ f"Duplicate configuration keys found in '{ config_file } ': { ', ' .join (sorted (duplicate_keys ))} "
61+ )
62+
63+ all_config_data .update (config_data )
64+ except FileNotFoundError :
65+ raise ValueError (f"Input file '{ config_file } ' does not exist." )
66+
67+ # Check if the key exists
68+ if args .key not in all_config_data :
69+ available_keys = ', ' .join (sorted (all_config_data .keys ()))
70+ raise ValueError (
71+ f"Key '{ args .key } ' not found in configuration files. "
72+ f"Available keys: { available_keys } "
73+ )
74+
75+ val = all_config_data [args .key ]
76+
77+ # Validate required fields
78+ seq_len_configs = val .get ('seq-len-configs' )
79+ assert seq_len_configs , f"Missing 'seq-len-configs' for key '{ args .key } '"
80+
81+ image = val .get ('image' )
82+ model = val .get ('model' )
83+ precision = val .get ('precision' )
84+ framework = val .get ('framework' )
85+ runner = val .get ('runner' )
86+
87+ assert None not in (image , model , precision , framework , runner ), \
88+ f"Missing required fields (image, model, precision, framework, runner) for key '{ args .key } '"
89+
90+ matrix_values = []
91+
92+ # Process each sequence length configuration
93+ for seq_config in seq_len_configs :
94+ isl = seq_config .get ('isl' )
95+ osl = seq_config .get ('osl' )
96+
97+ assert None not in (isl , osl ), \
98+ f"Missing 'isl' or 'osl' in seq-len-config for key '{ args .key } '"
99+
100+ # Filter by sequence lengths if specified
101+ if seq_lens_filter and (isl , osl ) not in seq_lens_filter :
102+ continue
103+
104+ bmk_space = seq_config .get ('bmk-space' )
105+ assert bmk_space , f"Missing 'bmk-space' in seq-len-config for key '{ args .key } '"
106+
107+ for bmk in bmk_space :
108+ tp = bmk .get ('tp' )
109+ conc_start = bmk .get ('conc-start' )
110+ conc_end = bmk .get ('conc-end' )
111+ ep = bmk .get ('ep' )
112+ dp_attn = bmk .get ('dp-attn' )
113+
114+ assert None not in (tp , conc_start , conc_end ), \
115+ f"Missing 'tp', 'conc-start', or 'conc-end' in bmk-space for key '{ args .key } '"
116+
117+ # Generate entries for each concurrency value in the range
118+ conc = conc_start
119+ while conc <= conc_end :
120+ entry = {
121+ 'image' : image ,
122+ 'model' : model ,
123+ 'precision' : precision ,
124+ 'framework' : framework ,
125+ 'runner' : runner ,
126+ 'isl' : isl ,
127+ 'osl' : osl ,
128+ 'tp' : tp ,
129+ 'conc' : conc ,
130+ 'max-model-len' : isl + osl ,
131+ }
132+
133+ # Add optional fields if they exist
134+ if ep is not None :
135+ entry ['ep' ] = ep
136+ if dp_attn is not None :
137+ entry ['dp-attn' ] = dp_attn
138+
139+ matrix_values .append (entry )
140+
141+ if conc == conc_end :
142+ break
143+ conc *= args .step_size
144+ if conc > conc_end :
145+ conc = conc_end
146+
147+ print (json .dumps (matrix_values ))
148+ return matrix_values
149+
150+ if __name__ == "__main__" :
151+ main ()
0 commit comments