44from fms .models import get_model
55from fms .utils .generation import pad_input_ids
66import itertools
7+ import warnings
8+ import re
79import torch
810from torch import distributed as dist
911from aiu_fms_testing_utils .testing .validation import (
171173 # the compiler supports certain max context lengths (VLLM_DT_MAX_CONTEXT_LEN)
172174 # this will ensure that we select smallest supported VLLM_DT_MAX_CONTEXT_LEN that fits the largest possible context (prompt size + max_new_tokens)
173175 __largest_context = max (common_seq_lengths ) + max (common_max_new_tokens )
174- __supported_context_lengths = [256 , 512 , 1024 , 2048 , 4096 , 8192 ]
176+ __supported_context_lengths = [256 , 512 , 1024 , 2048 , 4096 , 8192 , 16384 , 32768 ]
175177 os .environ ["VLLM_DT_MAX_CONTEXT_LEN" ] = str (
176178 __supported_context_lengths [
177179 bisect .bisect_left (__supported_context_lengths , __largest_context )
@@ -301,14 +303,28 @@ def __maybe_get_gptq_kwargs(model_path):
301303
302304
303305def __prepare_inputs (batch_size , seq_length , tokenizer , seed = 0 ):
304- prompts_and_sizes = sample_sharegpt_requests (
305- SHARE_GPT_DATASET_PATH ,
306- batch_size ,
307- tokenizer ,
308- seq_length // 2 ,
309- seq_length ,
310- seed ,
311- )
306+ if "paged" in ATTN_NAME :
307+ prompts_and_sizes = sample_sharegpt_requests (
308+ SHARE_GPT_DATASET_PATH ,
309+ batch_size ,
310+ tokenizer ,
311+ 32 ,
312+ seq_length ,
313+ seed ,
314+ enforce_heterogeneous = True ,
315+ enforce_sizes = [seq_length ], # ensure at least the max seq length is sampled
316+ pad_multiple = 64 ,
317+ )
318+ else :
319+ prompts_and_sizes = sample_sharegpt_requests (
320+ SHARE_GPT_DATASET_PATH ,
321+ batch_size ,
322+ tokenizer ,
323+ seq_length // 2 ,
324+ seq_length ,
325+ seed ,
326+ )
327+
312328 prompt_list = []
313329 for prompt , _ in prompts_and_sizes :
314330 prompt_list .append (tokenizer .encode (prompt , return_tensors = "pt" ).squeeze (0 ))
@@ -341,25 +357,44 @@ def __filter_before_eos(metrics, filter_indexes):
341357
342358
343359def __get_validation_info_full_path (
344- model_path , batch_size , seq_length , max_new_tokens , seed , device_type = "cpu"
360+ model_path ,
361+ batch_size ,
362+ seq_length ,
363+ max_new_tokens ,
364+ seed ,
365+ attn_type : str ,
366+ device_type = "cpu" ,
345367):
346- validation_file_name = f"{ get_default_validation_prefix (model_path , max_new_tokens , batch_size , seq_length , 'fp16' )} .{ device_type } _validation_info.{ seed } .out"
368+ validation_file_name = f"{ get_default_validation_prefix (model_path , max_new_tokens , batch_size , seq_length , 'fp16' , attn_type )} .{ device_type } _validation_info.{ seed } .out"
347369 full_path = os .path .join (validation_info_dir , validation_file_name )
348370 return full_path
349371
350372
351373def __load_validation_info (
352- model_path , batch_size , seq_length , max_new_tokens , tokenizer , seed
374+ model_path , batch_size , seq_length , max_new_tokens , tokenizer , seed , attn_type : str
353375):
376+ # if path doesn't exist and paged isn't in the attention name, remove `attn_type` and recheck again, warn that we will no longer in the future have paths without 'attn_type'
354377 full_path = __get_validation_info_full_path (
355- model_path , batch_size , seq_length , max_new_tokens , seed
378+ model_path , batch_size , seq_length , max_new_tokens , seed , attn_type
356379 )
357380
358381 if os .path .exists (full_path ):
359382 dprint (f"cpu validation info found for seed={ seed } -- loading it" )
360383 return load_validation_information (full_path , "logits" , batch_size , tokenizer )
361- else :
362- return None
384+ elif "paged" not in attn_type :
385+ # This regex applies to a very specific file name format
386+ modified_full_path = re .sub (r"_attn-type[^.]*" , "" , full_path )
387+
388+ if os .path .exists (modified_full_path ):
389+ warnings .warn (
390+ f"All future paths should contain attn_type prefix information in path name, please modify { full_path = } to { modified_full_path = } " ,
391+ stacklevel = 2 ,
392+ )
393+ dprint (f"cpu validation info found for seed={ seed } -- loading it" )
394+ return load_validation_information (
395+ modified_full_path , "logits" , batch_size , tokenizer
396+ )
397+ return None
363398
364399
365400class PersistentModel :
@@ -513,7 +548,7 @@ def test_common_shapes(
513548
514549 # generate cpu validation info
515550 cpu_validation_info = __load_validation_info (
516- model_path , batch_size , seq_length , max_new_tokens , tokenizer , 0
551+ model_path , batch_size , seq_length , max_new_tokens , tokenizer , 0 , ATTN_NAME
517552 )
518553 if cpu_validation_info is None :
519554 cpu_validation_info = extract_validation_information (
@@ -529,7 +564,7 @@ def test_common_shapes(
529564 if save_validation_info_outputs :
530565 cpu_validation_info .save (
531566 __get_validation_info_full_path (
532- model_path , batch_size , seq_length , max_new_tokens , 0
567+ model_path , batch_size , seq_length , max_new_tokens , 0 , ATTN_NAME
533568 )
534569 )
535570 cpu_static_tokens = cpu_validation_info .get_info ("tokens" )
@@ -591,7 +626,13 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
591626 )
592627 extra_kwargs ["attn_name" ] = ATTN_NAME
593628 cpu_validation_info = __load_validation_info (
594- model_path , batch_size , seq_length , max_new_tokens , tokenizer , i
629+ model_path ,
630+ batch_size ,
631+ seq_length ,
632+ max_new_tokens ,
633+ tokenizer ,
634+ i ,
635+ ATTN_NAME ,
595636 )
596637 if cpu_validation_info is None :
597638 cpu_validation_info = extract_validation_information (
@@ -609,7 +650,12 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
609650 if save_validation_info_outputs :
610651 cpu_validation_info .save (
611652 __get_validation_info_full_path (
612- model_path , batch_size , seq_length , max_new_tokens , i
653+ model_path ,
654+ batch_size ,
655+ seq_length ,
656+ max_new_tokens ,
657+ i ,
658+ ATTN_NAME ,
613659 )
614660 )
615661 cpu_static_tokens = cpu_validation_info .get_info ("tokens" )
@@ -634,7 +680,13 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
634680 if save_validation_info_outputs :
635681 aiu_validation_info .save (
636682 __get_validation_info_full_path (
637- model_path , batch_size , seq_length , max_new_tokens , i , "aiu"
683+ model_path ,
684+ batch_size ,
685+ seq_length ,
686+ max_new_tokens ,
687+ i ,
688+ ATTN_NAME ,
689+ "aiu" ,
638690 )
639691 )
640692
0 commit comments