-
-
Notifications
You must be signed in to change notification settings - Fork 200
[Attention] feat: support PrefixLM #1526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -250,6 +250,21 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, | |
| return q_idx >= kv_idx | ||
|
|
||
|
|
||
| def prefixlm_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, | ||
| kv_idx: torch.Tensor, prefix_len: int): | ||
| """ | ||
| Mask function for PrefixLM (Prefix Language Modeling). | ||
|
|
||
| In PrefixLM: | ||
| - Tokens 0 to prefix_len-1 (prefix) can attend bidirectionally to each other | ||
| - Tokens prefix_len+ onwards (suffix) follow causal masking | ||
| - Prefix tokens can attend to suffix tokens, but suffix tokens cannot attend to prefix tokens | ||
| """ | ||
| return ((q_idx < prefix_len) | ||
| | ((q_idx >= prefix_len) & (kv_idx >= prefix_len) | ||
| & (q_idx >= kv_idx))) | ||
|
|
||
|
|
||
| @dataclass | ||
| class FlexAttentionMetadata: | ||
| causal: bool | ||
|
|
@@ -261,12 +276,6 @@ class FlexAttentionMetadata: | |
| block_table: torch.Tensor | ||
| slot_mapping: torch.Tensor | ||
|
|
||
| use_cascade: bool | ||
| common_prefix_len: int | ||
| cu_prefix_query_lens: Optional[torch.Tensor] | ||
| prefix_kv_lens: Optional[torch.Tensor] | ||
| suffix_kv_lens: Optional[torch.Tensor] | ||
|
|
||
| # Block info | ||
| total_cache_tokens: int | ||
| block_size: int | ||
|
|
@@ -276,6 +285,16 @@ class FlexAttentionMetadata: | |
| decode_offset: torch.Tensor | ||
| num_blocks_per_seq: torch.Tensor | ||
|
|
||
| use_cascade: bool | ||
| common_prefix_len: int | ||
| cu_prefix_query_lens: Optional[torch.Tensor] | ||
| prefix_kv_lens: Optional[torch.Tensor] | ||
| suffix_kv_lens: Optional[torch.Tensor] | ||
|
|
||
| # PrefixLM support | ||
| prefixlm: bool = False | ||
| prefix_len: int = 0 | ||
|
|
||
| # For logging. | ||
| num_input_tokens: int = 0 # Number of tokens including padding. | ||
|
|
||
|
|
@@ -377,6 +396,39 @@ def final_mask_mod( | |
|
|
||
| return final_mask_mod | ||
|
|
||
| def get_prefixlm_mask_mod(self) -> _mask_mod_signature: | ||
| """Creates the mask_mod function for PrefixLM. | ||
|
|
||
| This function creates the combined mask mod function that handles: | ||
| 1. The paged attention block mapping | ||
| 2. The mapping from packed query sequences to logical query entries | ||
| 3. PrefixLM masking logic | ||
|
|
||
| It also by defaults adds the decoding offset to the query indices. | ||
| With this info we create the "logical" indices that are passed to | ||
| mask_mod functions. This allows mask mod functions to be agnostic to | ||
| layout of the query and key/value tensors. | ||
| """ | ||
| assert self.doc_ids is not None | ||
|
|
||
| def final_mask_mod( | ||
| b: torch.Tensor, | ||
| h: torch.Tensor, | ||
| q_idx: torch.Tensor, | ||
| physical_kv_idx: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| (is_valid, logical_q_idx, | ||
| logical_kv_idx) = self._convert_physical_to_logical( | ||
| self.doc_ids, q_idx, physical_kv_idx) | ||
| # Apply mask modification only for valid indices | ||
| return torch.where( | ||
| is_valid, | ||
| prefixlm_mask_mod(b, h, logical_q_idx, logical_kv_idx, self.prefix_len), | ||
| False, | ||
| ) | ||
|
|
||
| return final_mask_mod | ||
|
Comment on lines
+399
to
+430
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: | ||
| """Creates the transformed score_mod function for FlexAttention. | ||
|
|
||
|
|
@@ -469,7 +521,10 @@ def _build_block_mask_direct(self) -> BlockMask: | |
| return BlockMask.from_kv_blocks(**block_mask_kwargs) | ||
|
|
||
| def build_block_mask(self) -> BlockMask: | ||
| if self.causal: | ||
| if self.prefixlm: | ||
| mask_mod = self.get_prefixlm_mask_mod() | ||
| kv_len = self.total_cache_tokens | ||
| elif self.causal: | ||
| mask_mod = self.get_causal_mask_mod() | ||
| kv_len = self.total_cache_tokens | ||
|
Comment on lines
+524
to
529
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for if self.prefixlm or self.causal:
if self.prefixlm:
mask_mod = self.get_prefixlm_mask_mod()
else:
mask_mod = self.get_causal_mask_mod()
kv_len = self.total_cache_tokens |
||
| else: | ||
|
|
@@ -495,14 +550,16 @@ def __post_init__(self): | |
| self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) | ||
| self.num_blocks = self.total_cache_tokens // self.block_size | ||
|
|
||
| if self.causal: | ||
| if self.prefixlm: | ||
| self.mask_mod = self.get_prefixlm_mask_mod() | ||
| elif self.causal: | ||
| self.mask_mod = self.get_causal_mask_mod() | ||
| else: | ||
| self.mask_mod = self.get_bidirectional_mask_mod() | ||
|
|
||
| self.transformed_score_mod = self.get_transformed_score_mod() | ||
|
|
||
| if self.direct_build and self.causal: | ||
| if self.direct_build and (self.causal or self.prefixlm): | ||
| self.block_mask = self._build_block_mask_direct() | ||
| else: | ||
| self.block_mask = self.build_block_mask() | ||
|
|
@@ -538,7 +595,9 @@ def reorder_batch(self, input_batch: "InputBatch", | |
| def build(self, | ||
| common_prefix_len: int, | ||
| common_attn_metadata: CommonAttentionMetadata, | ||
| fast_build: bool = False) -> FlexAttentionMetadata: | ||
| fast_build: bool = False, | ||
| prefixlm: bool = False, | ||
| prefix_len: int = 0) -> FlexAttentionMetadata: | ||
| num_reqs = common_attn_metadata.num_reqs | ||
| num_actual_tokens = common_attn_metadata.num_actual_tokens | ||
| max_query_len = common_attn_metadata.max_query_len | ||
|
|
@@ -585,6 +644,8 @@ def build(self, | |
| cu_prefix_query_lens=cu_prefix_query_lens, | ||
| prefix_kv_lens=prefix_kv_lens, | ||
| suffix_kv_lens=suffix_kv_lens, | ||
| prefixlm=prefixlm, | ||
| prefix_len=prefix_len, | ||
| block_size=block_size, | ||
| max_possible_sequence_length=max_possible_seq_len, | ||
| num_reqs=num_reqs, | ||
|
|
@@ -708,7 +769,7 @@ def forward( | |
|
|
||
| num_actual_tokens = attn_metadata.num_actual_tokens | ||
|
|
||
| if not attn_metadata.causal: | ||
| if not attn_metadata.causal and not attn_metadata.prefixlm: | ||
| assert self.attn_type == AttentionType.ENCODER_ONLY | ||
|
|
||
| query, key_tensor, value_tensor = map( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The boolean expression for the mask can be simplified. The
(q_idx >= prefix_len)check in the second part of the|operation is redundant. If the first part(q_idx < prefix_len)is false, thenq_idx >= prefix_lenis implicitly true. Removing this redundant check will make the code slightly more efficient and easier to read.