|
4 | 4 | from ..llm.scheduler import SchedulerOutput |
5 | 5 |
|
6 | 6 |
|
| 7 | +def extend_to_next_power_of_two(lst): |
| 8 | + """Pad ``lst`` to the next power-of-two length with ``-1``. |
| 9 | +
|
| 10 | + Padding marks unused slots (same convention as ``block_tables``). |
| 11 | + Callers must ``narrow`` to the real length before passing data to kernels. |
| 12 | +
|
| 13 | + Args: |
| 14 | + lst: Input list of numeric offsets or cumulative lengths. |
| 15 | +
|
| 16 | + Returns: |
| 17 | + A new list. Empty input yields ``[0]``; already power-of-two yields a copy. |
| 18 | + """ |
| 19 | + if not lst: |
| 20 | + return [0] |
| 21 | + n = len(lst) |
| 22 | + next_pow = 1 |
| 23 | + while next_pow < n: |
| 24 | + next_pow <<= 1 |
| 25 | + if next_pow == n: |
| 26 | + return lst[:] |
| 27 | + return lst + [-1] * (next_pow - n) |
| 28 | + |
| 29 | + |
7 | 30 | @register_processor("default") |
8 | 31 | class BasicLLMProcessor(InfinilmProcessor): |
9 | 32 | def __init__(self, model_dir_path: str): |
@@ -35,9 +58,13 @@ def apply_chat_template( |
35 | 58 | normalized_conversation = [] |
36 | 59 | for message in conversation: |
37 | 60 | if isinstance(message["content"], list): |
38 | | - assert len(message["content"]) == 1, "Only one content item supported in list" |
| 61 | + assert len(message["content"]) == 1, ( |
| 62 | + "Only one content item supported in list" |
| 63 | + ) |
39 | 64 | content_item = message["content"][0] |
40 | | - assert "type" in content_item and "text" in content_item, "Content dict must have 'type' and 'text' keys" |
| 65 | + assert "type" in content_item and "text" in content_item, ( |
| 66 | + "Content dict must have 'type' and 'text' keys" |
| 67 | + ) |
41 | 68 | normalized_conversation.append( |
42 | 69 | {"role": message["role"], "content": content_item["text"]} |
43 | 70 | ) |
@@ -229,21 +256,44 @@ def _build_model_input_from_batch_scheduler_output( |
229 | 256 | block_tables.append(padded_block_table) |
230 | 257 | cu_seqlens.append(cu_seqlens[-1] + seq_len) |
231 | 258 |
|
232 | | - return { |
233 | | - "input_ids": infinicore.from_list([tokens], dtype=infinicore.int64), |
234 | | - "position_ids": infinicore.from_list(position_ids, dtype=infinicore.int64), |
235 | | - "past_kv_lengths": infinicore.from_list( |
236 | | - cached_lens, dtype=infinicore.int32 |
237 | | - ), |
238 | | - "total_kv_lengths": infinicore.from_list(seq_lens, dtype=infinicore.int32), |
239 | | - "input_offsets": infinicore.from_list(seq_offsets, dtype=infinicore.int32), |
240 | | - "cu_seqlens": infinicore.from_list(cu_seqlens, dtype=infinicore.int32), |
241 | | - "block_tables": infinicore.from_list(block_tables, dtype=infinicore.int32), |
242 | | - "slot_mapping": infinicore.from_list(slot_mapping, dtype=infinicore.int64), |
| 259 | + assert seq_offsets[-1] == len(tokens), ( |
| 260 | + f"seq_offsets[-1]={seq_offsets[-1]} != len(tokens)={len(tokens)}" |
| 261 | + ) |
| 262 | + |
| 263 | + length = len(seq_offsets) |
| 264 | + seq_offsets = extend_to_next_power_of_two(seq_offsets) |
| 265 | + cu_seqlens = extend_to_next_power_of_two(cu_seqlens) |
| 266 | + |
| 267 | + input_ids = infinicore.from_list([tokens], dtype=infinicore.int64) |
| 268 | + position_ids = infinicore.from_list(position_ids, dtype=infinicore.int64) |
| 269 | + past_kv_lengths = infinicore.from_list(cached_lens, dtype=infinicore.int32) |
| 270 | + total_kv_lengths = infinicore.from_list(seq_lens, dtype=infinicore.int32) |
| 271 | + |
| 272 | + input_offsets = infinicore.from_list( |
| 273 | + seq_offsets, dtype=infinicore.int32 |
| 274 | + ).narrow(0, 0, length) |
| 275 | + |
| 276 | + cu_seqlens = infinicore.from_list(cu_seqlens, dtype=infinicore.int32).narrow( |
| 277 | + 0, 0, length |
| 278 | + ) |
| 279 | + |
| 280 | + block_tables = infinicore.from_list(block_tables, dtype=infinicore.int32) |
| 281 | + slot_mapping = infinicore.from_list(slot_mapping, dtype=infinicore.int64) |
| 282 | + |
| 283 | + return_dict = { |
| 284 | + "input_ids": input_ids, |
| 285 | + "position_ids": position_ids, |
| 286 | + "past_kv_lengths": past_kv_lengths, |
| 287 | + "total_kv_lengths": total_kv_lengths, |
| 288 | + "input_offsets": input_offsets, |
| 289 | + "cu_seqlens": cu_seqlens, |
| 290 | + "block_tables": block_tables, |
| 291 | + "slot_mapping": slot_mapping, |
243 | 292 | "temperature": temperature, |
244 | 293 | "top_k": top_k, |
245 | 294 | "top_p": top_p, |
246 | 295 | } |
| 296 | + return return_dict |
247 | 297 |
|
248 | 298 | def get_tokenizer(self): |
249 | 299 | return self.tokenizer |
0 commit comments