|
| 1 | +import json |
| 2 | +import yaml |
| 3 | +import argparse |
| 4 | + |
| 5 | +seq_len_stoi = { |
| 6 | + "1k1k": (1024, 1024), |
| 7 | + "1k8k": (1024, 8192), |
| 8 | + "8k1k": (8192, 1024) |
| 9 | +} |
| 10 | + |
| 11 | +def generate_full_sweep(args, all_config_data): |
| 12 | + """Generate full sweep configurations based on model prefix and sequence lengths.""" |
| 13 | + isl, osl = seq_len_stoi[args.seq_lens] |
| 14 | + |
| 15 | + matrix_values = [] |
| 16 | + for key, val in all_config_data.items(): |
| 17 | + # Filter by model prefix |
| 18 | + if not key.startswith(args.model_prefix): |
| 19 | + continue |
| 20 | + |
| 21 | + seq_len_configs = val.get('seq-len-configs') |
| 22 | + assert seq_len_configs, f"Missing 'seq-len-configs' for key '{key}'" |
| 23 | + |
| 24 | + image = val.get('image') |
| 25 | + model = val.get('model') |
| 26 | + precision = val.get('precision') |
| 27 | + framework = val.get('framework') |
| 28 | + runner = val.get('runner') |
| 29 | + |
| 30 | + assert None not in (image, model, precision, framework, runner), \ |
| 31 | + f"Missing required fields for key '{key}'" |
| 32 | + |
| 33 | + # Check if this config has matching sequence lengths |
| 34 | + matching_seq_config = None |
| 35 | + for slq in seq_len_configs: |
| 36 | + if slq.get('isl') == isl and slq.get('osl') == osl: |
| 37 | + matching_seq_config = slq |
| 38 | + break |
| 39 | + |
| 40 | + if not matching_seq_config: |
| 41 | + continue # Skip this config if no matching sequence length |
| 42 | + |
| 43 | + bmk_space = matching_seq_config.get('bmk-space') |
| 44 | + assert bmk_space, f"Missing 'bmk-space' in matching seq-len-config for key '{key}'" |
| 45 | + |
| 46 | + for bmk in bmk_space: |
| 47 | + tp = bmk.get('tp') |
| 48 | + conc_start = bmk.get('conc-start') |
| 49 | + conc_end = bmk.get('conc-end') |
| 50 | + ep = bmk.get('ep') |
| 51 | + dp_attn = bmk.get('dp-attn') |
| 52 | + |
| 53 | + assert None not in (tp, conc_start, conc_end), \ |
| 54 | + f"Missing 'tp', 'conc-start', or 'conc-end' in bmk-space for key '{key}'" |
| 55 | + |
| 56 | + # Generate entries for each concurrency value in the range |
| 57 | + conc = conc_start |
| 58 | + while conc <= conc_end: |
| 59 | + entry = { |
| 60 | + 'image': image, |
| 61 | + 'model': model, |
| 62 | + 'precision': precision, |
| 63 | + 'framework': framework, |
| 64 | + 'runner': runner, |
| 65 | + 'isl': isl, |
| 66 | + 'osl': osl, |
| 67 | + 'tp': tp, |
| 68 | + 'conc': conc |
| 69 | + } |
| 70 | + |
| 71 | + # Add optional fields if they exist |
| 72 | + if ep is not None: |
| 73 | + entry['ep'] = ep |
| 74 | + if dp_attn is not None: |
| 75 | + entry['dp-attn'] = dp_attn |
| 76 | + |
| 77 | + matrix_values.append(entry) |
| 78 | + |
| 79 | + if conc == conc_end: |
| 80 | + break |
| 81 | + conc *= args.step_size |
| 82 | + if conc > conc_end: |
| 83 | + conc = conc_end |
| 84 | + |
| 85 | + return matrix_values |
| 86 | + |
| 87 | +def generate_test_config(args, all_config_data): |
| 88 | + """Generate test configurations for a specific key.""" |
| 89 | + # Check if the key exists |
| 90 | + if args.key not in all_config_data: |
| 91 | + available_keys = ', '.join(sorted(all_config_data.keys())) |
| 92 | + raise ValueError( |
| 93 | + f"Key '{args.key}' not found in configuration files. " |
| 94 | + f"Available keys: {available_keys}" |
| 95 | + ) |
| 96 | + |
| 97 | + # Extract model code (everything before first hyphen) |
| 98 | + model_code = args.key.split('-')[0] |
| 99 | + |
| 100 | + val = all_config_data[args.key] |
| 101 | + |
| 102 | + # Validate required fields |
| 103 | + seq_len_configs = val.get('seq-len-configs') |
| 104 | + assert seq_len_configs, f"Missing 'seq-len-configs' for key '{args.key}'" |
| 105 | + |
| 106 | + image = val.get('image') |
| 107 | + model = val.get('model') |
| 108 | + precision = val.get('precision') |
| 109 | + framework = val.get('framework') |
| 110 | + runner = val.get('runner') |
| 111 | + |
| 112 | + assert None not in (image, model, precision, framework, runner), \ |
| 113 | + f"Missing required fields (image, model, precision, framework, runner) for key '{args.key}'" |
| 114 | + |
| 115 | + # Convert seq-lens to set of (isl, osl) tuples for filtering |
| 116 | + seq_lens_filter = None |
| 117 | + if args.seq_lens: |
| 118 | + seq_lens_filter = {seq_len_stoi[sl] for sl in args.seq_lens} |
| 119 | + |
| 120 | + matrix_values = [] |
| 121 | + |
| 122 | + # Process each sequence length configuration |
| 123 | + for seq_config in seq_len_configs: |
| 124 | + isl = seq_config.get('isl') |
| 125 | + osl = seq_config.get('osl') |
| 126 | + |
| 127 | + assert None not in (isl, osl), \ |
| 128 | + f"Missing 'isl' or 'osl' in seq-len-config for key '{args.key}'" |
| 129 | + |
| 130 | + # Filter by sequence lengths if specified |
| 131 | + if seq_lens_filter and (isl, osl) not in seq_lens_filter: |
| 132 | + continue |
| 133 | + |
| 134 | + bmk_space = seq_config.get('bmk-space') |
| 135 | + assert bmk_space, f"Missing 'bmk-space' in seq-len-config for key '{args.key}'" |
| 136 | + |
| 137 | + for bmk in bmk_space: |
| 138 | + tp = bmk.get('tp') |
| 139 | + conc_start = bmk.get('conc-start') |
| 140 | + conc_end = bmk.get('conc-end') |
| 141 | + ep = bmk.get('ep') |
| 142 | + dp_attn = bmk.get('dp-attn') |
| 143 | + |
| 144 | + assert None not in (tp, conc_start, conc_end), \ |
| 145 | + f"Missing 'tp', 'conc-start', or 'conc-end' in bmk-space for key '{args.key}'" |
| 146 | + |
| 147 | + # In test mode, only use the lowest concurrency (conc_start) |
| 148 | + if args.test_mode: |
| 149 | + entry = { |
| 150 | + 'image': image, |
| 151 | + 'model': model, |
| 152 | + 'model-code': model_code, |
| 153 | + 'precision': precision, |
| 154 | + 'framework': framework, |
| 155 | + 'runner': runner, |
| 156 | + 'isl': isl, |
| 157 | + 'osl': osl, |
| 158 | + 'tp': tp, |
| 159 | + 'conc': conc_start, |
| 160 | + 'max-model-len': isl + osl, |
| 161 | + } |
| 162 | + |
| 163 | + # Add optional fields if they exist |
| 164 | + if ep is not None: |
| 165 | + entry['ep'] = ep |
| 166 | + if dp_attn is not None: |
| 167 | + entry['dp-attn'] = dp_attn |
| 168 | + |
| 169 | + matrix_values.append(entry) |
| 170 | + else: |
| 171 | + # Generate entries for each concurrency value in the range |
| 172 | + conc = conc_start |
| 173 | + while conc <= conc_end: |
| 174 | + entry = { |
| 175 | + 'image': image, |
| 176 | + 'model': model, |
| 177 | + 'model-code': model_code, |
| 178 | + 'precision': precision, |
| 179 | + 'framework': framework, |
| 180 | + 'runner': runner, |
| 181 | + 'isl': isl, |
| 182 | + 'osl': osl, |
| 183 | + 'tp': tp, |
| 184 | + 'conc': conc, |
| 185 | + 'max-model-len': isl + osl, |
| 186 | + } |
| 187 | + |
| 188 | + # Add optional fields if they exist |
| 189 | + if ep is not None: |
| 190 | + entry['ep'] = ep |
| 191 | + if dp_attn is not None: |
| 192 | + entry['dp-attn'] = dp_attn |
| 193 | + |
| 194 | + matrix_values.append(entry) |
| 195 | + |
| 196 | + if conc == conc_end: |
| 197 | + break |
| 198 | + conc *= args.step_size |
| 199 | + if conc > conc_end: |
| 200 | + conc = conc_end |
| 201 | + |
| 202 | + return matrix_values |
| 203 | + |
| 204 | +def load_config_files(config_files): |
| 205 | + """Load and merge configuration files.""" |
| 206 | + all_config_data = {} |
| 207 | + for config_file in config_files: |
| 208 | + try: |
| 209 | + with open(config_file, 'r') as f: |
| 210 | + config_data = yaml.safe_load(f) |
| 211 | + assert isinstance(config_data, dict), f"Config file '{config_file}' must contain a dictionary" |
| 212 | + |
| 213 | + # Check for duplicate keys |
| 214 | + duplicate_keys = set(all_config_data.keys()) & set(config_data.keys()) |
| 215 | + if duplicate_keys: |
| 216 | + raise ValueError( |
| 217 | + f"Duplicate configuration keys found in '{config_file}': {', '.join(sorted(duplicate_keys))}" |
| 218 | + ) |
| 219 | + |
| 220 | + all_config_data.update(config_data) |
| 221 | + except FileNotFoundError: |
| 222 | + raise ValueError(f"Input file '{config_file}' does not exist.") |
| 223 | + |
| 224 | + return all_config_data |
| 225 | + |
| 226 | +def main(): |
| 227 | + # Create parent parser with common arguments |
| 228 | + parent_parser = argparse.ArgumentParser(add_help=False) |
| 229 | + parent_parser.add_argument( |
| 230 | + '--config-files', |
| 231 | + nargs='+', |
| 232 | + required=True, |
| 233 | + help='One or more configuration files (YAML format)' |
| 234 | + ) |
| 235 | + |
| 236 | + # Create main parser |
| 237 | + parser = argparse.ArgumentParser( |
| 238 | + description='Generate benchmark configurations from YAML config files' |
| 239 | + ) |
| 240 | + |
| 241 | + # Create subparsers for subcommands |
| 242 | + subparsers = parser.add_subparsers( |
| 243 | + dest='command', |
| 244 | + required=True, |
| 245 | + help='Available commands' |
| 246 | + ) |
| 247 | + |
| 248 | + # Subcommand: full-sweep |
| 249 | + full_sweep_parser = subparsers.add_parser( |
| 250 | + 'full-sweep', |
| 251 | + parents=[parent_parser], |
| 252 | + add_help=False, |
| 253 | + help='Generate full sweep configurations based on model prefix' |
| 254 | + ) |
| 255 | + full_sweep_parser.add_argument( |
| 256 | + '--seq-lens', |
| 257 | + choices=list(seq_len_stoi.keys()), |
| 258 | + required=True, |
| 259 | + help=f"Sequence length configuration: {', '.join(seq_len_stoi.keys())}" |
| 260 | + ) |
| 261 | + full_sweep_parser.add_argument( |
| 262 | + '--model-prefix', |
| 263 | + required=True, |
| 264 | + help='Model prefix to filter configurations' |
| 265 | + ) |
| 266 | + full_sweep_parser.add_argument( |
| 267 | + '--step-size', |
| 268 | + type=int, |
| 269 | + default=2, |
| 270 | + help='Step size for concurrency values (default: 2)' |
| 271 | + ) |
| 272 | + full_sweep_parser.add_argument( |
| 273 | + '-h', '--help', |
| 274 | + action='help', |
| 275 | + help='Show this help message and exit' |
| 276 | + ) |
| 277 | + |
| 278 | + # Subcommand: test-config |
| 279 | + test_config_parser = subparsers.add_parser( |
| 280 | + 'test-config', |
| 281 | + parents=[parent_parser], |
| 282 | + add_help=False, |
| 283 | + help='Generate test configurations for a specific key' |
| 284 | + ) |
| 285 | + test_config_parser.add_argument( |
| 286 | + '--key', |
| 287 | + required=True, |
| 288 | + help='Configuration key to use' |
| 289 | + ) |
| 290 | + test_config_parser.add_argument( |
| 291 | + '--seq-lens', |
| 292 | + nargs='+', |
| 293 | + choices=list(seq_len_stoi.keys()), |
| 294 | + required=False, |
| 295 | + help=f"Sequence length configurations to include: {', '.join(seq_len_stoi.keys())}. If not specified, all sequence lengths are included." |
| 296 | + ) |
| 297 | + test_config_parser.add_argument( |
| 298 | + '--step-size', |
| 299 | + type=int, |
| 300 | + default=2, |
| 301 | + help='Step size for concurrency values (default: 2)' |
| 302 | + ) |
| 303 | + test_config_parser.add_argument( |
| 304 | + '--test-mode', |
| 305 | + action='store_true', |
| 306 | + help='Generate only the lowest concurrency value for each TP level' |
| 307 | + ) |
| 308 | + test_config_parser.add_argument( |
| 309 | + '-h', '--help', |
| 310 | + action='help', |
| 311 | + help='Show this help message and exit' |
| 312 | + ) |
| 313 | + |
| 314 | + args = parser.parse_args() |
| 315 | + |
| 316 | + # Load configuration files |
| 317 | + all_config_data = load_config_files(args.config_files) |
| 318 | + |
| 319 | + # Route to appropriate function based on subcommand |
| 320 | + if args.command == 'full-sweep': |
| 321 | + matrix_values = generate_full_sweep(args, all_config_data) |
| 322 | + elif args.command == 'test-config': |
| 323 | + matrix_values = generate_test_config(args, all_config_data) |
| 324 | + else: |
| 325 | + parser.error(f"Unknown command: {args.command}") |
| 326 | + |
| 327 | + print(json.dumps(matrix_values)) |
| 328 | + return matrix_values |
| 329 | + |
| 330 | +if __name__ == "__main__": |
| 331 | + main() |
0 commit comments