Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def _update_causal_mask(
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
else past_seen_tokens + sequence_length
)

# In case the provided `attention` mask is 2D, we generate a
Expand Down Expand Up @@ -849,50 +849,61 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
diagonal_attend_mask = torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond
# sliding window length, so we mask them out also the check is needed
# to verify is current checkpoint was trained with sliding window or not
if (
not isinstance(past_key_values, SlidingWindowCache)
or sequence_length > target_length
):
sliding_attend_mask = torch.arange(
target_length, device=device
) <= (cache_position.reshape(-1, 1) - config.sliding_window)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
:, None, None, :
].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)

if hasattr(self, "tree_mask") and self.tree_mask is not None:
tree_mask = self.tree_mask
tree_len = tree_mask.size(-1)
causal_mask[:, :, -tree_len:, -tree_len:][
tree_mask == 0
] = causal_mask.min()
if sequence_length == target_length:
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
diagonal_attend_mask = torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens
# beyond sliding window length, so we mask them out also the
# check is needed to verify is current checkpoint was
# trained with sliding window or not
if (
not isinstance(past_key_values, SlidingWindowCache)
or sequence_length > target_length
):
sliding_attend_mask = torch.arange(
target_length, device=device
) <= (cache_position.reshape(-1, 1) - config.sliding_window)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(
batch_size, 1, -1, -1
)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
:, None, None, :
].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
else:
causal_mask = torch.zeros(
(sequence_length, target_length), dtype=dtype, device=device
)
causal_mask = causal_mask[None, None, :, :].expand(
batch_size, 1, -1, -1
)

if hasattr(self, "tree_mask") and self.tree_mask is not None:
tree_mask = self.tree_mask
tree_len = tree_mask.size(-1)
causal_mask[:, :, -tree_len:, -tree_len:][
tree_mask == 0
] = min_dtype
return causal_mask


Expand Down