@@ -101,41 +101,55 @@ def tokenize_and_pad(
101101
102102
103103def chunk_and_pad_tokens (
104- tokens ,
104+ tokens : np . ndarray ,
105105 bos_id : int ,
106106 pad_id : int ,
107- is_bos : bool = True ,
107+ is_bos : bool ,
108+ chunk_size : int ,
108109 prefill_lengths : Optional [List [int ]] = None ,
109110 max_prefill_length : Optional [int ] = None ,
110- chunk_size : Optional [int ] = None ,
111111 jax_padding : bool = True ,
112112) -> Tuple [
113113 List [Union [jax .Array , np .ndarray ]],
114- List [Union [ jax . Array , np . ndarray ] ],
115- List [Union [ jax .Array , np . ndarray ] ],
114+ List [int ],
115+ List [jax .Array ],
116116]:
117- """Chunks and pads tokens for chunked prefill
118- if total token size is 520 and chunk size is 256,
117+ """Chunks and pads tokens for chunked prefill.
118+
119+ If total token size is 520 and chunk size is 256,
119120 the function will return 3 chunks and return tuple is as follows-
120121 [[t0,..t255][t256,..t511][t512,..t519]],
121122 [256, 256, 7],
122- [[0,..255],[ 256,..511],[ 512..518..]]
123+ [[[ 0,..255]],[[ 256,..511]],[[ 512..518..] ]]
123124
124125 Args:
125126 tokens: Tokens.
126127 bos_id: Bos ID.
127128 pad_id: Pad ID.
128129 is_bos: Add a beginning of sequence token if this is ture.
130+ chunk_size: maximum size of each chunk
129131 prefill_lengths: Buckets to pad the sequence to for static compilation.
130132 max_prefill_length: Maximum bucket to use.
131- chunk_size: maximum size of each chunk
132133 jax_padding: convert to JAX padded tokens if True.
133134
134135 Returns:
135136 chunk_padded_tokens: List of chunked and padded tokens.
136137 padded_chunk_true_lengths: List of integers - true length of each chunk
137138 positions:list of position of each token in the chunk
138139 """
140+ # Add a beginning of sequence token if this is the beginning.
141+ if is_bos :
142+ tokens = np .concatenate (
143+ [
144+ np .array (
145+ [
146+ bos_id ,
147+ ]
148+ ),
149+ tokens ,
150+ ],
151+ axis = - 1 ,
152+ )
139153
140154 num_tokens = len (tokens )
141155 num_chunks = int (math .ceil (num_tokens / chunk_size ))
@@ -147,33 +161,22 @@ def chunk_and_pad_tokens(
147161
148162 # positions of tokens in each chunk
149163 positions = []
150- # to be able to slice the tokens
151- tokens = jnp .array (tokens )
164+
152165 for chunk_num in range (num_chunks ):
153- start = int (chunk_num * chunk_size )
154- end = jnp .minimum ((chunk_num + 1 ) * chunk_size , num_tokens )
155- chunk_tokens = jax .lax .slice (tokens , (start ,), (end ,))
156- if chunk_num == 0 :
157- padded_chunk , padded_chunk_true_length = pad_tokens (
158- chunk_tokens ,
159- bos_id ,
160- pad_id ,
161- is_bos ,
162- prefill_lengths ,
163- max_prefill_length ,
164- jax_padding ,
165- )
166- else :
167- # is_bos should be false in subsequent chunks.
168- padded_chunk , padded_chunk_true_length = pad_tokens (
169- chunk_tokens ,
170- bos_id ,
171- pad_id ,
172- False ,
173- prefill_lengths ,
174- max_prefill_length ,
175- jax_padding ,
176- )
166+ start : int = chunk_num * chunk_size
167+ end : int = min ((chunk_num + 1 ) * chunk_size , num_tokens )
168+ chunk_tokens = tokens [start :end ]
169+ # the bos is added at the begin of the function.
170+ # is_bos should be false in chunks.
171+ padded_chunk , padded_chunk_true_length = pad_tokens (
172+ chunk_tokens ,
173+ bos_id ,
174+ pad_id ,
175+ False ,
176+ prefill_lengths ,
177+ max_prefill_length ,
178+ jax_padding ,
179+ )
177180
178181 positions_chunk = jnp .expand_dims (
179182 jnp .arange (start , start + len (padded_chunk ), dtype = jnp .int32 ), 0
0 commit comments