diff --git a/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen3_kv.py b/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen3_kv.py index 49737e4e..64b69e17 100644 --- a/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen3_kv.py +++ b/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen3_kv.py @@ -765,7 +765,7 @@ def _update_causal_mask( target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 + else past_seen_tokens + sequence_length ) # In case the provided `attention` mask is 2D, we generate a @@ -849,50 +849,61 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device, - ) - diagonal_attend_mask = torch.arange( - target_length, device=device - ) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond - # sliding window length, so we mask them out also the check is needed - # to verify is current checkpoint was trained with sliding window or not - if ( - not isinstance(past_key_values, SlidingWindowCache) - or sequence_length > target_length - ): - sliding_attend_mask = torch.arange( - target_length, device=device - ) <= (cache_position.reshape(-1, 1) - config.sliding_window) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = ( - causal_mask.clone() - ) # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ - :, None, None, : - ].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[ - :, :, :, :mask_length - ].masked_fill(padding_mask, min_dtype) - - if hasattr(self, "tree_mask") and self.tree_mask is not None: - tree_mask = self.tree_mask - tree_len = tree_mask.size(-1) - causal_mask[:, :, -tree_len:, -tree_len:][ - tree_mask == 0 - ] = causal_mask.min() + if sequence_length == target_length: + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens + # beyond sliding window length, so we mask them out also the + # check is needed to verify is current checkpoint was + # trained with sliding window or not + if ( + not isinstance(past_key_values, SlidingWindowCache) + or sequence_length > target_length + ): + sliding_attend_mask = torch.arange( + target_length, device=device + ) <= (cache_position.reshape(-1, 1) - config.sliding_window) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand( + batch_size, 1, -1, -1 + ) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ + :, None, None, : + ].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + else: + causal_mask = torch.zeros( + (sequence_length, target_length), dtype=dtype, device=device + ) + causal_mask = causal_mask[None, None, :, :].expand( + batch_size, 1, -1, -1 + ) + + if hasattr(self, "tree_mask") and self.tree_mask is not None: + tree_mask = self.tree_mask + tree_len = tree_mask.size(-1) + causal_mask[:, :, -tree_len:, -tree_len:][ + tree_mask == 0 + ] = min_dtype return causal_mask