@@ -123,6 +123,22 @@ def generate(
123123 if _measure_and_log_time :
124124 time_measurements = []
125125
126+ block_tables = None
127+ max_blocks_per_batch = 0
128+ if self .enable_paged_attn :
129+ max_blocks_per_batch = (
130+ initial_seqlen + generation_config .max_new_tokens + paged_block_size - 1
131+ ) // paged_block_size
132+
133+ block_tables_list = [
134+ range (i * max_blocks_per_batch , (i + 1 ) * max_blocks_per_batch )
135+ for i in range (batch_size )
136+ ]
137+ block_tables = infinicore .from_list (
138+ block_tables_list ,
139+ dtype = infinicore .int64 ,
140+ )
141+
126142 for iter in range (0 , generation_config .max_new_tokens ):
127143 if _measure_and_log_time :
128144 start_time = time .perf_counter ()
@@ -135,28 +151,28 @@ def generate(
135151 list (range (past_seq_len , past_seq_len + seq_len )) * batch_size ,
136152 dtype = infinicore .int64 ,
137153 )
138- block_tables_list = [
139- [
140- i * batch_size + b
154+
155+ if iter == 0 :
156+ slot_mapping_list = []
157+ for b in range (batch_size ):
158+ slot_mapping_list .extend (
159+ [
160+ b * max_blocks_per_batch * paged_block_size + i
161+ for i in range (seq_len )
162+ ]
163+ )
164+ else :
165+ slot_mapping_list = [
166+ i
141167 for i in range (
142- (past_seq_len + seq_len + paged_block_size - 1 )
143- // paged_block_size
168+ past_seq_len ,
169+ max_blocks_per_batch
170+ * paged_block_size
171+ * initial_batch_size ,
172+ max_blocks_per_batch * paged_block_size ,
144173 )
145174 ]
146- for b in range (batch_size )
147- ]
148- slot_mapping_list = [
149- (((past_seq_len + i ) // paged_block_size ) * batch_size + b )
150- * paged_block_size
151- + (past_seq_len + i ) % paged_block_size
152- for b in range (batch_size )
153- for i in range (seq_len )
154- ]
155-
156- block_tables = infinicore .from_list (
157- block_tables_list ,
158- dtype = infinicore .int64 ,
159- )
175+
160176 slot_mapping = infinicore .from_list (
161177 slot_mapping_list ,
162178 dtype = infinicore .int64 ,
@@ -170,7 +186,6 @@ def generate(
170186 dtype = infinicore .int64 ,
171187 )
172188
173- block_tables = None
174189 slot_mapping = None
175190
176191 past_kv_lengths = infinicore .from_list (
@@ -207,9 +222,9 @@ def generate(
207222 ):
208223 break
209224
210- input_ids = infinicore . from_list (
211- [[ output_id ] for output_id in output_id . to_numpy (). tolist ()]
212- )
225+ # start_prepare_time = time.perf_counter()
226+ input_ids = output_id . view ([ batch_size , 1 ])
227+
213228 past_seq_len = past_seq_len + seq_len
214229
215230 if _measure_and_log_time :
0 commit comments