|
26 | 26 | from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed |
27 | 27 | from lightllm.utils.log_utils import init_logger |
28 | 28 | from lightllm.utils.dist_utils import get_dp_world_size |
29 | | -from lightllm.utils.envs_utils import ( |
30 | | - get_env_start_args, |
31 | | - get_llm_data_type, |
32 | | - get_added_mtp_kv_layer_num, |
33 | | -) |
| 29 | +from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num |
34 | 30 | from lightllm.distributed.communication_op import dist_group_manager |
35 | 31 | from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput |
36 | 32 | from lightllm.common.triton_utils.autotuner import AutotuneLevel |
@@ -381,105 +377,36 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s |
381 | 377 | is_mtp_grouped_decode = (not model_input.is_prefill) and self.args.mtp_step > 0 |
382 | 378 | if is_mtp_grouped_decode: |
383 | 379 | mtp_size = self.args.mtp_step + 1 |
384 | | - assert model_input.batch_size % mtp_size == 0 |
385 | | - assert new_batch_size % mtp_size == 0 |
386 | 380 | assert padded_batch_size % mtp_size == 0 |
387 | 381 | padded_req_num = padded_batch_size // mtp_size |
388 | | - |
389 | | - pad_mtp_index = torch.arange( |
390 | | - mtp_size, |
391 | | - dtype=new_model_input.b_mtp_index.dtype, |
392 | | - device=new_model_input.b_mtp_index.device, |
393 | | - ).repeat(padded_req_num) |
394 | | - pad_seq_len = torch.arange( |
395 | | - 2, |
396 | | - mtp_size + 2, |
397 | | - dtype=new_model_input.b_seq_len.dtype, |
398 | | - device=new_model_input.b_seq_len.device, |
399 | | - ).repeat(padded_req_num) |
400 | 382 | new_model_input.total_token_num += padded_req_num * (mtp_size * (mtp_size + 3) // 2) |
401 | 383 | new_model_input.max_kv_seq_len = max(mtp_size + 1, model_input.max_kv_seq_len) |
402 | | - new_model_input.input_ids = torch.cat( |
403 | | - ( |
404 | | - new_model_input.input_ids, |
405 | | - torch.ones( |
406 | | - padded_batch_size, |
407 | | - dtype=new_model_input.input_ids.dtype, |
408 | | - device=new_model_input.input_ids.device, |
409 | | - ), |
410 | | - ), |
411 | | - dim=0, |
412 | | - ) |
413 | | - new_model_input.b_req_idx = torch.cat( |
414 | | - ( |
415 | | - new_model_input.b_req_idx, |
416 | | - torch.full( |
417 | | - (padded_batch_size,), |
418 | | - self.req_manager.HOLD_REQUEST_ID, |
419 | | - dtype=new_model_input.b_req_idx.dtype, |
420 | | - device=new_model_input.b_req_idx.device, |
421 | | - ), |
422 | | - ), |
423 | | - dim=0, |
424 | | - ) |
425 | | - new_model_input.b_mtp_index = torch.cat((new_model_input.b_mtp_index, pad_mtp_index), dim=0) |
| 384 | + pad_seq_len = torch.arange( |
| 385 | + 2, mtp_size + 2, dtype=new_model_input.b_seq_len.dtype, device=new_model_input.b_seq_len.device |
| 386 | + ).repeat(padded_req_num) |
426 | 387 | new_model_input.b_seq_len = torch.cat((new_model_input.b_seq_len, pad_seq_len), dim=0) |
427 | | - new_model_input.mem_indexes = torch.cat( |
428 | | - ( |
429 | | - new_model_input.mem_indexes, |
430 | | - torch.full( |
431 | | - (padded_batch_size,), |
432 | | - self.mem_manager.HOLD_TOKEN_MEMINDEX, |
433 | | - dtype=new_model_input.mem_indexes.dtype, |
434 | | - device=new_model_input.mem_indexes.device, |
435 | | - ), |
436 | | - ), |
437 | | - dim=0, |
438 | | - ) |
439 | | - new_model_input.b_num_accepted_tokens = torch.cat( |
440 | | - ( |
441 | | - new_model_input.b_num_accepted_tokens, |
442 | | - torch.ones( |
443 | | - padded_req_num, |
444 | | - dtype=new_model_input.b_num_accepted_tokens.dtype, |
445 | | - device=new_model_input.b_num_accepted_tokens.device, |
446 | | - ), |
447 | | - ), |
448 | | - dim=0, |
449 | | - ) |
| 388 | + # b_num_accepted_tokens 不再随 model_input 流转/补齐:它在 GDN 的 init_mtp_verify_extra_state |
| 389 | + # 里按 req_first 从 req_to_accept_len gather,padding 组 req_first=HOLD(槽恒为 1)自然得 1。 |
450 | 390 | else: |
451 | 391 | new_model_input.total_token_num += padded_batch_size * 2 |
452 | 392 | new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) |
453 | | - new_model_input.input_ids = F.pad( |
454 | | - new_model_input.input_ids, |
455 | | - (0, padded_batch_size), |
456 | | - mode="constant", |
457 | | - value=1, |
458 | | - ) |
459 | | - new_model_input.b_req_idx = F.pad( |
460 | | - new_model_input.b_req_idx, |
461 | | - (0, padded_batch_size), |
462 | | - mode="constant", |
463 | | - value=self.req_manager.HOLD_REQUEST_ID, |
464 | | - ) |
465 | | - new_model_input.b_mtp_index = F.pad( |
466 | | - new_model_input.b_mtp_index, |
467 | | - (0, padded_batch_size), |
468 | | - mode="constant", |
469 | | - value=0, |
470 | | - ) |
471 | 393 | new_model_input.b_seq_len = F.pad( |
472 | | - new_model_input.b_seq_len, |
473 | | - (0, padded_batch_size), |
474 | | - mode="constant", |
475 | | - value=2, |
476 | | - ) |
477 | | - new_model_input.mem_indexes = F.pad( |
478 | | - new_model_input.mem_indexes, |
479 | | - (0, padded_batch_size), |
480 | | - mode="constant", |
481 | | - value=self.mem_manager.HOLD_TOKEN_MEMINDEX, |
| 394 | + new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2 |
482 | 395 | ) |
| 396 | + |
| 397 | + new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) |
| 398 | + new_model_input.b_req_idx = F.pad( |
| 399 | + new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID |
| 400 | + ) |
| 401 | + new_model_input.b_mtp_index = F.pad( |
| 402 | + new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0 |
| 403 | + ) |
| 404 | + new_model_input.mem_indexes = F.pad( |
| 405 | + new_model_input.mem_indexes, |
| 406 | + (0, padded_batch_size), |
| 407 | + mode="constant", |
| 408 | + value=self.mem_manager.HOLD_TOKEN_MEMINDEX, |
| 409 | + ) |
483 | 410 | new_model_input.multimodal_params = new_model_input.multimodal_params + [ |
484 | 411 | {"images": [], "audios": []} for _ in range(padded_batch_size) |
485 | 412 | ] |
@@ -698,6 +625,7 @@ def _decode( |
698 | 625 |
|
699 | 626 | @final |
700 | 627 | def _context_forward(self, infer_state: InferStateInfo): |
| 628 | + |
701 | 629 | input_embs = self.pre_infer.context_forward(infer_state.input_ids, infer_state, self.pre_post_weight) |
702 | 630 | if self.args.enable_dp_prefill_balance: |
703 | 631 | assert not self.args.enable_prefill_cudagraph, "not support now" |
|
0 commit comments