Skip to content

Commit f83cfb9

Browse files
author
zhangyue
committed
docs(base): add vLLM interface references to FlashAttention, ReshapeAndCache, RotaryEmbedding
1 parent 1654d1d commit f83cfb9

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

src/base/flash_attention.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99

1010
namespace infini::ops {
1111

12+
// Fused multi-head / grouped-query attention.
13+
//
14+
// Interface follows vLLM v1 `AttentionImpl.forward()`:
15+
// `vllm.v1.attention.backends.abstract.AttentionImpl`
16+
//
17+
// Layout: `query` / `key` / `value` are `[T, N, D]` (TND).
18+
// Prefill uses `cu_seqlens_q` / `cu_seqlens_kv` for variable-length packing.
19+
// Decode uses `block_table` for paged KV cache lookup.
1220
class FlashAttention : public Operator<FlashAttention> {
1321
public:
1422
FlashAttention(const Tensor query, const Tensor key, const Tensor value,

src/base/reshape_and_cache.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88

99
namespace infini::ops {
1010

11+
// Scatter `key` / `value` tokens into a paged KV cache.
12+
//
13+
// Interface follows vLLM's `reshape_and_cache` kernel:
14+
// `vllm._custom_ops.reshape_and_cache_flash`
15+
//
16+
// `kv_cache` layout: `[2, num_blocks, block_size, num_kv_heads, head_size]`.
17+
// `slot_mapping`: 1D `[num_tokens]`, each entry is the linear slot index
18+
// into the cache. Padding tokens must be filtered by the caller (no
19+
// negative indices).
1120
class ReshapeAndCache : public Operator<ReshapeAndCache> {
1221
public:
1322
ReshapeAndCache(const Tensor key, const Tensor value, const Tensor kv_cache,

src/base/rotary_embedding.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88

99
namespace infini::ops {
1010

11+
// Rotary position embedding (RoPE) applied in-place to Q and K.
12+
//
13+
// Interface follows vLLM's `RotaryEmbedding.forward_oot()`:
14+
// `vllm.model_executor.layers.rotary_embedding.RotaryEmbedding`
15+
//
16+
// `positions`: `[T]` token position indices.
17+
// `cos_sin_cache`: precomputed `[max_seq_len, rotary_dim]` table.
18+
// `query` / `key`: `[T, N, D]` (TND layout), mutated in-place into
19+
// `query_out` / `key_out`.
1120
class RotaryEmbedding : public Operator<RotaryEmbedding> {
1221
public:
1322
RotaryEmbedding(const Tensor positions, const Tensor query, const Tensor key,

0 commit comments

Comments
 (0)