@@ -518,112 +518,111 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
518518for v in program_map .values ():
519519 random .Random (42 ).shuffle (v )
520520
521+
521522# select prompts that fit the batch size criteria
522- valid_prompts = []
523- if custom_shape :
524- for program_criteria_seq , valid_prompt_shapes in program_map .items ():
525- for valid_prompt_shape in valid_prompt_shapes :
526- if valid_prompt_shape == custom_shape :
527- enforce_sizes = [valid_prompt_shape [1 ]]
528- input_ids , extra_kwargs , sample_key = __prepare_inputs (
529- valid_prompt_shape [0 ],
530- valid_prompt_shape [1 ],
531- tokenizer ,
532- enforce_sizes = enforce_sizes ,
533- )
534- valid_prompts = [
535- (
523+ def get_program_prompt_list ():
524+ if custom_shape :
525+ prompt_found = 0
526+ for program_criteria_seq , valid_prompt_shapes in program_map .items ():
527+ for valid_prompt_shape in valid_prompt_shapes :
528+ if valid_prompt_shape == custom_shape :
529+ enforce_sizes = [valid_prompt_shape [1 ]]
530+ input_ids , extra_kwargs , sample_key = __prepare_inputs (
531+ valid_prompt_shape [0 ],
532+ valid_prompt_shape [1 ],
533+ tokenizer ,
534+ enforce_sizes = enforce_sizes ,
535+ )
536+ prompt_found = 1
537+ yield (
536538 program_criteria_seq [0 ].program_id ,
537539 custom_shape ,
538540 input_ids ,
539541 extra_kwargs ,
540542 sample_key ,
541543 )
542- ]
544+ break
545+ if prompt_found :
543546 break
544- if len (valid_prompts ) > 0 :
545- break
546- else :
547- for program_info in programs :
548- program_id = program_info .program_id
549- batch_size_limit = program_info .batch_size_limit
550- batch_size_limit_type = program_info .batch_size_limit_type
551- prompt_length_limit = program_info .prompt_length_limit
552- prompt_length_limit_type = program_info .prompt_length_limit_type
553-
554- filtered_program_map = program_map
555- if program_id .isnumeric ():
556- filtered_program_map = {
557- k : v
558- for k , v in program_map .items ()
559- if k [0 ] == program_criteria_list [int (program_id )]
560- }
561- used_keys = set ()
562- # for each program, we need to check if we have a shape that satisfies the --programs request
563- for program_seq_key , valid_prompt_shapes in filtered_program_map .items ():
564- # if ? or numeric => we need to check if we have found at least one valid key to stop
565- if (program_id == "?" or program_id .isnumeric ()) and len (used_keys ) > 0 :
566- break
567- # if * => we need to see if we have found the first key to see if we should skip
568- elif program_id == "*" and program_seq_key [0 ] in used_keys :
569- continue
570-
571- for valid_prompt_shape in valid_prompt_shapes :
572- # make sure the criteria for batch limit and prompt limit is satisfied
573- # eval is safe here because we have limited what type and limit can be before
574-
575- batch_check = eval (
576- f"valid_prompt_shape[0] { batch_size_limit_type } { batch_size_limit } "
577- )
578- prompt_check = eval (
579- f"valid_prompt_shape[1] { prompt_length_limit_type } { prompt_length_limit } "
580- )
581- if batch_check and prompt_check :
582- # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length
583- # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
584- # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
585- enforce_sizes = [valid_prompt_shape [1 ]]
586- if args .enforce_homogeneous_prompt_programs :
587- # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
588- tkv_cutoff = 1 << (valid_prompt_shape [1 ].bit_length () - 1 )
589- possible_seq_lengths = [
590- _ for _ in range (tkv_cutoff , valid_prompt_shape [1 ], 64 )
591- ]
592- # favor sequences that are close to the valid prompt length
593- possible_seq_lengths .reverse ()
594- enforce_sizes = enforce_sizes + list (
595- itertools .islice (
596- itertools .cycle (possible_seq_lengths ),
597- valid_prompt_shape [0 ] - 1 ,
547+ else :
548+ for program_info in programs :
549+ program_id = program_info .program_id
550+ batch_size_limit = program_info .batch_size_limit
551+ batch_size_limit_type = program_info .batch_size_limit_type
552+ prompt_length_limit = program_info .prompt_length_limit
553+ prompt_length_limit_type = program_info .prompt_length_limit_type
554+
555+ filtered_program_map = program_map
556+ if program_id .isnumeric ():
557+ filtered_program_map = {
558+ k : v
559+ for k , v in program_map .items ()
560+ if k [0 ] == program_criteria_list [int (program_id )]
561+ }
562+ used_keys = set ()
563+ # for each program, we need to check if we have a shape that satisfies the --programs request
564+ for program_seq_key , valid_prompt_shapes in filtered_program_map .items ():
565+ # if ? or numeric => we need to check if we have found at least one valid key to stop
566+ if (program_id == "?" or program_id .isnumeric ()) and len (used_keys ) > 0 :
567+ break
568+ # if * => we need to see if we have found the first key to see if we should skip
569+ elif program_id == "*" and program_seq_key [0 ] in used_keys :
570+ continue
571+
572+ for valid_prompt_shape in valid_prompt_shapes :
573+ # make sure the criteria for batch limit and prompt limit is satisfied
574+ # eval is safe here because we have limited what type and limit can be before
575+
576+ batch_check = eval (
577+ f"valid_prompt_shape[0] { batch_size_limit_type } { batch_size_limit } "
578+ )
579+ prompt_check = eval (
580+ f"valid_prompt_shape[1] { prompt_length_limit_type } { prompt_length_limit } "
581+ )
582+ if batch_check and prompt_check :
583+ # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length
584+ # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
585+ # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
586+ enforce_sizes = [valid_prompt_shape [1 ]]
587+ if args .enforce_homogeneous_prompt_programs :
588+ # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
589+ tkv_cutoff = 1 << (valid_prompt_shape [1 ].bit_length () - 1 )
590+ possible_seq_lengths = [
591+ _ for _ in range (tkv_cutoff , valid_prompt_shape [1 ], 64 )
592+ ]
593+ # favor sequences that are close to the valid prompt length
594+ possible_seq_lengths .reverse ()
595+ enforce_sizes = enforce_sizes + list (
596+ itertools .islice (
597+ itertools .cycle (possible_seq_lengths ),
598+ valid_prompt_shape [0 ] - 1 ,
599+ )
600+ )
601+ try :
602+ input_ids , extra_kwargs , sample_key = __prepare_inputs (
603+ valid_prompt_shape [0 ],
604+ valid_prompt_shape [1 ],
605+ tokenizer ,
606+ enforce_sizes = enforce_sizes ,
598607 )
599- )
600- try :
601- input_ids , extra_kwargs , sample_key = __prepare_inputs (
602- valid_prompt_shape [0 ],
603- valid_prompt_shape [1 ],
604- tokenizer ,
605- enforce_sizes = enforce_sizes ,
606- )
607- valid_prompts .append (
608- (
608+ used_keys .add (program_seq_key [0 ])
609+ yield (
609610 program_seq_key [0 ],
610611 valid_prompt_shape ,
611612 input_ids ,
612613 extra_kwargs ,
613614 sample_key ,
614615 )
615- )
616- used_keys .add (program_seq_key [0 ])
617- break
618- except ValueError :
619- dprint (
620- f"No valid sample exists in dataset for this input shape { valid_prompt_shape } "
621- )
622-
623- if len (used_keys ) == 0 and local_rank == 0 :
624- dprint (
625- f"no valid prompt shape was found which would result in program { program_id } that satisfied batch{ batch_size_limit_type } { batch_size_limit } and prompt_length{ prompt_length_limit_type } { prompt_length_limit } "
626- )
616+ break
617+ except ValueError :
618+ dprint (
619+ f"No valid sample exists in dataset for this input shape { valid_prompt_shape } "
620+ )
621+
622+ if len (used_keys ) == 0 and local_rank == 0 :
623+ dprint (
624+ f"no valid prompt shape was found which would result in program { program_id } that satisfied batch{ batch_size_limit_type } { batch_size_limit } and prompt_length{ prompt_length_limit_type } { prompt_length_limit } "
625+ )
627626
628627
629628# metric calculator based on the cross-entropy and mean diff for each decode step
@@ -642,7 +641,13 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
642641
643642failed_cases = []
644643# for each program and valid prompt (batch size, sequence length)
645- for program_id , valid_prompt , input_ids , extra_kwargs , sample_key in valid_prompts :
644+ for (
645+ program_id ,
646+ valid_prompt ,
647+ input_ids ,
648+ extra_kwargs ,
649+ sample_key ,
650+ ) in get_program_prompt_list ():
646651 extra_kwargs ["attn_name" ] = ATTN_NAME
647652 if (
648653 "granite-3.3-8b-instruct" in model_variant
0 commit comments