Skip to content

Commit 698f33f

Browse files
Chunked Prefill (#188)
1 parent f602565 commit 698f33f

5 files changed

Lines changed: 340 additions & 27 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@
9191

9292
import grpc
9393
import jax
94+
import jax.numpy as jnp
95+
9496
from jetstream.core.proto import jetstream_pb2
9597
from jetstream.core.proto import jetstream_pb2_grpc
9698
from jetstream.core.utils import async_multifuture
@@ -519,26 +521,59 @@ def _process_prefill_content(
519521
tokenizer: tokenizer_api.Tokenizer,
520522
is_bos: bool,
521523
max_prefill_length: int,
522-
) -> Tuple[jax.Array | np.ndarray, int]:
524+
chunked_prefill: bool = False,
525+
chunk_size: Optional[int] = None,
526+
) -> Tuple[jax.Array | np.ndarray, jax.Array, jax.Array | np.ndarray]:
523527
content = request.prefill_content
524528
if isinstance(content, str):
525529
# If it's text input, tokenize and pad the input.
526-
return tokenizer.encode(
530+
tokens, true_length = tokenizer.encode(
527531
content,
528532
is_bos=is_bos,
529533
max_prefill_length=max_prefill_length,
530534
jax_padding=self._jax_padding,
531535
)
536+
positions = jnp.expand_dims(
537+
jnp.arange(0, len(tokens), dtype=jnp.int32), 0
538+
)
539+
540+
if chunked_prefill:
541+
return token_utils.chunk_and_pad_tokens(
542+
tokens[:true_length],
543+
tokenizer.bos_id,
544+
tokenizer.pad_id,
545+
is_bos=is_bos,
546+
max_prefill_length=max_prefill_length,
547+
chunk_size=chunk_size,
548+
jax_padding=self._jax_padding,
549+
)
550+
return tokens, true_length, positions
551+
532552
else:
553+
if chunked_prefill:
554+
return token_utils.chunk_and_pad_tokens(
555+
content,
556+
tokenizer.bos_id,
557+
tokenizer.pad_id,
558+
is_bos=is_bos,
559+
max_prefill_length=max_prefill_length,
560+
chunk_size=chunk_size,
561+
jax_padding=self._jax_padding,
562+
)
563+
533564
# If it's token input, pad the input.
534-
return token_utils.pad_tokens(
565+
tokens, true_length = token_utils.pad_tokens(
535566
content,
536567
tokenizer.bos_id,
537568
tokenizer.pad_id,
538569
is_bos=is_bos,
539570
max_prefill_length=max_prefill_length,
540571
jax_padding=self._jax_padding,
541572
)
573+
positions = jnp.expand_dims(
574+
jnp.arange(0, len(tokens), dtype=jnp.int32), 0
575+
)
576+
return tokens, true_length, positions
542577

543578
def _prefill_thread(self, idx: int):
544579
"""Thread which runs in the background performing prefills."""
@@ -566,8 +601,12 @@ def _prefill_thread(self, idx: int):
566601
f" is_bos: {is_bos}",
567602
)
568603
# Tokenize and padding the text or token input.
569-
padded_tokens, true_length = self._process_prefill_content(
570-
request, tokenizer, is_bos, prefill_engine.max_prefill_length
604+
padded_tokens, true_length, _ = self._process_prefill_content(
605+
request,
606+
tokenizer,
607+
is_bos,
608+
prefill_engine.max_prefill_length,
609+
False,
571610
)
572611

573612
# Compute new kv cache for the prefill_content.
@@ -580,17 +619,51 @@ def _prefill_thread(self, idx: int):
580619
)
581620
request.complete = np.zeros((request.num_samples,), np.bool_)
582621
else:
583-
prefill_result, first_token = prefill_engine.prefill(
584-
params=prefill_params,
585-
padded_tokens=padded_tokens,
586-
true_length=true_length,
587-
request_id=request.request_id,
588-
)
589-
request.complete = np.zeros(
590-
(prefill_engine.samples_per_slot,), np.bool_
591-
)
592-
622+
# if chunked_prefill is used,
623+
if prefill_engine.use_chunked_prefill:
624+
padded_chunked_tokens, true_lengths_of_chunks, positions_chunks = (
625+
self._process_prefill_content(
626+
request,
627+
tokenizer,
628+
is_bos,
629+
prefill_engine.max_prefill_length,
630+
prefill_engine.use_chunked_prefill,
631+
prefill_engine.chunk_size,
632+
)
633+
)
634+
prefill_result = None
635+
for chunk_num, _ in enumerate(padded_chunked_tokens):
636+
cache_so_far = (
637+
{} if prefill_result is None else prefill_result["cache"] # pylint: disable=unsubscriptable-object
638+
)
639+
prefill_result, first_token = prefill_engine.prefill(
640+
params=prefill_params | {"cache": cache_so_far},
641+
padded_tokens=padded_chunked_tokens[chunk_num],
642+
true_length=true_lengths_of_chunks[chunk_num],
643+
positions=positions_chunks[chunk_num],
644+
previous_chunk=prefill_result,
645+
complete_prompt_true_length=true_length,
646+
complete_padded_prompt=padded_tokens,
647+
)
648+
# true_length_array is arrays of 1 true lengths so far
649+
t_l_array = jnp.expand_dims(
650+
jnp.arange(
651+
0,
652+
chunk_num * prefill_engine.chunk_size
653+
+ true_lengths_of_chunks[chunk_num],
654+
),
655+
1,
656+
)
657+
prefill_result["true_length_array"] = t_l_array
658+
else:
659+
# Compute new kv cache for the prefill_content.
660+
prefill_result, first_token = prefill_engine.prefill(
661+
params=prefill_params,
662+
padded_tokens=padded_tokens,
663+
true_length=true_length,
664+
)
593665
request.prefill_result = prefill_result
666+
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_)
594667

