Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 72 additions & 11 deletions aphrodite/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,21 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
return q_idx >= kv_idx


def prefixlm_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
kv_idx: torch.Tensor, prefix_len: int):
"""
Mask function for PrefixLM (Prefix Language Modeling).

In PrefixLM:
- Tokens 0 to prefix_len-1 (prefix) can attend bidirectionally to each other
- Tokens prefix_len+ onwards (suffix) follow causal masking
- Prefix tokens can attend to suffix tokens, but suffix tokens cannot attend to prefix tokens
"""
return ((q_idx < prefix_len)
| ((q_idx >= prefix_len) & (kv_idx >= prefix_len)
& (q_idx >= kv_idx)))
Comment on lines +263 to +265

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The boolean expression for the mask can be simplified. The (q_idx >= prefix_len) check in the second part of the | operation is redundant. If the first part (q_idx < prefix_len) is false, then q_idx >= prefix_len is implicitly true. Removing this redundant check will make the code slightly more efficient and easier to read.

    return ((q_idx < prefix_len)
            | ((kv_idx >= prefix_len) & (q_idx >= kv_idx)))



@dataclass
class FlexAttentionMetadata:
causal: bool
Expand All @@ -261,12 +276,6 @@ class FlexAttentionMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor

use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]

# Block info
total_cache_tokens: int
block_size: int
Expand All @@ -276,6 +285,16 @@ class FlexAttentionMetadata:
decode_offset: torch.Tensor
num_blocks_per_seq: torch.Tensor

use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]

# PrefixLM support
prefixlm: bool = False
prefix_len: int = 0

# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.

Expand Down Expand Up @@ -377,6 +396,39 @@ def final_mask_mod(

return final_mask_mod

def get_prefixlm_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for PrefixLM.

This function creates the combined mask mod function that handles:
1. The paged attention block mapping
2. The mapping from packed query sequences to logical query entries
3. PrefixLM masking logic

It also by defaults adds the decoding offset to the query indices.
With this info we create the "logical" indices that are passed to
mask_mod functions. This allows mask mod functions to be agnostic to
layout of the query and key/value tensors.
"""
assert self.doc_ids is not None

def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx,
logical_kv_idx) = self._convert_physical_to_logical(
self.doc_ids, q_idx, physical_kv_idx)
# Apply mask modification only for valid indices
return torch.where(
is_valid,
prefixlm_mask_mod(b, h, logical_q_idx, logical_kv_idx, self.prefix_len),
False,
)

return final_mask_mod
Comment on lines +399 to +430

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new method get_prefixlm_mask_mod is almost identical to get_causal_mask_mod. This code duplication can make maintenance harder. Consider refactoring the common logic into a shared helper method to improve code quality and reduce redundancy.


def get_transformed_score_mod(self) -> Optional[_score_mod_signature]:
"""Creates the transformed score_mod function for FlexAttention.

Expand Down Expand Up @@ -469,7 +521,10 @@ def _build_block_mask_direct(self) -> BlockMask:
return BlockMask.from_kv_blocks(**block_mask_kwargs)

def build_block_mask(self) -> BlockMask:
if self.causal:
if self.prefixlm:
mask_mod = self.get_prefixlm_mask_mod()
kv_len = self.total_cache_tokens
elif self.causal:
mask_mod = self.get_causal_mask_mod()
kv_len = self.total_cache_tokens
Comment on lines +524 to 529

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for prefixlm and causal attention is very similar, leading to duplicated code. You can combine these two branches to make the code more concise and maintainable.

        if self.prefixlm or self.causal:
            if self.prefixlm:
                mask_mod = self.get_prefixlm_mask_mod()
            else:
                mask_mod = self.get_causal_mask_mod()
            kv_len = self.total_cache_tokens

else:
Expand All @@ -495,14 +550,16 @@ def __post_init__(self):
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
self.num_blocks = self.total_cache_tokens // self.block_size

if self.causal:
if self.prefixlm:
self.mask_mod = self.get_prefixlm_mask_mod()
elif self.causal:
self.mask_mod = self.get_causal_mask_mod()
else:
self.mask_mod = self.get_bidirectional_mask_mod()

self.transformed_score_mod = self.get_transformed_score_mod()

if self.direct_build and self.causal:
if self.direct_build and (self.causal or self.prefixlm):
self.block_mask = self._build_block_mask_direct()
else:
self.block_mask = self.build_block_mask()
Expand Down Expand Up @@ -538,7 +595,9 @@ def reorder_batch(self, input_batch: "InputBatch",
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlexAttentionMetadata:
fast_build: bool = False,
prefixlm: bool = False,
prefix_len: int = 0) -> FlexAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
Expand Down Expand Up @@ -585,6 +644,8 @@ def build(self,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefixlm=prefixlm,
prefix_len=prefix_len,
block_size=block_size,
max_possible_sequence_length=max_possible_seq_len,
num_reqs=num_reqs,
Expand Down Expand Up @@ -708,7 +769,7 @@ def forward(

num_actual_tokens = attn_metadata.num_actual_tokens

if not attn_metadata.causal:
if not attn_metadata.causal and not attn_metadata.prefixlm:
assert self.attn_type == AttentionType.ENCODER_ONLY

query, key_tensor, value_tensor = map(
Expand Down
Loading