Skip to content

Commit 9d19631

Browse files
authored
Added new HuggingFaceTokenizer to token_utils and updated TokenizerParameters to include tokenizer_type and access_token as additional metadata to store. (#229)
1 parent b8ad727 commit 9d19631

19 files changed

Lines changed: 50353 additions & 50 deletions

benchmarks/benchmark_serving.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def get_tokenizer(
208208
model_id: str,
209209
tokenizer_name: str,
210210
use_hf_tokenizer: bool,
211+
access_token: str | None = None,
211212
) -> Any:
212213
"""Return a tokenizer or a tokenizer placholder."""
213214
if tokenizer_name == "test":
@@ -218,7 +219,7 @@ def get_tokenizer(
218219
# follow up instructions below to set up access token
219220
# https://huggingface.co/docs/transformers.js/en/guides/private
220221
print(f"Using HuggingFace tokenizer: {tokenizer_name}")
221-
return AutoTokenizer.from_pretrained(tokenizer_name)
222+
return AutoTokenizer.from_pretrained(tokenizer_name, token=access_token)
222223
elif model_id == "llama-3":
223224
# Llama 3 uses a tiktoken tokenizer.
224225
print(f"Using llama-3 tokenizer: {tokenizer_name}")
@@ -386,11 +387,9 @@ def load_math500_dataset(dataset_path: str) -> list[tuple[Any, Any]]:
386387

387388

388389
def tokenize_dataset(
389-
dataset: list[tuple[Any, Any, Any]],
390-
tokenizer: Any,
390+
dataset: list[tuple[Any, Any, Any]], tokenizer: Any, use_chat_template: bool
391391
) -> list[tuple[str, Any, str, int, int, int]]:
392392
tokenized_dataset = []
393-
394393
for prompt, output, idx in dataset:
395394
if isinstance(output, tuple):
396395
output_len = len(tokenizer.encode(output[0]))
@@ -399,7 +398,12 @@ def tokenize_dataset(
399398
output_len = len(tokenizer.encode(output))
400399
output_tokens = output
401400

402-
prompt_tokens = tokenizer.encode(prompt)
401+
if use_chat_template:
402+
prompt_tokens = tokenizer.apply_chat_template(
403+
[{"role": "user", "content": prompt}], add_generation_prompt=True
404+
)
405+
else:
406+
prompt_tokens = tokenizer.encode(prompt)
403407

404408
tokenized_data = (
405409
prompt,
@@ -470,6 +474,7 @@ def filter_dataset(
470474
def sample_requests(
471475
dataset: list[tuple[Any, Any]],
472476
tokenizer: Any,
477+
use_chat_template: bool,
473478
num_requests: int,
474479
dataset_type: str,
475480
max_output_length: int = 0,
@@ -508,7 +513,9 @@ def sample_requests(
508513
sampled_data = dataset[i] + (dataset_indices[i],)
509514
sampled_dataset.append(sampled_data)
510515

511-
tokenized_dataset = tokenize_dataset(sampled_dataset, tokenizer)
516+
tokenized_dataset = tokenize_dataset(
517+
sampled_dataset, tokenizer, use_chat_template
518+
)
512519

513520
input_requests = filter_dataset(
514521
tokenized_dataset,
@@ -636,6 +643,7 @@ async def grpc_async_request(
636643
async def send_request(
637644
api_url: str,
638645
tokenizer: Any,
646+
use_chat_template: bool,
639647
input_request: InputRequest,
640648
prefill_quota: AsyncCounter,
641649
active_req_quota: AsyncCounter,
@@ -645,7 +653,13 @@ async def send_request(
645653
) -> RequestFuncOutput:
646654
"""Send the request to JetStream server."""
647655
# Tokenize on client side following MLPerf standard.
648-
token_ids = tokenizer.encode(input_request.prompt)
656+
if use_chat_template:
657+
token_ids = tokenizer.apply_chat_template(
658+
[{"role": "user", "content": input_request.prompt}],
659+
add_generation_prompt=True,
660+
)
661+
else:
662+
token_ids = tokenizer.encode(input_request.prompt)
649663

650664
# Send the request
651665
request = jetstream_pb2.DecodeRequest(
@@ -691,6 +705,7 @@ async def send_request(
691705
async def benchmark(
692706
api_url: str,
693707
tokenizer: Any,
708+
use_chat_template: bool,
694709
input_requests: list[InputRequest],
695710
request_rate: float,
696711
disable_tqdm: bool,
@@ -734,6 +749,7 @@ async def benchmark(
734749
send_request(
735750
api_url=api_url,
736751
tokenizer=tokenizer,
752+
use_chat_template=use_chat_template,
737753
input_request=request,
738754
prefill_quota=prefill_quota,
739755
active_req_quota=active_req_quota,
@@ -892,6 +908,23 @@ def parse_args() -> argparse.Namespace:
892908
" to True, and provide name of the tokenizer in the tokenizer flag."
893909
),
894910
)
911+
parser.add_argument(
912+
"--hf-access-token",
913+
type=str,
914+
default="",
915+
help=(
916+
"Access token used to load a tokenizer from an API (i.e. HuggingFace)"
917+
),
918+
)
919+
parser.add_argument(
920+
"--use-chat-template",
921+
type=str2bool,
922+
default=False,
923+
help=(
924+
"Whether the tokenizer should be applying a chat template "
925+
"(used for instruction-tuned models)."
926+
),
927+
)
895928
parser.add_argument(
896929
"--num-prompts",
897930
type=int,
@@ -1051,13 +1084,16 @@ def main(args: argparse.Namespace):
10511084
model_id = args.model
10521085
tokenizer_id = args.tokenizer
10531086
use_hf_tokenizer = args.use_hf_tokenizer
1087+
hf_access_token = args.hf_access_token
1088+
use_chat_template = args.use_chat_template
10541089

10551090
prefill_quota = AsyncCounter(init_value=3)
10561091
active_req_quota = AsyncCounter(init_value=450)
10571092

10581093
api_url = f"{args.server}:{args.port}"
1059-
1060-
tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
1094+
tokenizer = get_tokenizer(
1095+
model_id, tokenizer_id, use_hf_tokenizer, hf_access_token
1096+
)
10611097
if tokenizer == "test" or args.dataset == "test":
10621098
input_requests = mock_requests(
10631099
args.total_mock_requests
@@ -1094,6 +1130,7 @@ def main(args: argparse.Namespace):
10941130
input_requests = sample_requests(
10951131
dataset=dataset,
10961132
tokenizer=tokenizer,
1133+
use_chat_template=use_chat_template,
10971134
num_requests=args.num_prompts,
10981135
dataset_type=args.dataset,
10991136
max_output_length=args.max_output_length,
@@ -1116,6 +1153,7 @@ def main(args: argparse.Namespace):
11161153
benchmark(
11171154
api_url=api_url,
11181155
tokenizer=tokenizer,
1156+
use_chat_template=use_chat_template,
11191157
input_requests=warmup_requests,
11201158
request_rate=args.request_rate,
11211159
disable_tqdm=args.disable_tqdm,
@@ -1134,6 +1172,7 @@ def main(args: argparse.Namespace):
11341172
benchmark(
11351173
api_url=api_url,
11361174
tokenizer=tokenizer,
1175+
use_chat_template=use_chat_template,
11371176
input_requests=input_requests,
11381177
request_rate=args.request_rate,
11391178
disable_tqdm=args.disable_tqdm,

benchmarks/tests/test_benchmark_serving.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class TestBenchmarkServing(unittest.IsolatedAsyncioTestCase):
2727
async def test_benchmark(self):
2828
api_url = "test_url"
2929
tokenizer = mock.MagicMock()
30+
use_chat_template = False
3031
tokenizer.encode = mock.MagicMock(return_value=[1, 2, 3])
3132
tokenizer.decode = mock.MagicMock(return_value="test_decode")
3233
input_requests = [
@@ -78,6 +79,7 @@ def mock_orchestrator_factory(*args, **kwargs):
7879
metrics, outputs = await benchmark_serving.benchmark(
7980
api_url,
8081
tokenizer,
82+
use_chat_template,
8183
input_requests,
8284
request_rate,
8385
disable_tqdm,

jetstream/engine/mock_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,11 @@ def get_prefix_destination_sharding(self) -> Any:
450450

451451
def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
452452
"""Return a protobuf of tokenizer info, callable from Py or C++."""
453-
return tokenizer_pb2.TokenizerParameters(path="test", extra_ids=0)
453+
return tokenizer_pb2.TokenizerParameters(
454+
path="test",
455+
tokenizer_type=tokenizer_pb2.TokenizerType.sentencepiece,
456+
extra_ids=0,
457+
)
454458

455459
def init_decode_state(self) -> DecodeState:
456460
"""Initialises any state which a generation step transforms."""

jetstream/engine/token_utils.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
from typing import Any, Iterable, List, Optional, Tuple, Union
2020

21+
from transformers import AutoTokenizer
2122
import jax
2223
import jax.numpy as jnp
2324
import numpy as np
@@ -200,7 +201,7 @@ def pad_tokens(
200201
tokens: Tokens.
201202
bos_id: Bos ID.
202203
pad_id: Pad ID.
203-
is_bos: Add a beginning of sequence token if this is ture.
204+
is_bos: Add a beginning of sequence token if this is true.
204205
prefill_lengths: Buckets to pad the sequence to for static compilation.
205206
max_prefill_length: Maximum bucket to use.
206207
jax_padding: convert to JAX padded tokens if True.
@@ -506,3 +507,103 @@ def eos_id(self) -> int:
506507
def bos_id(self) -> int:
507508
"""ID of the BOS token."""
508509
return self.tokenizer.bos_id
510+
511+
512+
class HuggingFaceTokenizer(tokenizer_api.Tokenizer):
513+
"""Tokenizer to convert strings to token ids and vice-versa."""
514+
515+
def __init__(self, metadata: tokenizer_pb2.TokenizerParameters):
516+
self.tokenizer = AutoTokenizer.from_pretrained(
517+
metadata.path, token=metadata.access_token
518+
)
519+
self.metadata = metadata
520+
521+
def encode(
522+
self, s: str, **kwargs
523+
) -> Tuple[Union[jax.Array, np.ndarray], int]:
524+
"""Tokenize a string.
525+
Args:
526+
s: String to tokenize.
527+
**kwargs: Additional keyword arguments
528+
Returns:
529+
tokens: Tokenized into integers.
530+
true_length: Actual length of the non-padded sequence
531+
if padding is used.
532+
"""
533+
is_bos = kwargs.pop("is_bos", True)
534+
prefill_lengths = kwargs.pop("prefill_lengths", None)
535+
max_prefill_length = kwargs.pop("max_prefill_length", None)
536+
jax_padding = kwargs.pop("jax_padding", True)
537+
if getattr(self.metadata, "use_chat_template", False):
538+
tokens = self.tokenizer.apply_chat_template(
539+
[{"role": "user", "content": s}],
540+
add_generation_prompt=True,
541+
return_tensors="np",
542+
).squeeze()
543+
if is_bos:
544+
logging.warning(
545+
"Overriding is_bos to False because use_chat_template "
546+
"is set to True."
547+
)
548+
is_bos = False
549+
else:
550+
tokens = self.tokenizer.encode(
551+
s, add_special_tokens=False, return_tensors="np"
552+
).squeeze()
553+
554+
tokens, true_length = pad_tokens(
555+
tokens,
556+
self.bos_id,
557+
self.pad_id,
558+
is_bos=is_bos,
559+
prefill_lengths=prefill_lengths,
560+
max_prefill_length=max_prefill_length,
561+
jax_padding=jax_padding,
562+
)
563+
return tokens, true_length
564+
565+
def decode(self, token_ids: list[int]) -> str:
566+
"""Processess input token ids to generate a string.
567+
Args:
568+
token_ids: List of token ids.
569+
Returns:
570+
str: String generated from the token ids.
571+
"""
572+
return self.tokenizer.decode(token_ids, skip_special_tokens=True)
573+
574+
@property
575+
def pad_id(self) -> Union[None, int]:
576+
"""ID of the pad token."""
577+
if getattr(self.tokenizer, "pad_token_id", None):
578+
return self.tokenizer.pad_token_id
579+
elif getattr(self.tokenizer, "pad_token", None):
580+
try:
581+
return self.tokenizer.encode(self.tokenizer.pad_token)[0]
582+
except ValueError as _:
583+
raise ValueError(
584+
"Tokenizer does not contain a special" " pad token!"
585+
) from None
586+
587+
@property
588+
def eos_id(self) -> Union[None, int]:
589+
if getattr(self.tokenizer, "eos_token_id", None):
590+
return self.tokenizer.eos_token_id
591+
elif getattr(self.tokenizer, "eos_token", None):
592+
try:
593+
return self.tokenizer.encode(self.tokenizer.eos_token)[0]
594+
except ValueError as _:
595+
raise ValueError(
596+
"Tokenizer does not contain a special " "eos token!"
597+
) from None
598+
599+
@property
600+
def bos_id(self) -> Union[None, int]:
601+
if getattr(self.tokenizer, "bos_token_id", None):
602+
return self.tokenizer.bos_token_id
603+
elif getattr(self.tokenizer, "bos_token", None):
604+
try:
605+
return self.tokenizer.encode(self.tokenizer.bos_token)[0]
606+
except ValueError as _:
607+
raise ValueError(
608+
"Tokenizer does not contain a special " "bos token!"
609+
) from None

jetstream/engine/tokenizer.proto

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,14 @@ option java_multiple_files = true;
2020

2121
message TokenizerParameters {
2222
string path = 1;
23-
int32 extra_ids = 2;
23+
TokenizerType tokenizer_type = 2;
24+
string access_token = 3;
25+
bool use_chat_template = 4;
26+
int32 extra_ids = 5;
27+
}
28+
29+
enum TokenizerType {
30+
tiktoken = 0;
31+
sentencepiece = 1;
32+
huggingface = 2;
2433
}

0 commit comments

Comments
 (0)