Skip to content

Commit e539007

Browse files
Refactor chunked prefill (#232)
1 parent b8b9cb2 commit e539007

2 files changed

Lines changed: 77 additions & 90 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 70 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@
9191

9292
import grpc
9393
import jax
94-
import jax.numpy as jnp
9594

9695
from jetstream.core.proto import jetstream_pb2
9796
from jetstream.core.proto import jetstream_pb2_grpc
@@ -523,72 +522,83 @@ def _process_prefill_content(
523522
tokenizer: tokenizer_api.Tokenizer,
524523
is_bos: bool,
525524
max_prefill_length: int,
526-
chunked_prefill: bool = False,
527-
chunk_size: Optional[int] = None,
528-
) -> (
529-
Tuple[(jax.Array | np.ndarray), int, jax.Array]
530-
| Tuple[
531-
list[jax.Array | np.ndarray],
532-
list[int],
533-
list[jax.Array],
534-
]
535-
):
536-
assert (chunked_prefill and chunk_size is not None) or (
537-
not chunked_prefill
538-
), "Set chunk_size when chunked_prefill is True to use chunked prefill"
539-
525+
) -> Tuple[jax.Array | np.ndarray, int]:
540526
content = request.prefill_content
541527
if isinstance(content, str):
542528
# If it's text input, tokenize and pad the input.
543-
tokens, true_length = tokenizer.encode(
529+
return tokenizer.encode(
544530
content,
545531
is_bos=is_bos,
546532
max_prefill_length=max_prefill_length,
547533
jax_padding=self._jax_padding,
548534
)
549-
positions = jnp.expand_dims(
550-
jnp.arange(0, len(tokens), dtype=jnp.int32), 0
551-
)
552-
553-
if chunked_prefill and chunk_size is not None:
554-
# tokenizer.encode handle the is_bos already,
555-
# set is_bos to False while chunking
556-
return token_utils.chunk_and_pad_tokens(
557-
tokens[:true_length],
558-
tokenizer.bos_id,
559-
tokenizer.pad_id,
560-
is_bos=False,
561-
max_prefill_length=max_prefill_length,
562-
chunk_size=chunk_size,
563-
jax_padding=self._jax_padding,
564-
)
565-
return tokens, true_length, positions
566-
567535
else:
568-
if chunked_prefill and chunk_size is not None:
569-
return token_utils.chunk_and_pad_tokens(
570-
content,
571-
tokenizer.bos_id,
572-
tokenizer.pad_id,
573-
is_bos=is_bos,
574-
max_prefill_length=max_prefill_length,
575-
chunk_size=chunk_size,
576-
jax_padding=self._jax_padding,
577-
)
578-
579536
# If it's token input, pad the input.
580-
tokens, true_length = token_utils.pad_tokens(
537+
return token_utils.pad_tokens(
581538
content,
582539
tokenizer.bos_id,
583540
tokenizer.pad_id,
584541
is_bos=is_bos,
585542
max_prefill_length=max_prefill_length,
586543
jax_padding=self._jax_padding,
587544
)
588-
positions = jnp.expand_dims(
589-
jnp.arange(0, len(tokens), dtype=jnp.int32), 0
545+
546+
def _do_chunked_prefill(
547+
self,
548+
prefill_engine: engine_api.Engine,
549+
prefill_params: Any,
550+
tokenizer: tokenizer_api.Tokenizer,
551+
tokens: jax.Array | np.ndarray,
552+
) -> Tuple[engine_api.Prefix, engine_api.ResultTokens]:
553+
"""Do chunked prefill.
554+
555+
Should not use without enabling use_chunked_prefill config.
556+
"""
557+
558+
assert prefill_engine.use_chunked_prefill
559+
560+
prefill_result = None
561+
first_token = None
562+
563+
existing_prefix = None
564+
for start_pos in range(
565+
0,
566+
len(tokens),
567+
prefill_engine.prefill_chunk_size,
568+
):
569+
input_token = tokens[
570+
start_pos : min(
571+
len(tokens), start_pos + prefill_engine.prefill_chunk_size
572+
)
573+
]
574+
padded_input_token, input_true_length = token_utils.pad_tokens(
575+
input_token,
576+
tokenizer.bos_id,
577+
tokenizer.pad_id,
578+
is_bos=False,
579+
max_prefill_length=prefill_engine.max_prefill_length,
580+
jax_padding=self._jax_padding,
581+
)
582+
prefill_result, first_token = prefill_engine.prefill(
583+
params=prefill_params,
584+
existing_prefix=existing_prefix,
585+
padded_tokens=padded_input_token,
586+
true_length=input_true_length,
587+
)
588+
existing_prefix = engine_api.ExistingPrefix(
589+
cache=prefill_result["cache"],
590+
common_prefix_tokens=tokens[
591+
0 : min(
592+
len(tokens), start_pos + prefill_engine.prefill_chunk_size
593+
)
594+
],
590595
)
591-
return tokens, true_length, positions
596+
597+
# Should assign in the loop
598+
assert prefill_result is not None
599+
assert first_token is not None
600+
601+
return prefill_result, first_token
592602

593603
def _prefill_thread(self, idx: int):
594604
"""Thread which runs in the background performing prefills."""
@@ -616,12 +626,11 @@ def _prefill_thread(self, idx: int):
616626
f" is_bos: {is_bos}",
617627
)
618628
# Tokenize and padding the text or token input.
619-
padded_tokens, true_length, _ = self._process_prefill_content(
629+
padded_tokens, true_length = self._process_prefill_content(
620630
request,
621631
tokenizer,
622632
is_bos,
623633
prefill_engine.max_prefill_length,
624-
False,
625634
)
626635

627636
# Compute new kv cache for the prefill_content.
@@ -636,49 +645,25 @@ def _prefill_thread(self, idx: int):
636645
else:
637646
# if chunked_prefill is used,
638647
if prefill_engine.use_chunked_prefill:
639-
padded_chunked_tokens, true_lengths_of_chunks, positions_chunks = (
640-
self._process_prefill_content(
641-
request,
642-
tokenizer,
643-
is_bos,
644-
prefill_engine.max_prefill_length,
645-
prefill_engine.use_chunked_prefill,
646-
prefill_engine.prefill_chunk_size,
647-
)
648+
prefill_result, first_token = self._do_chunked_prefill(
649+
prefill_engine,
650+
prefill_params,
651+
tokenizer,
652+
padded_tokens[:true_length],
648653
)
649-
prefill_result = None
650-
for chunk_num, _ in enumerate(padded_chunked_tokens):
651-
cache_so_far = (
652-
{} if prefill_result is None else prefill_result["cache"] # pylint: disable=unsubscriptable-object
653-
)
654-
prefill_result, first_token = prefill_engine.prefill(
655-
params=prefill_params | {"cache": cache_so_far},
656-
padded_tokens=padded_chunked_tokens[chunk_num],
657-
true_length=true_lengths_of_chunks[chunk_num],
658-
positions=positions_chunks[chunk_num],
659-
previous_chunk=prefill_result,
660-
complete_prompt_true_length=true_length,
661-
complete_padded_prompt=padded_tokens,
662-
)
663-
# true_length_array is arrays of 1 true lengths so far
664-
t_l_array = jnp.expand_dims(
665-
jnp.arange(
666-
0,
667-
chunk_num * prefill_engine.prefill_chunk_size
668-
+ true_lengths_of_chunks[chunk_num],
669-
),
670-
0,
671-
)
672-
prefill_result["true_length_array"] = t_l_array
673654
else:
674655
# Compute new kv cache for the prefill_content.
675656
prefill_result, first_token = prefill_engine.prefill(
676657
params=prefill_params,
677658
padded_tokens=padded_tokens,
678659
true_length=true_length,
679660
)
661+
662+
request.complete = np.zeros(
663+
(prefill_engine.samples_per_slot,), np.bool_
664+
)
665+
680666
request.prefill_result = prefill_result
681-
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_)
682667

683668
# put first token to detokenize queue
684669
my_detokenize_backlog = self._detokenize_backlogs[idx]

jetstream/engine/engine_api.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
PRNGKeyType = Any
4848

4949

50+
@struct.dataclass
51+
class ExistingPrefix:
52+
cache: Any
53+
common_prefix_tokens: jax.Array
54+
55+
5056
@struct.dataclass
5157
class SlotData:
5258
"""Class to store slot data."""
@@ -157,14 +163,10 @@ def prefill(
157163
self,
158164
*,
159165
params: Params,
160-
existing_prefix: Optional[Prefix] = None,
166+
existing_prefix: Optional[ExistingPrefix] = None,
161167
padded_tokens: jax.Array,
162168
true_length: int,
163169
sampler: Optional[Callable[[Any], Any]] = None,
164-
complete_prompt_true_length: Optional[int] = None,
165-
complete_padded_prompt: Optional[jax.Array] = None,
166-
positions: Optional[jax.Array] = None,
167-
previous_chunk: Optional[Any] = None,
168170
request_id: Optional[uuid.UUID] = None,
169171
) -> Tuple[Prefix, ResultTokens]:
170172
"""Computes a kv-cache for a set of tokens conditional on existing cache.

0 commit comments

Comments
 (0)