@@ -76,39 +76,51 @@ Note: Some operators are not yet fused or fully optimized, which may lead to inc
7676## 🧩 Quick Example: Custom Sparse Attention
7777
7878``` python
79+ from typing import Dict
80+ import torch
81+
82+ from vortex_torch.flow import vFlow, register
83+ from vortex_torch.indexer import GeMM, Mean, topK
84+ from vortex_torch.cache import Mean as CMean
85+ from vortex_torch.abs import ContextBase
86+
87+
7988@register (" custom_sparse_attention" )
8089class CustomSparseAttention (vFlow ):
8190
8291 def __init__ (self ):
8392 super ().__init__ ()
84- # Indexer-side ops
85- self .gemv = GeMV()
86- self .output_func = topK()
93+ # Indexer-side ops (run every decode step)
94+ self .mean = Mean(dim = 1 ) # average over the query heads
95+ self .gemm = GeMM() # GeMM(x, y) = y @ xᵀ
96+ self .output_func = topK() # must end in topK / approxTopK
8797
88- # Cache-side ops
89- self .reduction = CMean(dim = 1 )
98+ # Cache-side ops (run once per finished page)
99+ self .reduction = CMean(dim = 1 ) # one centroid (mean key) per page
90100
91101 def forward_indexer (
92102 self ,
93- q : torch.Tensor, # viewed as [1 , H_q, D]
103+ q : torch.Tensor, # viewed as [B , H_q, D]
94104 o : torch.Tensor,
95- cache : Dict[str , torch.Tensor], # viewed as [S, r, c] depending on create_cache()
105+ cache : Dict[str , torch.Tensor], # viewed as [S, r, c] per create_cache()
96106 ctx : ContextBase,
97107 ):
98- q_mean = q.mean(dim = 1 , keepdim = True )
99- score = self .gemv(q_mean, cache[" centroids" ], ctx = ctx)
100- self .output_func(score, o, ctx = ctx)
108+ # No native torch ops here — every tensor flows through vortex ops.
109+ q_mean = self .mean(q, ctx = ctx) # [B, 1, D]
110+ score = self .gemm(q_mean, cache[" centroids" ], ctx = ctx) # [S, 1, 1]
111+ self .output_func(score, o, ctx = ctx) # selected pages -> o
101112
102113 def forward_cache (
103114 self ,
104- cache : Dict[str , torch.Tensor], # viewed as [B, r, c] depending on create_cache()
115+ cache : Dict[str , torch.Tensor], # viewed as [B, r, c] per create_cache()
105116 loc : torch.Tensor,
106117 ctx : ContextBase,
107118 ):
108119 # triggered only when a page is finished
109120 self .reduction(cache[" k" ], cache[" centroids" ], loc = loc, ctx = ctx)
110121
111- def create_cache (self , page_size : int , head_dim : int ):
122+ def create_cache (self , block_size : int , head_dim : int ):
123+ # "k" and "v" are provided automatically — do not declare them
112124 return {
113125 " centroids" : (1 , head_dim),
114126 }
@@ -127,8 +139,8 @@ llm = sgl.Engine(
127139 disable_overlap_schedule = True , # Mandatory
128140 attention_backend = " flashinfer" , # Mandatory
129141 enable_vortex_sparsity = True , # Otherwise full attention is used
130- vortex_page_reserved_bos = 1 ,
131- vortex_page_reserved_eos = 1 ,
142+ vortex_block_reserved_bos = 1 ,
143+ vortex_block_reserved_eos = 1 ,
132144 vortex_layers_skip = list (range (1 )), # Full attention for layer 0
133145 vortex_module_path = " path/to/custom_sparse_attention.py" ,
134146 vortex_module_name = " custom_sparse_attention" , # the registered name for your algorithm
0 commit comments