@@ -765,7 +765,7 @@ def _update_causal_mask(
765765 target_length = (
766766 attention_mask .shape [- 1 ]
767767 if isinstance (attention_mask , torch .Tensor )
768- else past_seen_tokens + sequence_length + 1
768+ else past_seen_tokens + sequence_length
769769 )
770770
771771 # In case the provided `attention` mask is 2D, we generate a
@@ -849,50 +849,61 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
849849 causal_mask = attention_mask
850850 else :
851851 min_dtype = torch .finfo (dtype ).min
852- causal_mask = torch .full (
853- (sequence_length , target_length ),
854- fill_value = min_dtype ,
855- dtype = dtype ,
856- device = device ,
857- )
858- diagonal_attend_mask = torch .arange (
859- target_length , device = device
860- ) > cache_position .reshape (- 1 , 1 )
861- if config .sliding_window is not None :
862- # if we have sliding window, we should not attend to tokens beyond
863- # sliding window length, so we mask them out also the check is needed
864- # to verify is current checkpoint was trained with sliding window or not
865- if (
866- not isinstance (past_key_values , SlidingWindowCache )
867- or sequence_length > target_length
868- ):
869- sliding_attend_mask = torch .arange (
870- target_length , device = device
871- ) <= (cache_position .reshape (- 1 , 1 ) - config .sliding_window )
872- diagonal_attend_mask .bitwise_or_ (sliding_attend_mask )
873- causal_mask *= diagonal_attend_mask
874- causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
875- if attention_mask is not None :
876- causal_mask = (
877- causal_mask .clone ()
878- ) # copy to contiguous memory for in-place edit
879- if attention_mask .shape [- 1 ] > target_length :
880- attention_mask = attention_mask [:, :target_length ]
881- mask_length = attention_mask .shape [- 1 ]
882- padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [
883- :, None , None , :
884- ].to (causal_mask .device )
885- padding_mask = padding_mask == 0
886- causal_mask [:, :, :, :mask_length ] = causal_mask [
887- :, :, :, :mask_length
888- ].masked_fill (padding_mask , min_dtype )
889-
890- if hasattr (self , "tree_mask" ) and self .tree_mask is not None :
891- tree_mask = self .tree_mask
892- tree_len = tree_mask .size (- 1 )
893- causal_mask [:, :, - tree_len :, - tree_len :][
894- tree_mask == 0
895- ] = causal_mask .min ()
852+ if sequence_length == target_length :
853+ causal_mask = torch .full (
854+ (sequence_length , target_length ),
855+ fill_value = min_dtype ,
856+ dtype = dtype ,
857+ device = device ,
858+ )
859+ diagonal_attend_mask = torch .arange (
860+ target_length , device = device
861+ ) > cache_position .reshape (- 1 , 1 )
862+ if config .sliding_window is not None :
863+ # if we have sliding window, we should not attend to tokens
864+ # beyond sliding window length, so we mask them out also the
865+ # check is needed to verify is current checkpoint was
866+ # trained with sliding window or not
867+ if (
868+ not isinstance (past_key_values , SlidingWindowCache )
869+ or sequence_length > target_length
870+ ):
871+ sliding_attend_mask = torch .arange (
872+ target_length , device = device
873+ ) <= (cache_position .reshape (- 1 , 1 ) - config .sliding_window )
874+ diagonal_attend_mask .bitwise_or_ (sliding_attend_mask )
875+ causal_mask *= diagonal_attend_mask
876+ causal_mask = causal_mask [None , None , :, :].expand (
877+ batch_size , 1 , - 1 , - 1
878+ )
879+ if attention_mask is not None :
880+ causal_mask = (
881+ causal_mask .clone ()
882+ ) # copy to contiguous memory for in-place edit
883+ if attention_mask .shape [- 1 ] > target_length :
884+ attention_mask = attention_mask [:, :target_length ]
885+ mask_length = attention_mask .shape [- 1 ]
886+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [
887+ :, None , None , :
888+ ].to (causal_mask .device )
889+ padding_mask = padding_mask == 0
890+ causal_mask [:, :, :, :mask_length ] = causal_mask [
891+ :, :, :, :mask_length
892+ ].masked_fill (padding_mask , min_dtype )
893+ else :
894+ causal_mask = torch .zeros (
895+ (sequence_length , target_length ), dtype = dtype , device = device
896+ )
897+ causal_mask = causal_mask [None , None , :, :].expand (
898+ batch_size , 1 , - 1 , - 1
899+ )
900+
901+ if hasattr (self , "tree_mask" ) and self .tree_mask is not None :
902+ tree_mask = self .tree_mask
903+ tree_len = tree_mask .size (- 1 )
904+ causal_mask [:, :, - tree_len :, - tree_len :][
905+ tree_mask == 0
906+ ] = min_dtype
896907 return causal_mask
897908
898909
0 commit comments