Skip to content

Commit 5d2aa76

Browse files
author
Zhuoming Chen
committed
comments and docs
1 parent b7dc024 commit 5d2aa76

2 files changed

Lines changed: 46 additions & 22 deletions

File tree

README.md

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,39 +76,51 @@ Note: Some operators are not yet fused or fully optimized, which may lead to inc
7676
## 🧩 Quick Example: Custom Sparse Attention
7777

7878
```python
79+
from typing import Dict
80+
import torch
81+
82+
from vortex_torch.flow import vFlow, register
83+
from vortex_torch.indexer import GeMM, Mean, topK
84+
from vortex_torch.cache import Mean as CMean
85+
from vortex_torch.abs import ContextBase
86+
87+
7988
@register("custom_sparse_attention")
8089
class CustomSparseAttention(vFlow):
8190

8291
def __init__(self):
8392
super().__init__()
84-
# Indexer-side ops
85-
self.gemv = GeMV()
86-
self.output_func = topK()
93+
# Indexer-side ops (run every decode step)
94+
self.mean = Mean(dim=1) # average over the query heads
95+
self.gemm = GeMM() # GeMM(x, y) = y @ xᵀ
96+
self.output_func = topK() # must end in topK / approxTopK
8797

88-
# Cache-side ops
89-
self.reduction = CMean(dim=1)
98+
# Cache-side ops (run once per finished page)
99+
self.reduction = CMean(dim=1) # one centroid (mean key) per page
90100

91101
def forward_indexer(
92102
self,
93-
q: torch.Tensor, # viewed as [1, H_q, D]
103+
q: torch.Tensor, # viewed as [B, H_q, D]
94104
o: torch.Tensor,
95-
cache: Dict[str, torch.Tensor], # viewed as [S, r, c] depending on create_cache()
105+
cache: Dict[str, torch.Tensor], # viewed as [S, r, c] per create_cache()
96106
ctx: ContextBase,
97107
):
98-
q_mean = q.mean(dim=1, keepdim=True)
99-
score = self.gemv(q_mean, cache["centroids"], ctx=ctx)
100-
self.output_func(score, o, ctx=ctx)
108+
# No native torch ops here — every tensor flows through vortex ops.
109+
q_mean = self.mean(q, ctx=ctx) # [B, 1, D]
110+
score = self.gemm(q_mean, cache["centroids"], ctx=ctx) # [S, 1, 1]
111+
self.output_func(score, o, ctx=ctx) # selected pages -> o
101112

102113
def forward_cache(
103114
self,
104-
cache: Dict[str, torch.Tensor], # viewed as [B, r, c] depending on create_cache()
115+
cache: Dict[str, torch.Tensor], # viewed as [B, r, c] per create_cache()
105116
loc: torch.Tensor,
106117
ctx: ContextBase,
107118
):
108119
# triggered only when a page is finished
109120
self.reduction(cache["k"], cache["centroids"], loc=loc, ctx=ctx)
110121

111-
def create_cache(self, page_size: int, head_dim: int):
122+
def create_cache(self, block_size: int, head_dim: int):
123+
# "k" and "v" are provided automatically — do not declare them
112124
return {
113125
"centroids": (1, head_dim),
114126
}
@@ -127,8 +139,8 @@ llm = sgl.Engine(
127139
disable_overlap_schedule=True, # Mandatory
128140
attention_backend="flashinfer", # Mandatory
129141
enable_vortex_sparsity=True, # Otherwise full attention is used
130-
vortex_page_reserved_bos=1,
131-
vortex_page_reserved_eos=1,
142+
vortex_block_reserved_bos=1,
143+
vortex_block_reserved_eos=1,
132144
vortex_layers_skip=list(range(1)), # Full attention for layer 0
133145
vortex_module_path="path/to/custom_sparse_attention.py",
134146
vortex_module_name="custom_sparse_attention", # the registered name for your algorithm

docs/index.rst

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,26 @@ Define a custom flow — centroid-based block-sparse routing in a dozen lines:
3838

3939
.. code-block:: python
4040
41+
from typing import Dict
42+
import torch
43+
44+
from vortex_torch.flow import vFlow, register
45+
from vortex_torch.indexer import GeMM, Mean, topK
46+
from vortex_torch.cache import Mean as CMean
47+
from vortex_torch.abs import ContextBase
48+
49+
4150
@register("custom_sparse_attention")
4251
class CustomSparseAttention(vFlow):
4352
4453
def __init__(self):
4554
super().__init__()
4655
# Indexer-side ops (run every decode step)
47-
self.gemv = GeMV()
48-
self.output_func = topK()
56+
self.mean = Mean(dim=1) # average over the query heads
57+
self.gemm = GeMM() # GeMM(x, y) = y @ xᵀ
58+
self.output_func = topK() # must end in topK / approxTopK
4959
# Cache-side ops (run once per finished page)
50-
self.reduction = CMean(dim=1)
60+
self.reduction = CMean(dim=1) # one centroid (mean key) per page
5161
5262
def forward_indexer(
5363
self,
@@ -56,9 +66,10 @@ Define a custom flow — centroid-based block-sparse routing in a dozen lines:
5666
cache: Dict[str, torch.Tensor], # viewed as [S, r, c] per create_cache()
5767
ctx: ContextBase,
5868
):
59-
q_mean = self.mean(q, ctx=ctx)
60-
score = self.gemv(q_mean, cache["centroids"], ctx=ctx)
61-
self.output_func(score, o, ctx=ctx) # must end in topK / approxTopK
69+
# No native torch ops here — every tensor flows through vortex ops.
70+
q_mean = self.mean(q, ctx=ctx) # [B, 1, D]
71+
score = self.gemm(q_mean, cache["centroids"], ctx=ctx) # [S, 1, 1]
72+
self.output_func(score, o, ctx=ctx) # selected pages -> o
6273
6374
def forward_cache(
6475
self,
@@ -69,7 +80,7 @@ Define a custom flow — centroid-based block-sparse routing in a dozen lines:
6980
# triggered only when a page is finished
7081
self.reduction(cache["k"], cache["centroids"], loc=loc, ctx=ctx)
7182
72-
def create_cache(self, page_size: int, head_dim: int):
83+
def create_cache(self, block_size: int, head_dim: int):
7384
# "k" and "v" are provided automatically — do not declare them
7485
return {"centroids": (1, head_dim)}
7586
@@ -80,7 +91,8 @@ Then run it through an SGLang engine:
8091
llm = sgl.Engine(
8192
model_path="Qwen/Qwen3-0.6B",
8293
page_size=16,
83-
attention_backend="flashinfer", # SGLang's base backend
94+
attention_backend="flashinfer", # mandatory: SGLang's base backend
95+
disable_overlap_schedule=True, # mandatory for vortex sparsity
8496
enable_vortex_sparsity=True, # otherwise computes full attention
8597
vortex_topk_val=30, # pages kept per request
8698
vortex_block_reserved_bos=1, # always-attended prefix blocks

0 commit comments

Comments
 (0)