|
28 | 28 | from ..multi_stream_utils import maybe_execute_in_parallel |
29 | 29 | from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
30 | 30 | from .causal_conv1d_triton import causal_conv1d_update as causal_conv1d_update_triton |
| 31 | +from .fuse_elementwise_ops import extract_transpose_prefill_slice |
31 | 32 | from .layernorm_gated import RMSNorm as RMSNormGated |
32 | 33 | from .mamba2_metadata import Mamba2Metadata |
33 | 34 |
|
@@ -544,16 +545,22 @@ def forward_extend( |
544 | 545 | query_start_loc_p = query_start_loc[: num_prefill + 1] |
545 | 546 | has_initial_states_p = has_initial_states[:num_prefill] |
546 | 547 |
|
547 | | - mixed_qkv_p = causal_conv1d_fn( |
548 | | - mixed_qkv_p.transpose(0, 1), |
| 548 | + mixed_qkv_p_t = extract_transpose_prefill_slice( |
| 549 | + mixed_qkv_p, |
| 550 | + mixed_qkv_p.shape[0], |
| 551 | + 0, |
| 552 | + mixed_qkv_p.shape[1], |
| 553 | + ) |
| 554 | + mixed_qkv_p_t = causal_conv1d_fn( |
| 555 | + mixed_qkv_p_t, |
549 | 556 | self.conv1d.weight, |
550 | 557 | self.conv1d.bias, |
551 | 558 | activation=self.activation, |
552 | 559 | conv_states=conv_states_to_use, |
553 | 560 | has_initial_state=has_initial_states_p, |
554 | 561 | cache_indices=state_indices_p, |
555 | 562 | query_start_loc=query_start_loc_p, |
556 | | - ).transpose(0, 1) |
| 563 | + ) |
557 | 564 |
|
558 | 565 | if is_target_verify: |
559 | 566 | draft_token_num = spec_metadata.max_draft_len + 1 |
@@ -588,10 +595,17 @@ def forward_extend( |
588 | 595 | activation=self.activation, |
589 | 596 | conv_state_indices=state_indices_d, |
590 | 597 | ) |
| 598 | + mixed_qkv_p.copy_(mixed_qkv_p_t.transpose(0, 1)) |
591 | 599 | mixed_qkv = torch.cat((mixed_qkv_p, mixed_qkv_d), dim=0) |
592 | 600 | else: |
| 601 | + mixed_qkv_t = extract_transpose_prefill_slice( |
| 602 | + mixed_qkv, |
| 603 | + mixed_qkv.shape[0], |
| 604 | + 0, |
| 605 | + mixed_qkv.shape[1], |
| 606 | + ) |
593 | 607 | mixed_qkv = causal_conv1d_fn( |
594 | | - mixed_qkv.transpose(0, 1), |
| 608 | + mixed_qkv_t, |
595 | 609 | self.conv1d.weight, |
596 | 610 | self.conv1d.bias, |
597 | 611 | activation=self.activation, |
|
0 commit comments