Skip to content

Commit be5878b

Browse files
Merge pull request #207 from InfiniTensor/issue/206
issue/991 optimize input preparation
2 parents 805212c + 2b8699b commit be5878b

File tree

1 file changed

+38
-23
lines changed

1 file changed

+38
-23
lines changed

python/infinilm/infer_engine.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,22 @@ def generate(
123123
if _measure_and_log_time:
124124
time_measurements = []
125125

126+
block_tables = None
127+
max_blocks_per_batch = 0
128+
if self.enable_paged_attn:
129+
max_blocks_per_batch = (
130+
initial_seqlen + generation_config.max_new_tokens + paged_block_size - 1
131+
) // paged_block_size
132+
133+
block_tables_list = [
134+
range(i * max_blocks_per_batch, (i + 1) * max_blocks_per_batch)
135+
for i in range(batch_size)
136+
]
137+
block_tables = infinicore.from_list(
138+
block_tables_list,
139+
dtype=infinicore.int64,
140+
)
141+
126142
for iter in range(0, generation_config.max_new_tokens):
127143
if _measure_and_log_time:
128144
start_time = time.perf_counter()
@@ -135,28 +151,28 @@ def generate(
135151
list(range(past_seq_len, past_seq_len + seq_len)) * batch_size,
136152
dtype=infinicore.int64,
137153
)
138-
block_tables_list = [
139-
[
140-
i * batch_size + b
154+
155+
if iter == 0:
156+
slot_mapping_list = []
157+
for b in range(batch_size):
158+
slot_mapping_list.extend(
159+
[
160+
b * max_blocks_per_batch * paged_block_size + i
161+
for i in range(seq_len)
162+
]
163+
)
164+
else:
165+
slot_mapping_list = [
166+
i
141167
for i in range(
142-
(past_seq_len + seq_len + paged_block_size - 1)
143-
// paged_block_size
168+
past_seq_len,
169+
max_blocks_per_batch
170+
* paged_block_size
171+
* initial_batch_size,
172+
max_blocks_per_batch * paged_block_size,
144173
)
145174
]
146-
for b in range(batch_size)
147-
]
148-
slot_mapping_list = [
149-
(((past_seq_len + i) // paged_block_size) * batch_size + b)
150-
* paged_block_size
151-
+ (past_seq_len + i) % paged_block_size
152-
for b in range(batch_size)
153-
for i in range(seq_len)
154-
]
155-
156-
block_tables = infinicore.from_list(
157-
block_tables_list,
158-
dtype=infinicore.int64,
159-
)
175+
160176
slot_mapping = infinicore.from_list(
161177
slot_mapping_list,
162178
dtype=infinicore.int64,
@@ -170,7 +186,6 @@ def generate(
170186
dtype=infinicore.int64,
171187
)
172188

173-
block_tables = None
174189
slot_mapping = None
175190

176191
past_kv_lengths = infinicore.from_list(
@@ -207,9 +222,9 @@ def generate(
207222
):
208223
break
209224

210-
input_ids = infinicore.from_list(
211-
[[output_id] for output_id in output_id.to_numpy().tolist()]
212-
)
225+
# start_prepare_time = time.perf_counter()
226+
input_ids = output_id.view([batch_size, 1])
227+
213228
past_seq_len = past_seq_len + seq_len
214229

215230
if _measure_and_log_time:

0 commit comments

Comments
 (0)