|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "../device.hpp" |
| 4 | +#include "common/op.hpp" |
| 5 | +#include <optional> |
| 6 | + |
| 7 | +namespace infinicore::op { |
| 8 | + |
| 9 | +// Varlen InfLLM-V2 attention over unpadded Q/K/V. |
| 10 | +// |
| 11 | +// Shapes follow the FlashAttn-style varlen convention: |
| 12 | +// q : [total_q, nheads, head_dim] |
| 13 | +// k, v : [total_k, nheads_k, head_dim] |
| 14 | +// cu_seqlens_q: [batch_size + 1] (int32) |
| 15 | +// cu_seqlens_k: [batch_size + 1] (int32) |
| 16 | +// |
| 17 | +// Returns: |
| 18 | +// [total_q, nheads, head_dim] |
| 19 | +Tensor infllmv2_varlen(const Tensor &q, |
| 20 | + const Tensor &k, |
| 21 | + const Tensor &v, |
| 22 | + const Tensor &cu_seqlens_q, |
| 23 | + const Tensor &cu_seqlens_k, |
| 24 | + int max_seqlen_q, |
| 25 | + int max_seqlen_k, |
| 26 | + float scale, |
| 27 | + bool causal, |
| 28 | + int window_size_left = -1, |
| 29 | + int window_size_right = -1); |
| 30 | + |
| 31 | +// Decode-time InfLLM-V2 attention with KV cache. |
| 32 | +// |
| 33 | +// Shapes: |
| 34 | +// q : [batch, seqlen_q, nheads, head_dim] |
| 35 | +// k_cache : [num_blocks, block_size, nheads_k, head_dim] or [batch, seqlen_cache, nheads_k, head_dim] |
| 36 | +// v_cache : same as k_cache |
| 37 | +// cache_lens : [batch] (int32) total KV length per sequence |
| 38 | +// |
| 39 | +// Returns: |
| 40 | +// [batch, seqlen_q, nheads, head_dim] |
| 41 | +Tensor infllmv2_kvcache(const Tensor &q, |
| 42 | + const Tensor &k_cache, |
| 43 | + const Tensor &v_cache, |
| 44 | + const Tensor &cache_lens, |
| 45 | + float scale, |
| 46 | + bool causal, |
| 47 | + int window_size_left = -1, |
| 48 | + int window_size_right = -1); |
| 49 | + |
| 50 | +// Decode-time InfLLM-V2 attention with KV cache, updating cache in-place. |
| 51 | +// |
| 52 | +// Shapes: |
| 53 | +// q : [batch, seqlen_q, nheads, head_dim] |
| 54 | +// k_cache : [batch, seqlen_cache, nheads_k, head_dim] (dense cache) |
| 55 | +// v_cache : same as k_cache |
| 56 | +// k_new/v_new: [batch, seqlen_new, nheads_k, head_dim] (new KV to append at cache_lens offsets) |
| 57 | +// cache_lens : [batch] (int32) current KV length per sequence BEFORE appending |
| 58 | +// |
| 59 | +// Returns: |
| 60 | +// [batch, seqlen_q, nheads, head_dim] |
| 61 | +Tensor infllmv2_kvcache_update(const Tensor &q, |
| 62 | + const Tensor &k_cache, |
| 63 | + const Tensor &v_cache, |
| 64 | + const Tensor &k_new, |
| 65 | + const Tensor &v_new, |
| 66 | + const Tensor &cache_lens, |
| 67 | + float scale, |
| 68 | + bool causal, |
| 69 | + int window_size_left = -1, |
| 70 | + int window_size_right = -1); |
| 71 | + |
| 72 | +} // namespace infinicore::op |
| 73 | + |
0 commit comments