Skip to content

Commit 91cd578

Browse files
authored
Fix tree mask conflict (#62)
1 parent 989cebd commit 91cd578

1 file changed

Lines changed: 56 additions & 45 deletions

File tree

angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen3_kv.py

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def _update_causal_mask(
765765
target_length = (
766766
attention_mask.shape[-1]
767767
if isinstance(attention_mask, torch.Tensor)
768-
else past_seen_tokens + sequence_length + 1
768+
else past_seen_tokens + sequence_length
769769
)
770770

771771
# In case the provided `attention` mask is 2D, we generate a
@@ -849,50 +849,61 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
849849
causal_mask = attention_mask
850850
else:
851851
min_dtype = torch.finfo(dtype).min
852-
causal_mask = torch.full(
853-
(sequence_length, target_length),
854-
fill_value=min_dtype,
855-
dtype=dtype,
856-
device=device,
857-
)
858-
diagonal_attend_mask = torch.arange(
859-
target_length, device=device
860-
) > cache_position.reshape(-1, 1)
861-
if config.sliding_window is not None:
862-
# if we have sliding window, we should not attend to tokens beyond
863-
# sliding window length, so we mask them out also the check is needed
864-
# to verify is current checkpoint was trained with sliding window or not
865-
if (
866-
not isinstance(past_key_values, SlidingWindowCache)
867-
or sequence_length > target_length
868-
):
869-
sliding_attend_mask = torch.arange(
870-
target_length, device=device
871-
) <= (cache_position.reshape(-1, 1) - config.sliding_window)
872-
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
873-
causal_mask *= diagonal_attend_mask
874-
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
875-
if attention_mask is not None:
876-
causal_mask = (
877-
causal_mask.clone()
878-
) # copy to contiguous memory for in-place edit
879-
if attention_mask.shape[-1] > target_length:
880-
attention_mask = attention_mask[:, :target_length]
881-
mask_length = attention_mask.shape[-1]
882-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
883-
:, None, None, :
884-
].to(causal_mask.device)
885-
padding_mask = padding_mask == 0
886-
causal_mask[:, :, :, :mask_length] = causal_mask[
887-
:, :, :, :mask_length
888-
].masked_fill(padding_mask, min_dtype)
889-
890-
if hasattr(self, "tree_mask") and self.tree_mask is not None:
891-
tree_mask = self.tree_mask
892-
tree_len = tree_mask.size(-1)
893-
causal_mask[:, :, -tree_len:, -tree_len:][
894-
tree_mask == 0
895-
] = causal_mask.min()
852+
if sequence_length == target_length:
853+
causal_mask = torch.full(
854+
(sequence_length, target_length),
855+
fill_value=min_dtype,
856+
dtype=dtype,
857+
device=device,
858+
)
859+
diagonal_attend_mask = torch.arange(
860+
target_length, device=device
861+
) > cache_position.reshape(-1, 1)
862+
if config.sliding_window is not None:
863+
# if we have sliding window, we should not attend to tokens
864+
# beyond sliding window length, so we mask them out also the
865+
# check is needed to verify is current checkpoint was
866+
# trained with sliding window or not
867+
if (
868+
not isinstance(past_key_values, SlidingWindowCache)
869+
or sequence_length > target_length
870+
):
871+
sliding_attend_mask = torch.arange(
872+
target_length, device=device
873+
) <= (cache_position.reshape(-1, 1) - config.sliding_window)
874+
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
875+
causal_mask *= diagonal_attend_mask
876+
causal_mask = causal_mask[None, None, :, :].expand(
877+
batch_size, 1, -1, -1
878+
)
879+
if attention_mask is not None:
880+
causal_mask = (
881+
causal_mask.clone()
882+
) # copy to contiguous memory for in-place edit
883+
if attention_mask.shape[-1] > target_length:
884+
attention_mask = attention_mask[:, :target_length]
885+
mask_length = attention_mask.shape[-1]
886+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
887+
:, None, None, :
888+
].to(causal_mask.device)
889+
padding_mask = padding_mask == 0
890+
causal_mask[:, :, :, :mask_length] = causal_mask[
891+
:, :, :, :mask_length
892+
].masked_fill(padding_mask, min_dtype)
893+
else:
894+
causal_mask = torch.zeros(
895+
(sequence_length, target_length), dtype=dtype, device=device
896+
)
897+
causal_mask = causal_mask[None, None, :, :].expand(
898+
batch_size, 1, -1, -1
899+
)
900+
901+
if hasattr(self, "tree_mask") and self.tree_mask is not None:
902+
tree_mask = self.tree_mask
903+
tree_len = tree_mask.size(-1)
904+
causal_mask[:, :, -tree_len:, -tree_len:][
905+
tree_mask == 0
906+
] = min_dtype
896907
return causal_mask
897908

898909

0 commit comments

Comments
 (0)