Skip to content

Commit 0226fc5

Browse files
committed
adding more workflows
1 parent 6fec99e commit 0226fc5

1 file changed

Lines changed: 331 additions & 0 deletions

File tree

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
import json
2+
import yaml
3+
import argparse
4+
5+
seq_len_stoi = {
6+
"1k1k": (1024, 1024),
7+
"1k8k": (1024, 8192),
8+
"8k1k": (8192, 1024)
9+
}
10+
11+
def generate_full_sweep(args, all_config_data):
12+
"""Generate full sweep configurations based on model prefix and sequence lengths."""
13+
isl, osl = seq_len_stoi[args.seq_lens]
14+
15+
matrix_values = []
16+
for key, val in all_config_data.items():
17+
# Filter by model prefix
18+
if not key.startswith(args.model_prefix):
19+
continue
20+
21+
seq_len_configs = val.get('seq-len-configs')
22+
assert seq_len_configs, f"Missing 'seq-len-configs' for key '{key}'"
23+
24+
image = val.get('image')
25+
model = val.get('model')
26+
precision = val.get('precision')
27+
framework = val.get('framework')
28+
runner = val.get('runner')
29+
30+
assert None not in (image, model, precision, framework, runner), \
31+
f"Missing required fields for key '{key}'"
32+
33+
# Check if this config has matching sequence lengths
34+
matching_seq_config = None
35+
for slq in seq_len_configs:
36+
if slq.get('isl') == isl and slq.get('osl') == osl:
37+
matching_seq_config = slq
38+
break
39+
40+
if not matching_seq_config:
41+
continue # Skip this config if no matching sequence length
42+
43+
bmk_space = matching_seq_config.get('bmk-space')
44+
assert bmk_space, f"Missing 'bmk-space' in matching seq-len-config for key '{key}'"
45+
46+
for bmk in bmk_space:
47+
tp = bmk.get('tp')
48+
conc_start = bmk.get('conc-start')
49+
conc_end = bmk.get('conc-end')
50+
ep = bmk.get('ep')
51+
dp_attn = bmk.get('dp-attn')
52+
53+
assert None not in (tp, conc_start, conc_end), \
54+
f"Missing 'tp', 'conc-start', or 'conc-end' in bmk-space for key '{key}'"
55+
56+
# Generate entries for each concurrency value in the range
57+
conc = conc_start
58+
while conc <= conc_end:
59+
entry = {
60+
'image': image,
61+
'model': model,
62+
'precision': precision,
63+
'framework': framework,
64+
'runner': runner,
65+
'isl': isl,
66+
'osl': osl,
67+
'tp': tp,
68+
'conc': conc
69+
}
70+
71+
# Add optional fields if they exist
72+
if ep is not None:
73+
entry['ep'] = ep
74+
if dp_attn is not None:
75+
entry['dp-attn'] = dp_attn
76+
77+
matrix_values.append(entry)
78+
79+
if conc == conc_end:
80+
break
81+
conc *= args.step_size
82+
if conc > conc_end:
83+
conc = conc_end
84+
85+
return matrix_values
86+
87+
def generate_test_config(args, all_config_data):
88+
"""Generate test configurations for a specific key."""
89+
# Check if the key exists
90+
if args.key not in all_config_data:
91+
available_keys = ', '.join(sorted(all_config_data.keys()))
92+
raise ValueError(
93+
f"Key '{args.key}' not found in configuration files. "
94+
f"Available keys: {available_keys}"
95+
)
96+
97+
# Extract model code (everything before first hyphen)
98+
model_code = args.key.split('-')[0]
99+
100+
val = all_config_data[args.key]
101+
102+
# Validate required fields
103+
seq_len_configs = val.get('seq-len-configs')
104+
assert seq_len_configs, f"Missing 'seq-len-configs' for key '{args.key}'"
105+
106+
image = val.get('image')
107+
model = val.get('model')
108+
precision = val.get('precision')
109+
framework = val.get('framework')
110+
runner = val.get('runner')
111+
112+
assert None not in (image, model, precision, framework, runner), \
113+
f"Missing required fields (image, model, precision, framework, runner) for key '{args.key}'"
114+
115+
# Convert seq-lens to set of (isl, osl) tuples for filtering
116+
seq_lens_filter = None
117+
if args.seq_lens:
118+
seq_lens_filter = {seq_len_stoi[sl] for sl in args.seq_lens}
119+
120+
matrix_values = []
121+
122+
# Process each sequence length configuration
123+
for seq_config in seq_len_configs:
124+
isl = seq_config.get('isl')
125+
osl = seq_config.get('osl')
126+
127+
assert None not in (isl, osl), \
128+
f"Missing 'isl' or 'osl' in seq-len-config for key '{args.key}'"
129+
130+
# Filter by sequence lengths if specified
131+
if seq_lens_filter and (isl, osl) not in seq_lens_filter:
132+
continue
133+
134+
bmk_space = seq_config.get('bmk-space')
135+
assert bmk_space, f"Missing 'bmk-space' in seq-len-config for key '{args.key}'"
136+
137+
for bmk in bmk_space:
138+
tp = bmk.get('tp')
139+
conc_start = bmk.get('conc-start')
140+
conc_end = bmk.get('conc-end')
141+
ep = bmk.get('ep')
142+
dp_attn = bmk.get('dp-attn')
143+
144+
assert None not in (tp, conc_start, conc_end), \
145+
f"Missing 'tp', 'conc-start', or 'conc-end' in bmk-space for key '{args.key}'"
146+
147+
# In test mode, only use the lowest concurrency (conc_start)
148+
if args.test_mode:
149+
entry = {
150+
'image': image,
151+
'model': model,
152+
'model-code': model_code,
153+
'precision': precision,
154+
'framework': framework,
155+
'runner': runner,
156+
'isl': isl,
157+
'osl': osl,
158+
'tp': tp,
159+
'conc': conc_start,
160+
'max-model-len': isl + osl,
161+
}
162+
163+
# Add optional fields if they exist
164+
if ep is not None:
165+
entry['ep'] = ep
166+
if dp_attn is not None:
167+
entry['dp-attn'] = dp_attn
168+
169+
matrix_values.append(entry)
170+
else:
171+
# Generate entries for each concurrency value in the range
172+
conc = conc_start
173+
while conc <= conc_end:
174+
entry = {
175+
'image': image,
176+
'model': model,
177+
'model-code': model_code,
178+
'precision': precision,
179+
'framework': framework,
180+
'runner': runner,
181+
'isl': isl,
182+
'osl': osl,
183+
'tp': tp,
184+
'conc': conc,
185+
'max-model-len': isl + osl,
186+
}
187+
188+
# Add optional fields if they exist
189+
if ep is not None:
190+
entry['ep'] = ep
191+
if dp_attn is not None:
192+
entry['dp-attn'] = dp_attn
193+
194+
matrix_values.append(entry)
195+
196+
if conc == conc_end:
197+
break
198+
conc *= args.step_size
199+
if conc > conc_end:
200+
conc = conc_end
201+
202+
return matrix_values
203+
204+
def load_config_files(config_files):
205+
"""Load and merge configuration files."""
206+
all_config_data = {}
207+
for config_file in config_files:
208+
try:
209+
with open(config_file, 'r') as f:
210+
config_data = yaml.safe_load(f)
211+
assert isinstance(config_data, dict), f"Config file '{config_file}' must contain a dictionary"
212+
213+
# Check for duplicate keys
214+
duplicate_keys = set(all_config_data.keys()) & set(config_data.keys())
215+
if duplicate_keys:
216+
raise ValueError(
217+
f"Duplicate configuration keys found in '{config_file}': {', '.join(sorted(duplicate_keys))}"
218+
)
219+
220+
all_config_data.update(config_data)
221+
except FileNotFoundError:
222+
raise ValueError(f"Input file '{config_file}' does not exist.")
223+
224+
return all_config_data
225+
226+
def main():
227+
# Create parent parser with common arguments
228+
parent_parser = argparse.ArgumentParser(add_help=False)
229+
parent_parser.add_argument(
230+
'--config-files',
231+
nargs='+',
232+
required=True,
233+
help='One or more configuration files (YAML format)'
234+
)
235+
236+
# Create main parser
237+
parser = argparse.ArgumentParser(
238+
description='Generate benchmark configurations from YAML config files'
239+
)
240+
241+
# Create subparsers for subcommands
242+
subparsers = parser.add_subparsers(
243+
dest='command',
244+
required=True,
245+
help='Available commands'
246+
)
247+
248+
# Subcommand: full-sweep
249+
full_sweep_parser = subparsers.add_parser(
250+
'full-sweep',
251+
parents=[parent_parser],
252+
add_help=False,
253+
help='Generate full sweep configurations based on model prefix'
254+
)
255+
full_sweep_parser.add_argument(
256+
'--seq-lens',
257+
choices=list(seq_len_stoi.keys()),
258+
required=True,
259+
help=f"Sequence length configuration: {', '.join(seq_len_stoi.keys())}"
260+
)
261+
full_sweep_parser.add_argument(
262+
'--model-prefix',
263+
required=True,
264+
help='Model prefix to filter configurations'
265+
)
266+
full_sweep_parser.add_argument(
267+
'--step-size',
268+
type=int,
269+
default=2,
270+
help='Step size for concurrency values (default: 2)'
271+
)
272+
full_sweep_parser.add_argument(
273+
'-h', '--help',
274+
action='help',
275+
help='Show this help message and exit'
276+
)
277+
278+
# Subcommand: test-config
279+
test_config_parser = subparsers.add_parser(
280+
'test-config',
281+
parents=[parent_parser],
282+
add_help=False,
283+
help='Generate test configurations for a specific key'
284+
)
285+
test_config_parser.add_argument(
286+
'--key',
287+
required=True,
288+
help='Configuration key to use'
289+
)
290+
test_config_parser.add_argument(
291+
'--seq-lens',
292+
nargs='+',
293+
choices=list(seq_len_stoi.keys()),
294+
required=False,
295+
help=f"Sequence length configurations to include: {', '.join(seq_len_stoi.keys())}. If not specified, all sequence lengths are included."
296+
)
297+
test_config_parser.add_argument(
298+
'--step-size',
299+
type=int,
300+
default=2,
301+
help='Step size for concurrency values (default: 2)'
302+
)
303+
test_config_parser.add_argument(
304+
'--test-mode',
305+
action='store_true',
306+
help='Generate only the lowest concurrency value for each TP level'
307+
)
308+
test_config_parser.add_argument(
309+
'-h', '--help',
310+
action='help',
311+
help='Show this help message and exit'
312+
)
313+
314+
args = parser.parse_args()
315+
316+
# Load configuration files
317+
all_config_data = load_config_files(args.config_files)
318+
319+
# Route to appropriate function based on subcommand
320+
if args.command == 'full-sweep':
321+
matrix_values = generate_full_sweep(args, all_config_data)
322+
elif args.command == 'test-config':
323+
matrix_values = generate_test_config(args, all_config_data)
324+
else:
325+
parser.error(f"Unknown command: {args.command}")
326+
327+
print(json.dumps(matrix_values))
328+
return matrix_values
329+
330+
if __name__ == "__main__":
331+
main()

0 commit comments

Comments
 (0)