@@ -209,6 +209,25 @@ def to_dict(self):
209209 }
210210
211211
212+ class PrefixCacheTestTokenizer :
213+ """A simple tokenizer for testing prefix caching.
214+
215+ This tokenizer converts each character in a string to its integer ordinal
216+ value during encoding, and converts a list of integer ordinals back to
217+ a string during decoding. It's designed for testing scenarios, particularly
218+ those involving prefix caching, where a basic, predictable tokenizer is
219+ needed.
220+ """
221+
222+ def encode (self , s : str , ** kwargs ) -> list [int ]:
223+ del kwargs
224+ return [ord (c ) for c in s ]
225+
226+ def decode (self , token_ids : list [int ], ** kwargs ) -> str :
227+ del kwargs
228+ return "" .join ([chr (token_id ) for token_id in token_ids ])
229+
230+
212231def get_tokenizer (
213232 model_id : str ,
214233 tokenizer_name : str ,
@@ -219,6 +238,9 @@ def get_tokenizer(
219238 if tokenizer_name == "test" :
220239 print ("Using test tokenizer" )
221240 return "test"
241+ elif tokenizer_name == "prefix_cache_test" :
242+ print ("Using prefix_cache_test tokenizer" )
243+ return PrefixCacheTestTokenizer ()
222244 elif use_hf_tokenizer :
223245 # Please accept agreement to access private/gated models in HF, and
224246 # follow up instructions below to set up access token
@@ -329,6 +351,98 @@ def load_mmlu_dataset_csv(dataset_path: str) -> tuple[Any, dict[str, str]]:
329351 return combined_dataset , prompts_per_subject
330352
331353
354+ def load_mock_prefix_cache_test_input_requests (
355+ prompt_len : int ,
356+ output_len : int ,
357+ common_prefix_len : int ,
358+ num_samples : int ,
359+ ) -> list [InputRequest ]:
360+ """Generates a mock dataset for testing prefix cache.
361+
362+ The prefix part of each prompt is a sub-string of a single master string.
363+ The length of this prefix part for each sample is drawn from a normal
364+ distribution with its mean set to `common_prefix_len`, and values are
365+ clipped to the range [0, `prompt_len`].
366+ The tokenizer is assumed to treat each character as a token.
367+
368+ Args:
369+ prompt_len: The total length of each generated prompt string.
370+ output_len: The length of each generated output string.
371+ common_prefix_len: The target mean for the length of the prefix part
372+ of each prompt. These prefixes are derived from a
373+ shared master string.
374+ num_samples: The number of (prompt, output) pairs to generate.
375+
376+ Returns:
377+ A list of InputRequest objects.
378+ """
379+ if not 0 <= common_prefix_len <= prompt_len :
380+ raise ValueError (
381+ "Target mean common_prefix_len must be between 0 and prompt_len,"
382+ f" inclusive. Got common_prefix_len={ common_prefix_len } , "
383+ f"prompt_len={ prompt_len } "
384+ )
385+ if any (arg <= 0 for arg in [prompt_len , output_len , num_samples ]):
386+ raise ValueError (
387+ "prompt_len, output_len, and num_samples cannot be 0 or negative."
388+ )
389+
390+ input_requests : list [InputRequest ] = []
391+
392+ # Generate a master string from which all prefixes will be derived.
393+ # This ensures that prefixes of the same length are identical,
394+ # and shorter prefixes are actual prefixes of longer ones.
395+ master_potential_prefix = "" .join (
396+ random .choices ("ABCDEFGHIJKLMNOPQRSTUVWXYZ" , k = prompt_len )
397+ )
398+
399+ # Generate prefix lengths for each sample from a normal distribution
400+ scale = prompt_len / 3.0 # Standard deviation for the normal distribution
401+
402+ generated_prefix_lengths = np .random .normal (
403+ loc = common_prefix_len , scale = scale , size = num_samples
404+ )
405+ generated_prefix_lengths = (
406+ np .clip (generated_prefix_lengths , 0 , prompt_len ).round ().astype (int )
407+ )
408+
409+ for idx in range (num_samples ):
410+ current_actual_prefix_len = generated_prefix_lengths [idx ]
411+
412+ actual_prefix_for_sample = master_potential_prefix [
413+ :current_actual_prefix_len
414+ ]
415+
416+ current_unique_len = prompt_len - current_actual_prefix_len
417+ # This should not happen if generated_prefix_lengths is clipped correctly
418+ if current_unique_len < 0 :
419+ current_unique_len = 0 # Safeguard
420+ current_actual_prefix_len = prompt_len
421+ actual_prefix_for_sample = master_potential_prefix [
422+ :current_actual_prefix_len
423+ ]
424+
425+ unique_suffix_str = "" .join (
426+ random .choices (
427+ "abcdefghijklmnopqrstuvwxyz0123456789" , k = current_unique_len
428+ )
429+ )
430+
431+ prompt_str = actual_prefix_for_sample + unique_suffix_str
432+
433+ output_str = "" .join (random .choices ("!@#$%^&*()_+" , k = output_len ))
434+
435+ request = InputRequest (
436+ prompt = prompt_str ,
437+ prompt_len = len (prompt_str ),
438+ output = output_str ,
439+ output_len = len (output_str ),
440+ sample_idx = idx ,
441+ )
442+ input_requests .append (request )
443+ return input_requests
444+
445+
332446def gen_mmlu_qa (data : Any , mmlu_method : str = "" ) -> str :
333447
334448 output = ""
@@ -893,6 +1007,7 @@ def parse_args() -> argparse.Namespace:
8931007 "mmlu" ,
8941008 "math500" ,
8951009 "longcontext" ,
1010+ "prefix_cache_test" ,
8961011 ],
8971012 help = "The dataset name." ,
8981013 )
@@ -1086,6 +1201,12 @@ def parse_args() -> argparse.Namespace:
10861201 choices = ["HELM" , "Harness" , "" ],
10871202 help = "mmlu method/format to generate shots" ,
10881203 )
1204+ parser .add_argument (
1205+ "--prefix-cache-test-common-len" ,
1206+ type = int ,
1207+ default = 64 ,
1208+ help = "Common prefix length for the prefix cache test dataset." ,
1209+ )
10891210 return parser .parse_args ()
10901211
10911212
@@ -1112,6 +1233,13 @@ def main(args: argparse.Namespace):
11121233 input_requests = mock_requests (
11131234 args .total_mock_requests
11141235 ) # e.g. [("AB", 2, "AB", 3)]
1236+ elif args .dataset == "prefix_cache_test" :
1237+ input_requests = load_mock_prefix_cache_test_input_requests (
1238+ prompt_len = args .max_input_length ,
1239+ output_len = args .max_output_length ,
1240+ common_prefix_len = args .prefix_cache_test_common_len ,
1241+ num_samples = args .num_prompts ,
1242+ )
11151243 else :
11161244 dataset = []
11171245 if args .dataset == "openorca" :
0 commit comments