Skip to content

Commit 9595e3d

Browse files
authored
Merge pull request #170 from foundation-model-stack/refactor_get_prompts
Refactor get valid prompts - for memory optimization
2 parents 914dd6c + e308e67 commit 9595e3d

1 file changed

Lines changed: 97 additions & 92 deletions

File tree

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 97 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -518,112 +518,111 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
518518
for v in program_map.values():
519519
random.Random(42).shuffle(v)
520520

521+
521522
# select prompts that fit the batch size criteria
522-
valid_prompts = []
523-
if custom_shape:
524-
for program_criteria_seq, valid_prompt_shapes in program_map.items():
525-
for valid_prompt_shape in valid_prompt_shapes:
526-
if valid_prompt_shape == custom_shape:
527-
enforce_sizes = [valid_prompt_shape[1]]
528-
input_ids, extra_kwargs, sample_key = __prepare_inputs(
529-
valid_prompt_shape[0],
530-
valid_prompt_shape[1],
531-
tokenizer,
532-
enforce_sizes=enforce_sizes,
533-
)
534-
valid_prompts = [
535-
(
523+
def get_program_prompt_list():
524+
if custom_shape:
525+
prompt_found = 0
526+
for program_criteria_seq, valid_prompt_shapes in program_map.items():
527+
for valid_prompt_shape in valid_prompt_shapes:
528+
if valid_prompt_shape == custom_shape:
529+
enforce_sizes = [valid_prompt_shape[1]]
530+
input_ids, extra_kwargs, sample_key = __prepare_inputs(
531+
valid_prompt_shape[0],
532+
valid_prompt_shape[1],
533+
tokenizer,
534+
enforce_sizes=enforce_sizes,
535+
)
536+
prompt_found = 1
537+
yield (
536538
program_criteria_seq[0].program_id,
537539
custom_shape,
538540
input_ids,
539541
extra_kwargs,
540542
sample_key,
541543
)
542-
]
544+
break
545+
if prompt_found:
543546
break
544-
if len(valid_prompts) > 0:
545-
break
546-
else:
547-
for program_info in programs:
548-
program_id = program_info.program_id
549-
batch_size_limit = program_info.batch_size_limit
550-
batch_size_limit_type = program_info.batch_size_limit_type
551-
prompt_length_limit = program_info.prompt_length_limit
552-
prompt_length_limit_type = program_info.prompt_length_limit_type
553-
554-
filtered_program_map = program_map
555-
if program_id.isnumeric():
556-
filtered_program_map = {
557-
k: v
558-
for k, v in program_map.items()
559-
if k[0] == program_criteria_list[int(program_id)]
560-
}
561-
used_keys = set()
562-
# for each program, we need to check if we have a shape that satisfies the --programs request
563-
for program_seq_key, valid_prompt_shapes in filtered_program_map.items():
564-
# if ? or numeric => we need to check if we have found at least one valid key to stop
565-
if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0:
566-
break
567-
# if * => we need to see if we have found the first key to see if we should skip
568-
elif program_id == "*" and program_seq_key[0] in used_keys:
569-
continue
570-
571-
for valid_prompt_shape in valid_prompt_shapes:
572-
# make sure the criteria for batch limit and prompt limit is satisfied
573-
# eval is safe here because we have limited what type and limit can be before
574-
575-
batch_check = eval(
576-
f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}"
577-
)
578-
prompt_check = eval(
579-
f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}"
580-
)
581-
if batch_check and prompt_check:
582-
# when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length
583-
# if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
584-
# in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
585-
enforce_sizes = [valid_prompt_shape[1]]
586-
if args.enforce_homogeneous_prompt_programs:
587-
# this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
588-
tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1)
589-
possible_seq_lengths = [
590-
_ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64)
591-
]
592-
# favor sequences that are close to the valid prompt length
593-
possible_seq_lengths.reverse()
594-
enforce_sizes = enforce_sizes + list(
595-
itertools.islice(
596-
itertools.cycle(possible_seq_lengths),
597-
valid_prompt_shape[0] - 1,
547+
else:
548+
for program_info in programs:
549+
program_id = program_info.program_id
550+
batch_size_limit = program_info.batch_size_limit
551+
batch_size_limit_type = program_info.batch_size_limit_type
552+
prompt_length_limit = program_info.prompt_length_limit
553+
prompt_length_limit_type = program_info.prompt_length_limit_type
554+
555+
filtered_program_map = program_map
556+
if program_id.isnumeric():
557+
filtered_program_map = {
558+
k: v
559+
for k, v in program_map.items()
560+
if k[0] == program_criteria_list[int(program_id)]
561+
}
562+
used_keys = set()
563+
# for each program, we need to check if we have a shape that satisfies the --programs request
564+
for program_seq_key, valid_prompt_shapes in filtered_program_map.items():
565+
# if ? or numeric => we need to check if we have found at least one valid key to stop
566+
if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0:
567+
break
568+
# if * => we need to see if we have found the first key to see if we should skip
569+
elif program_id == "*" and program_seq_key[0] in used_keys:
570+
continue
571+
572+
for valid_prompt_shape in valid_prompt_shapes:
573+
# make sure the criteria for batch limit and prompt limit is satisfied
574+
# eval is safe here because we have limited what type and limit can be before
575+
576+
batch_check = eval(
577+
f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}"
578+
)
579+
prompt_check = eval(
580+
f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}"
581+
)
582+
if batch_check and prompt_check:
583+
# when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length
584+
# if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
585+
# in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
586+
enforce_sizes = [valid_prompt_shape[1]]
587+
if args.enforce_homogeneous_prompt_programs:
588+
# this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
589+
tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1)
590+
possible_seq_lengths = [
591+
_ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64)
592+
]
593+
# favor sequences that are close to the valid prompt length
594+
possible_seq_lengths.reverse()
595+
enforce_sizes = enforce_sizes + list(
596+
itertools.islice(
597+
itertools.cycle(possible_seq_lengths),
598+
valid_prompt_shape[0] - 1,
599+
)
600+
)
601+
try:
602+
input_ids, extra_kwargs, sample_key = __prepare_inputs(
603+
valid_prompt_shape[0],
604+
valid_prompt_shape[1],
605+
tokenizer,
606+
enforce_sizes=enforce_sizes,
598607
)
599-
)
600-
try:
601-
input_ids, extra_kwargs, sample_key = __prepare_inputs(
602-
valid_prompt_shape[0],
603-
valid_prompt_shape[1],
604-
tokenizer,
605-
enforce_sizes=enforce_sizes,
606-
)
607-
valid_prompts.append(
608-
(
608+
used_keys.add(program_seq_key[0])
609+
yield (
609610
program_seq_key[0],
610611
valid_prompt_shape,
611612
input_ids,
612613
extra_kwargs,
613614
sample_key,
614615
)
615-
)
616-
used_keys.add(program_seq_key[0])
617-
break
618-
except ValueError:
619-
dprint(
620-
f"No valid sample exists in dataset for this input shape {valid_prompt_shape}"
621-
)
622-
623-
if len(used_keys) == 0 and local_rank == 0:
624-
dprint(
625-
f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}"
626-
)
616+
break
617+
except ValueError:
618+
dprint(
619+
f"No valid sample exists in dataset for this input shape {valid_prompt_shape}"
620+
)
621+
622+
if len(used_keys) == 0 and local_rank == 0:
623+
dprint(
624+
f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}"
625+
)
627626

628627

629628
# metric calculator based on the cross-entropy and mean diff for each decode step
@@ -642,7 +641,13 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
642641

643642
failed_cases = []
644643
# for each program and valid prompt (batch size, sequence length)
645-
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
644+
for (
645+
program_id,
646+
valid_prompt,
647+
input_ids,
648+
extra_kwargs,
649+
sample_key,
650+
) in get_program_prompt_list():
646651
extra_kwargs["attn_name"] = ATTN_NAME
647652
if (
648653
"granite-3.3-8b-instruct" in model_variant

0 commit comments

Comments
 (0)