Skip to content

Why FlashAttention only on encoder?  #7

@karioth

Description

@karioth

Hi, thank you for providing this flash implementation of t5. I am wondering however, why the code is set up to only have the attention variants work on the encoder and not on the decoder? See below the specific line:

class T5LayerSelfAttention(nn.Module):
def init(self, config, has_relative_attention_bias=False):
super().init()
if config.is_decoder:
# decoder always uses T5Attention
self.SelfAttention = T5Attention(config, ...)
else:
# encoder uses one of {T5FlashAttention, T5TritonBasicAttention, etc.}
self.SelfAttention = T5ATTENTION_TYPES[config.attention_type](config, ...)
...

I am planning on using only the decoder in an architecture, so I am curious as to why it was not integrated there.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions