@@ -208,6 +208,7 @@ def get_tokenizer(
208208 model_id : str ,
209209 tokenizer_name : str ,
210210 use_hf_tokenizer : bool ,
211+ access_token : str | None = None ,
211212) -> Any :
212213 """Return a tokenizer or a tokenizer placholder."""
213214 if tokenizer_name == "test" :
@@ -218,7 +219,7 @@ def get_tokenizer(
218219 # follow up instructions below to set up access token
219220 # https://huggingface.co/docs/transformers.js/en/guides/private
220221 print (f"Using HuggingFace tokenizer: { tokenizer_name } " )
221- return AutoTokenizer .from_pretrained (tokenizer_name )
222+ return AutoTokenizer .from_pretrained (tokenizer_name , token = access_token )
222223 elif model_id == "llama-3" :
223224 # Llama 3 uses a tiktoken tokenizer.
224225 print (f"Using llama-3 tokenizer: { tokenizer_name } " )
@@ -386,11 +387,9 @@ def load_math500_dataset(dataset_path: str) -> list[tuple[Any, Any]]:
386387
387388
388389def tokenize_dataset (
389- dataset : list [tuple [Any , Any , Any ]],
390- tokenizer : Any ,
390+ dataset : list [tuple [Any , Any , Any ]], tokenizer : Any , use_chat_template : bool
391391) -> list [tuple [str , Any , str , int , int , int ]]:
392392 tokenized_dataset = []
393-
394393 for prompt , output , idx in dataset :
395394 if isinstance (output , tuple ):
396395 output_len = len (tokenizer .encode (output [0 ]))
@@ -399,7 +398,12 @@ def tokenize_dataset(
399398 output_len = len (tokenizer .encode (output ))
400399 output_tokens = output
401400
402- prompt_tokens = tokenizer .encode (prompt )
401+ if use_chat_template :
402+ prompt_tokens = tokenizer .apply_chat_template (
403+ [{"role" : "user" , "content" : prompt }], add_generation_prompt = True
404+ )
405+ else :
406+ prompt_tokens = tokenizer .encode (prompt )
403407
404408 tokenized_data = (
405409 prompt ,
@@ -470,6 +474,7 @@ def filter_dataset(
470474def sample_requests (
471475 dataset : list [tuple [Any , Any ]],
472476 tokenizer : Any ,
477+ use_chat_template : bool ,
473478 num_requests : int ,
474479 dataset_type : str ,
475480 max_output_length : int = 0 ,
@@ -508,7 +513,9 @@ def sample_requests(
508513 sampled_data = dataset [i ] + (dataset_indices [i ],)
509514 sampled_dataset .append (sampled_data )
510515
511- tokenized_dataset = tokenize_dataset (sampled_dataset , tokenizer )
516+ tokenized_dataset = tokenize_dataset (
517+ sampled_dataset , tokenizer , use_chat_template
518+ )
512519
513520 input_requests = filter_dataset (
514521 tokenized_dataset ,
@@ -636,6 +643,7 @@ async def grpc_async_request(
636643async def send_request (
637644 api_url : str ,
638645 tokenizer : Any ,
646+ use_chat_template : bool ,
639647 input_request : InputRequest ,
640648 prefill_quota : AsyncCounter ,
641649 active_req_quota : AsyncCounter ,
@@ -645,7 +653,13 @@ async def send_request(
645653) -> RequestFuncOutput :
646654 """Send the request to JetStream server."""
647655 # Tokenize on client side following MLPerf standard.
648- token_ids = tokenizer .encode (input_request .prompt )
656+ if use_chat_template :
657+ token_ids = tokenizer .apply_chat_template (
658+ [{"role" : "user" , "content" : input_request .prompt }],
659+ add_generation_prompt = True ,
660+ )
661+ else :
662+ token_ids = tokenizer .encode (input_request .prompt )
649663
650664 # Send the request
651665 request = jetstream_pb2 .DecodeRequest (
@@ -691,6 +705,7 @@ async def send_request(
691705async def benchmark (
692706 api_url : str ,
693707 tokenizer : Any ,
708+ use_chat_template : bool ,
694709 input_requests : list [InputRequest ],
695710 request_rate : float ,
696711 disable_tqdm : bool ,
@@ -734,6 +749,7 @@ async def benchmark(
734749 send_request (
735750 api_url = api_url ,
736751 tokenizer = tokenizer ,
752+ use_chat_template = use_chat_template ,
737753 input_request = request ,
738754 prefill_quota = prefill_quota ,
739755 active_req_quota = active_req_quota ,
@@ -892,6 +908,23 @@ def parse_args() -> argparse.Namespace:
892908 " to True, and provide name of the tokenizer in the tokenizer flag."
893909 ),
894910 )
911+ parser .add_argument (
912+ "--hf-access-token" ,
913+ type = str ,
914+ default = "" ,
915+ help = (
916+ "Access token used to load a tokenizer from an API (i.e. HuggingFace)"
917+ ),
918+ )
919+ parser .add_argument (
920+ "--use-chat-template" ,
921+ type = str2bool ,
922+ default = False ,
923+ help = (
924+ "Whether the tokenizer should be applying a chat template "
925+ "(used for instruction-tuned models)."
926+ ),
927+ )
895928 parser .add_argument (
896929 "--num-prompts" ,
897930 type = int ,
@@ -1051,13 +1084,16 @@ def main(args: argparse.Namespace):
10511084 model_id = args .model
10521085 tokenizer_id = args .tokenizer
10531086 use_hf_tokenizer = args .use_hf_tokenizer
1087+ hf_access_token = args .hf_access_token
1088+ use_chat_template = args .use_chat_template
10541089
10551090 prefill_quota = AsyncCounter (init_value = 3 )
10561091 active_req_quota = AsyncCounter (init_value = 450 )
10571092
10581093 api_url = f"{ args .server } :{ args .port } "
1059-
1060- tokenizer = get_tokenizer (model_id , tokenizer_id , use_hf_tokenizer )
1094+ tokenizer = get_tokenizer (
1095+ model_id , tokenizer_id , use_hf_tokenizer , hf_access_token
1096+ )
10611097 if tokenizer == "test" or args .dataset == "test" :
10621098 input_requests = mock_requests (
10631099 args .total_mock_requests
@@ -1094,6 +1130,7 @@ def main(args: argparse.Namespace):
10941130 input_requests = sample_requests (
10951131 dataset = dataset ,
10961132 tokenizer = tokenizer ,
1133+ use_chat_template = use_chat_template ,
10971134 num_requests = args .num_prompts ,
10981135 dataset_type = args .dataset ,
10991136 max_output_length = args .max_output_length ,
@@ -1116,6 +1153,7 @@ def main(args: argparse.Namespace):
11161153 benchmark (
11171154 api_url = api_url ,
11181155 tokenizer = tokenizer ,
1156+ use_chat_template = use_chat_template ,
11191157 input_requests = warmup_requests ,
11201158 request_rate = args .request_rate ,
11211159 disable_tqdm = args .disable_tqdm ,
@@ -1134,6 +1172,7 @@ def main(args: argparse.Namespace):
11341172 benchmark (
11351173 api_url = api_url ,
11361174 tokenizer = tokenizer ,
1175+ use_chat_template = use_chat_template ,
11371176 input_requests = input_requests ,
11381177 request_rate = args .request_rate ,
11391178 disable_tqdm = args .disable_tqdm ,
0 commit comments