Skip to content

Commit 45a15f7

Browse files
authored
Merge pull request #109 from foundation-model-stack/paged_test
Add a script which drives particular programs during decode when using paged model
2 parents e42c33f + eeee927 commit 45a15f7

2 files changed

Lines changed: 476 additions & 0 deletions

File tree

aiu_fms_testing_utils/utils/paged.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,3 +452,80 @@ def generate(
452452
if timing != "":
453453
return result, times
454454
return result
455+
456+
457+
VLLM_DT_MAX_BATCH_TKV_LIMIT = 131072
458+
459+
460+
class ProgramCriteria:
461+
def __init__(
462+
self, program_id, max_batch, max_tkv, batch_granularity, tkv_granularity
463+
):
464+
self.program_id = program_id
465+
self.max_batch = max_batch
466+
self.max_tkv = max_tkv
467+
self.batch_granularity = batch_granularity
468+
self.tkv_granularity = tkv_granularity
469+
470+
def is_possible(self, batch_size, tkv):
471+
return batch_size * tkv <= VLLM_DT_MAX_BATCH_TKV_LIMIT
472+
473+
def calculate_padding(self, batch_size, tkv):
474+
min_batch_req = (
475+
((batch_size - 1) // self.batch_granularity) + 1
476+
) * self.batch_granularity
477+
min_tkv_req = (((tkv - 1) // self.tkv_granularity) + 1) * self.tkv_granularity
478+
return (min_batch_req * min_tkv_req) - (batch_size * tkv)
479+
480+
def __str__(self):
481+
return f"ProgramCriteria(program_id={self.program_id})"
482+
483+
def __eq__(self, other):
484+
if not isinstance(other, ProgramCriteria):
485+
return NotImplemented
486+
return (
487+
self.program_id == other.program_id
488+
and self.max_batch == other.max_batch
489+
and self.max_tkv == other.max_tkv
490+
and self.batch_granularity == other.batch_granularity
491+
and self.tkv_granularity == other.tkv_granularity
492+
)
493+
494+
def __hash__(self):
495+
return hash(self.program_id) # Hash based on immutable attributes
496+
497+
498+
def get_programs_prompts(
499+
program_criteria_list, multiple, max_batch_size, max_tkv, program_cycles
500+
):
501+
program_map = {}
502+
503+
for batch_size in range(1, max_batch_size + 1):
504+
for prompt_len in range(multiple, max_tkv - program_cycles, multiple):
505+
possible_program_switches = ((program_cycles - 1) // multiple) + 1
506+
resolved_programs = [None] * possible_program_switches
507+
for program_criteria in program_criteria_list:
508+
for program_index in range(possible_program_switches):
509+
context_length = prompt_len + (multiple * program_index) + 1
510+
511+
if program_criteria.is_possible(batch_size, context_length):
512+
padding = program_criteria.calculate_padding(
513+
batch_size, context_length
514+
)
515+
if (
516+
resolved_programs[program_index] is None
517+
or padding < resolved_programs[program_index][1]
518+
):
519+
resolved_programs[program_index] = (
520+
program_criteria,
521+
padding,
522+
)
523+
524+
if all(p is not None for p in resolved_programs):
525+
key = tuple(p[0] for p in resolved_programs)
526+
if key in program_map:
527+
program_map[key].append((batch_size, prompt_len))
528+
else:
529+
program_map[key] = [(batch_size, prompt_len)]
530+
531+
return program_map

0 commit comments

Comments
 (0)