|
| 1 | +#ifdef ENABLE_FLASH_ATTN |
| 2 | +#pragma once |
| 3 | +#include "aten_adaptor.hpp" |
| 4 | + |
| 5 | +namespace flash { |
| 6 | +std::vector<at::Tensor> |
| 7 | +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) |
| 8 | + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) |
| 9 | + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) |
| 10 | + std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) |
| 11 | + std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads |
| 12 | + const float p_dropout, |
| 13 | + const float softmax_scale, |
| 14 | + bool is_causal, |
| 15 | + int window_size_left, |
| 16 | + int window_size_right, |
| 17 | + const float softcap, |
| 18 | + const bool return_softmax, |
| 19 | + std::optional<at::Generator> gen_); |
| 20 | + |
| 21 | +std::vector<at::Tensor> |
| 22 | +mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i |
| 23 | + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. |
| 24 | + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. |
| 25 | + std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i |
| 26 | + const at::Tensor &cu_seqlens_q, // b+1 |
| 27 | + const at::Tensor &cu_seqlens_k, // b+1 |
| 28 | + std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. |
| 29 | + std::optional<const at::Tensor> &leftpad_k_, // batch_size |
| 30 | + std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq |
| 31 | + std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads |
| 32 | + int max_seqlen_q, |
| 33 | + const int max_seqlen_k, |
| 34 | + const float p_dropout, |
| 35 | + const float softmax_scale, |
| 36 | + const bool zero_tensors, |
| 37 | + bool is_causal, |
| 38 | + int window_size_left, |
| 39 | + int window_size_right, |
| 40 | + const float softcap, |
| 41 | + const bool return_softmax, |
| 42 | + std::optional<at::Generator> gen_); |
| 43 | + |
| 44 | +std::vector<at::Tensor> |
| 45 | +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) |
| 46 | + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size |
| 47 | + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size |
| 48 | + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size |
| 49 | + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size |
| 50 | + const at::Tensor &softmax_lse, // b x h x seqlen_q |
| 51 | + std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size |
| 52 | + std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size |
| 53 | + std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size |
| 54 | + std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads |
| 55 | + const float p_dropout, // probability to drop |
| 56 | + const float softmax_scale, |
| 57 | + const bool is_causal, |
| 58 | + int window_size_left, |
| 59 | + int window_size_right, |
| 60 | + const float softcap, |
| 61 | + const bool deterministic, |
| 62 | + std::optional<at::Generator> gen_, |
| 63 | + std::optional<at::Tensor> &rng_state); |
| 64 | + |
| 65 | +std::vector<at::Tensor> |
| 66 | +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size |
| 67 | + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i |
| 68 | + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i |
| 69 | + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i |
| 70 | + const at::Tensor &out, // total_q x num_heads x head_size |
| 71 | + const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp |
| 72 | + std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i |
| 73 | + std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i |
| 74 | + std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i |
| 75 | + const at::Tensor &cu_seqlens_q, // b+1 |
| 76 | + const at::Tensor &cu_seqlens_k, // b+1 |
| 77 | + std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads |
| 78 | + const int max_seqlen_q, |
| 79 | + const int max_seqlen_k, // max sequence length to choose the kernel |
| 80 | + const float p_dropout, // probability to drop |
| 81 | + const float softmax_scale, |
| 82 | + const bool zero_tensors, |
| 83 | + const bool is_causal, |
| 84 | + int window_size_left, |
| 85 | + int window_size_right, |
| 86 | + const float softcap, |
| 87 | + const bool deterministic, |
| 88 | + std::optional<at::Generator> gen_, |
| 89 | + std::optional<at::Tensor> &rng_state); |
| 90 | + |
| 91 | +std::vector<at::Tensor> |
| 92 | +mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size |
| 93 | + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. |
| 94 | + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. |
| 95 | + std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size |
| 96 | + std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size |
| 97 | + std::optional<const at::Tensor> &seqlens_k_, // batch_size |
| 98 | + std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2) |
| 99 | + std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2) |
| 100 | + std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache |
| 101 | + std::optional<const at::Tensor> &leftpad_k_, // batch_size |
| 102 | + std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq |
| 103 | + std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads |
| 104 | + std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size |
| 105 | + const float softmax_scale, |
| 106 | + bool is_causal, |
| 107 | + int window_size_left, |
| 108 | + int window_size_right, |
| 109 | + const float softcap, |
| 110 | + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 |
| 111 | + int num_splits); |
| 112 | + |
| 113 | +} // namespace flash |
| 114 | +#endif // ENABLE_FLASH_ATTN |
0 commit comments