595668
# put first token to detokenize queue
596669
my_detokenize_backlog = self._detokenize_backlogs[idx]

jetstream/engine/mock_engine.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
cache_length: int,
8787
weight: float,
8888
vocab_size: int = 1024,
89+
use_chunked_prefill: bool = False,
8990
):
9091
self.prefill_cache_batch = batch_size
9192
self.generate_cache_batch = batch_size
@@ -96,17 +97,24 @@ def __init__(
9697
mesh_utils.create_device_mesh((1, 1, 1), jax.devices()), ("x", "y", "z")
9798
)
9899
self._prng_key = jax.random.PRNGKey(42)
100+
self._use_chunked_prefill = use_chunked_prefill
99101

100102
def load_params(self) -> Params:
101103
"""Loads model weights."""
102104
# An integer, used to multiply inputs.
103105
return jnp.array([self.weight], dtype=jnp.float32)
104106

107+
def load_params_dict(self) -> Params:
108+
"""Loads model weights."""
109+
# An integer, used to multiply inputs.
110+
return {"params": jnp.array([self.weight], dtype=jnp.float32)}
111+
105112
@functools.partial(
106113
jax.jit,
107114
static_argnums=(0,),
108115
static_argnames=("request_id",),
109116
)
117+
# pylint: disable=unused-argument
110118
def prefill(
111119
self,
112120
*,
@@ -115,6 +123,10 @@ def prefill(
115123
padded_tokens: jax.Array,
116124
true_length: int,
117125
request_id: Optional[uuid.UUID] = None,
126+
previous_chunk=None,
127+
complete_padded_prompt=None,
128+
complete_prompt_true_length=None,
129+
positions=None,
118130
) -> Tuple[Prefix, engine_api.ResultTokens]:
119131
"""Computes a kv-cache for a new generate request.
120132
@@ -133,20 +145,33 @@ def prefill(
133145
assert padded_tokens.ndim == 1
134146

135147
# Generate dummy prefill cache content
136-
prefill_cache = padded_tokens[None, :] * params
148+
if not self._use_chunked_prefill:
149+
prefill_cache = padded_tokens[None, :] * params
150+
else:
151+
prefill_cache = padded_tokens[None, :]
137152

138153
# Create a dummy first generated token.
139154
first_generated_token = (prefill_cache.sum(axis=-1).astype(jnp.int32))[
140155
:, jnp.newaxis
141156
]
142157

143-
prefix = Prefix(
144-
logits=jax.random.normal(self._prng_key, (1, self.vocab_size)),
145-
cache=prefill_cache,
146-
next_pos=jnp.full((1, 1), true_length, dtype=jnp.int32),
147-
num_generated_tokens=jnp.zeros((1, 1), dtype=jnp.int32),
148-
first_token=first_generated_token,
149-
)
158+
if not self._use_chunked_prefill:
159+
prefix = Prefix(
160+
logits=jax.random.normal(self._prng_key, (1, self.vocab_size)),
161+
cache=prefill_cache,
162+
next_pos=jnp.full((1, 1), true_length, dtype=jnp.int32),
163+
num_generated_tokens=jnp.zeros((1, 1), dtype=jnp.int32),
164+
first_token=first_generated_token,
165+
)
166+
else:
167+
prefix = {
168+
"logits": jax.random.normal(self._prng_key, (1, self.vocab_size)),
169+
"cache": prefill_cache,
170+
"next_pos": jnp.full((1, 1), true_length, dtype=jnp.int32),
171+
"generated_tokens": jnp.zeros((1, 1), dtype=jnp.int32),
172+
"tokens": first_generated_token,
173+
"first_token": first_generated_token,
174+
}
150175

151176
speculations = first_generated_token.shape[1]
152177
result_tokens = engine_api.ResultTokens(
@@ -319,15 +344,19 @@ def generate(
319344
)
320345
def insert(
321346
self,
322-
prefix: Prefix,
347+
prefix: Any,
323348
decode_state: DecodeState,
324349
slot: int,
325350
request_id: Optional[uuid.UUID] = None,
326351
) -> DecodeState:
327352
"""Adds `prefix` into `decode_state` at `slot`."""
328-
prefill_cache = prefix.cache
353+
if not self._use_chunked_prefill:
354+
prefill_cache = prefix.cache
355+
else:
356+
prefill_cache = prefix["cache"]
357+
329358
prefill_cache = jax.lax.dynamic_update_slice_in_dim(
330-
decode_state.prefill_cache, prefill_cache, slot, axis=0
359+
decode_state.prefill_cache, prefill_cache * 1.0, slot, axis=0
331360
)
332361
generate_cache = jax.lax.dynamic_update_slice_in_dim(
333362
decode_state.generate_cache,
@@ -342,9 +371,13 @@ def insert(
342371
slot * samples_per_slot,
343372
axis=0,
344373
)
374+
if not self._use_chunked_prefill:
375+
first_token = prefix.first_token
376+
else:
377+
first_token = prefix["first_token"]
345378
generate_tokens = jax.lax.dynamic_update_slice_in_dim(
346379
decode_state.generate_tokens,
347-
prefix.first_token,
380+
first_token,
348381
slot * samples_per_slot,
349382
axis=0,
350383
)
@@ -455,3 +488,18 @@ def mesh(self) -> jax.sharding.Mesh:
455488
def colocated_cpus(self) -> None:
456489
"""CPU devices colocated with the engine's accelerators."""
457490
raise NotImplementedError
491+
492+
@property
493+
def use_chunked_prefill(self) -> bool:
494+
"""Maximum prefill length."""
495+
return self._use_chunked_prefill
496+
497+
@property
498+
def chunk_size(self) -> bool:
499+
"""Maximum prefill length."""
500+
return 2
501+
502+
@property
503+
def prefill_chunk_size(self) -> int:
504+
"""Maximum prefill length."""
505+
return 64

jetstream/engine/token_utils.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import jax
2222
import jax.numpy as jnp
2323
import numpy as np
24+
import math
2425
from seqio.vocabularies import SentencePieceVocabulary
2526
from seqio.vocabularies import Vocabulary
2627

@@ -98,6 +99,91 @@ def tokenize_and_pad(
9899
return padded_tokens, true_length
99100

100101

102+
def chunk_and_pad_tokens(
103+
tokens,
104+
bos_id: int,
105+
pad_id: int,
106+
is_bos: bool = True,
107+
prefill_lengths: Optional[List[int]] = None,
108+
max_prefill_length: Optional[int] = None,
109+
chunk_size: Optional[int] = None,
110+
jax_padding: bool = True,
111+
) -> Tuple[
112+
List[Union[jax.Array, np.ndarray]],
113+
List[Union[jax.Array, np.ndarray]],
114+
List[Union[jax.Array, np.ndarray]],
115+
]:
116+
"""Chunks and pads tokens for chunked prefill
117+
if total token size is 520 and chunk size is 256,
118+
the function will return 3 chunks and return tuple is as follows-
119+
[[t0,..t255][t256,..t511][t512,..t519]],
120+
[256, 256, 7],
121+
[[0,..255],[256,..511],[512..518..]]
122+
123+
Args:
124+
tokens: Tokens.
125+
bos_id: Bos ID.
126+
pad_id: Pad ID.
127+
is_bos: Add a beginning of sequence token if this is ture.
128+
prefill_lengths: Buckets to pad the sequence to for static compilation.
129+
max_prefill_length: Maximum bucket to use.
130+
chunk_size: maximum size of each chunk
131+
jax_padding: convert to JAX padded tokens if True.
132+
133+
Returns:
134+
chunk_padded_tokens: List of chunked and padded tokens.
135+
padded_chunk_true_lengths: List of integers - true length of each chunk
136+
positions:list of position of each token in the chunk
137+
"""
138+
139+
num_tokens = len(tokens)
140+
num_chunks = int(math.ceil(num_tokens / chunk_size))
141+
# every entry in chunk_padded_tokens is a padded chunk
142+
chunk_padded_tokens = []
143+
144+
# true lengths for each chunk
145+
padded_chunk_true_lengths = []
146+
147+
# positions of tokens in each chunk
148+
positions = []
149+
# to be able to slice the tokens
150+
tokens = jnp.array(tokens)
151+
for chunk_num in range(num_chunks):
152+
start = int(chunk_num * chunk_size)
153+
end = jnp.minimum((chunk_num + 1) * chunk_size, num_tokens)
154+
chunk_tokens = jax.lax.slice(tokens, (start,), (end,))
155+
if chunk_num == 0:
156+
padded_chunk, padded_chunk_true_length = pad_tokens(
157+
chunk_tokens,
158+
bos_id,
159+
pad_id,
160+
is_bos,
161+
prefill_lengths,
162+
max_prefill_length,
163+
jax_padding,
164+
)
165+
else:
166+
# is_bos should be false in subsequent chunks.
167+
padded_chunk, padded_chunk_true_length = pad_tokens(
168+
chunk_tokens,
169+
bos_id,
170+
pad_id,
171+
False,
172+
prefill_lengths,
173+
max_prefill_length,
174+
jax_padding,
175+
)
176+
177+
positions_chunk = jnp.expand_dims(
178+
jnp.arange(start, start + len(padded_chunk), dtype=jnp.int32), 0
179+
)
180+
chunk_padded_tokens.append(padded_chunk)
181+
padded_chunk_true_lengths.append(padded_chunk_true_length)
182+
positions.append(positions_chunk)
183+
184+
return chunk_padded_tokens, padded_chunk_true_lengths, positions
185+
186+
101187
def pad_tokens(
102188
tokens: np.ndarray,
103189
bos_id: int,

0 commit comments

Comments
 (0)