@@ -452,3 +452,80 @@ def generate(
452452 if timing != "" :
453453 return result , times
454454 return result
455+
456+
457+ VLLM_DT_MAX_BATCH_TKV_LIMIT = 131072
458+
459+
460+ class ProgramCriteria :
461+ def __init__ (
462+ self , program_id , max_batch , max_tkv , batch_granularity , tkv_granularity
463+ ):
464+ self .program_id = program_id
465+ self .max_batch = max_batch
466+ self .max_tkv = max_tkv
467+ self .batch_granularity = batch_granularity
468+ self .tkv_granularity = tkv_granularity
469+
470+ def is_possible (self , batch_size , tkv ):
471+ return batch_size * tkv <= VLLM_DT_MAX_BATCH_TKV_LIMIT
472+
473+ def calculate_padding (self , batch_size , tkv ):
474+ min_batch_req = (
475+ ((batch_size - 1 ) // self .batch_granularity ) + 1
476+ ) * self .batch_granularity
477+ min_tkv_req = (((tkv - 1 ) // self .tkv_granularity ) + 1 ) * self .tkv_granularity
478+ return (min_batch_req * min_tkv_req ) - (batch_size * tkv )
479+
480+ def __str__ (self ):
481+ return f"ProgramCriteria(program_id={ self .program_id } )"
482+
483+ def __eq__ (self , other ):
484+ if not isinstance (other , ProgramCriteria ):
485+ return NotImplemented
486+ return (
487+ self .program_id == other .program_id
488+ and self .max_batch == other .max_batch
489+ and self .max_tkv == other .max_tkv
490+ and self .batch_granularity == other .batch_granularity
491+ and self .tkv_granularity == other .tkv_granularity
492+ )
493+
494+ def __hash__ (self ):
495+ return hash (self .program_id ) # Hash based on immutable attributes
496+
497+
498+ def get_programs_prompts (
499+ program_criteria_list , multiple , max_batch_size , max_tkv , program_cycles
500+ ):
501+ program_map = {}
502+
503+ for batch_size in range (1 , max_batch_size + 1 ):
504+ for prompt_len in range (multiple , max_tkv - program_cycles , multiple ):
505+ possible_program_switches = ((program_cycles - 1 ) // multiple ) + 1
506+ resolved_programs = [None ] * possible_program_switches
507+ for program_criteria in program_criteria_list :
508+ for program_index in range (possible_program_switches ):
509+ context_length = prompt_len + (multiple * program_index ) + 1
510+
511+ if program_criteria .is_possible (batch_size , context_length ):
512+ padding = program_criteria .calculate_padding (
513+ batch_size , context_length
514+ )
515+ if (
516+ resolved_programs [program_index ] is None
517+ or padding < resolved_programs [program_index ][1 ]
518+ ):
519+ resolved_programs [program_index ] = (
520+ program_criteria ,
521+ padding ,
522+ )
523+
524+ if all (p is not None for p in resolved_programs ):
525+ key = tuple (p [0 ] for p in resolved_programs )
526+ if key in program_map :
527+ program_map [key ].append ((batch_size , prompt_len ))
528+ else :
529+ program_map [key ] = [(batch_size , prompt_len )]
530+
531+ return program_map
0 commit comments