9191
9292import grpc
9393import jax
94- import jax .numpy as jnp
9594
9695from jetstream .core .proto import jetstream_pb2
9796from jetstream .core .proto import jetstream_pb2_grpc
@@ -523,72 +522,83 @@ def _process_prefill_content(
523522 tokenizer : tokenizer_api .Tokenizer ,
524523 is_bos : bool ,
525524 max_prefill_length : int ,
526- chunked_prefill : bool = False ,
527- chunk_size : Optional [int ] = None ,
528- ) -> (
529- Tuple [(jax .Array | np .ndarray ), int , jax .Array ]
530- | Tuple [
531- list [jax .Array | np .ndarray ],
532- list [int ],
533- list [jax .Array ],
534- ]
535- ):
536- assert (chunked_prefill and chunk_size is not None ) or (
537- not chunked_prefill
538- ), "Set chunk_size when chunked_prefill is True to use chunked prefill"
539-
525+ ) -> Tuple [jax .Array | np .ndarray , int ]:
540526 content = request .prefill_content
541527 if isinstance (content , str ):
542528 # If it's text input, tokenize and pad the input.
543- tokens , true_length = tokenizer .encode (
529+ return tokenizer .encode (
544530 content ,
545531 is_bos = is_bos ,
546532 max_prefill_length = max_prefill_length ,
547533 jax_padding = self ._jax_padding ,
548534 )
549- positions = jnp .expand_dims (
550- jnp .arange (0 , len (tokens ), dtype = jnp .int32 ), 0
551- )
552-
553- if chunked_prefill and chunk_size is not None :
554- # tokenizer.encode handle the is_bos already,
555- # set is_bos to False while chunking
556- return token_utils .chunk_and_pad_tokens (
557- tokens [:true_length ],
558- tokenizer .bos_id ,
559- tokenizer .pad_id ,
560- is_bos = False ,
561- max_prefill_length = max_prefill_length ,
562- chunk_size = chunk_size ,
563- jax_padding = self ._jax_padding ,
564- )
565- return tokens , true_length , positions
566-
567535 else :
568- if chunked_prefill and chunk_size is not None :
569- return token_utils .chunk_and_pad_tokens (
570- content ,
571- tokenizer .bos_id ,
572- tokenizer .pad_id ,
573- is_bos = is_bos ,
574- max_prefill_length = max_prefill_length ,
575- chunk_size = chunk_size ,
576- jax_padding = self ._jax_padding ,
577- )
578-
579536 # If it's token input, pad the input.
580- tokens , true_length = token_utils .pad_tokens (
537+ return token_utils .pad_tokens (
581538 content ,
582539 tokenizer .bos_id ,
583540 tokenizer .pad_id ,
584541 is_bos = is_bos ,
585542 max_prefill_length = max_prefill_length ,
586543 jax_padding = self ._jax_padding ,
587544 )
588- positions = jnp .expand_dims (
589- jnp .arange (0 , len (tokens ), dtype = jnp .int32 ), 0
545+
546+ def _do_chunked_prefill (
547+ self ,
548+ prefill_engine : engine_api .Engine ,
549+ prefill_params : Any ,
550+ tokenizer : tokenizer_api .Tokenizer ,
551+ tokens : jax .Array | np .ndarray ,
552+ ) -> Tuple [engine_api .Prefix , engine_api .ResultTokens ]:
553+ """Do chunked prefill.
554+
555+ Should not use without enabling use_chunked_prefill config.
556+ """
557+
558+ assert prefill_engine .use_chunked_prefill
559+
560+ prefill_result = None
561+ first_token = None
562+
563+ existing_prefix = None
564+ for start_pos in range (
565+ 0 ,
566+ len (tokens ),
567+ prefill_engine .prefill_chunk_size ,
568+ ):
569+ input_token = tokens [
570+ start_pos : min (
571+ len (tokens ), start_pos + prefill_engine .prefill_chunk_size
572+ )
573+ ]
574+ padded_input_token , input_true_length = token_utils .pad_tokens (
575+ input_token ,
576+ tokenizer .bos_id ,
577+ tokenizer .pad_id ,
578+ is_bos = False ,
579+ max_prefill_length = prefill_engine .max_prefill_length ,
580+ jax_padding = self ._jax_padding ,
581+ )
582+ prefill_result , first_token = prefill_engine .prefill (
583+ params = prefill_params ,
584+ existing_prefix = existing_prefix ,
585+ padded_tokens = padded_input_token ,
586+ true_length = input_true_length ,
587+ )
588+ existing_prefix = engine_api .ExistingPrefix (
589+ cache = prefill_result ["cache" ],
590+ common_prefix_tokens = tokens [
591+ 0 : min (
592+ len (tokens ), start_pos + prefill_engine .prefill_chunk_size
593+ )
594+ ],
590595 )
591- return tokens , true_length , positions
596+
597+ # Should assign in the loop
598+ assert prefill_result is not None
599+ assert first_token is not None
600+
601+ return prefill_result , first_token
592602
593603 def _prefill_thread (self , idx : int ):
594604 """Thread which runs in the background performing prefills."""
@@ -616,12 +626,11 @@ def _prefill_thread(self, idx: int):
616626 f" is_bos: { is_bos } " ,
617627 )
618628 # Tokenize and padding the text or token input.
619- padded_tokens , true_length , _ = self ._process_prefill_content (
629+ padded_tokens , true_length = self ._process_prefill_content (
620630 request ,
621631 tokenizer ,
622632 is_bos ,
623633 prefill_engine .max_prefill_length ,
624- False ,
625634 )
626635
627636 # Compute new kv cache for the prefill_content.
@@ -636,49 +645,25 @@ def _prefill_thread(self, idx: int):
636645 else :
637646 # if chunked_prefill is used,
638647 if prefill_engine .use_chunked_prefill :
639- padded_chunked_tokens , true_lengths_of_chunks , positions_chunks = (
640- self ._process_prefill_content (
641- request ,
642- tokenizer ,
643- is_bos ,
644- prefill_engine .max_prefill_length ,
645- prefill_engine .use_chunked_prefill ,
646- prefill_engine .prefill_chunk_size ,
647- )
648+ prefill_result , first_token = self ._do_chunked_prefill (
649+ prefill_engine ,
650+ prefill_params ,
651+ tokenizer ,
652+ padded_tokens [:true_length ],
648653 )
649- prefill_result = None
650- for chunk_num , _ in enumerate (padded_chunked_tokens ):
651- cache_so_far = (
652- {} if prefill_result is None else prefill_result ["cache" ] # pylint: disable=unsubscriptable-object
653- )
654- prefill_result , first_token = prefill_engine .prefill (
655- params = prefill_params | {"cache" : cache_so_far },
656- padded_tokens = padded_chunked_tokens [chunk_num ],
657- true_length = true_lengths_of_chunks [chunk_num ],
658- positions = positions_chunks [chunk_num ],
659- previous_chunk = prefill_result ,
660- complete_prompt_true_length = true_length ,
661- complete_padded_prompt = padded_tokens ,
662- )
663- # true_length_array is arrays of 1 true lengths so far
664- t_l_array = jnp .expand_dims (
665- jnp .arange (
666- 0 ,
667- chunk_num * prefill_engine .prefill_chunk_size
668- + true_lengths_of_chunks [chunk_num ],
669- ),
670- 0 ,
671- )
672- prefill_result ["true_length_array" ] = t_l_array
673654 else :
674655 # Compute new kv cache for the prefill_content.
675656 prefill_result , first_token = prefill_engine .prefill (
676657 params = prefill_params ,
677658 padded_tokens = padded_tokens ,
678659 true_length = true_length ,
679660 )
661+
662+ request .complete = np .zeros (
663+ (prefill_engine .samples_per_slot ,), np .bool_
664+ )
665+
680666 request .prefill_result = prefill_result
681- request .complete = np .zeros ((prefill_engine .samples_per_slot ,), np .bool_ )
682667
683668 # put first token to detokenize queue
684669 my_detokenize_backlog = self ._detokenize_backlogs [idx ]
0 commit comments