Skip to content

Commit b9b1fdb

Browse files
author
Zhuoming Chen
committed
comments and docs
1 parent af697ef commit b9b1fdb

5 files changed

Lines changed: 170 additions & 640 deletions

File tree

docs/_static/custom.css

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
1+
/* Class signature: compact (smaller, not oversized). */
12
dl.py.class > dt.sig {
2-
font-size: 1.4em; /* 放大字体 */ /* 淡蓝/灰底色,自行改颜色值 */
3-
padding: 0.4.rem 0.8rem; /* 给一点内边距,让块更像卡片 */
4-
border-radius: 0.4rem; /* 圆角 */
3+
font-size: 0.9em;
4+
padding: 0.3rem 0.6rem;
5+
border-radius: 0.4rem;
56
}
67

7-
88
dl.py.class .sig-name {
9-
font-weight: 900;
10-
font-size: 1.2em;
9+
font-weight: 700;
10+
font-size: 1em;
1111
}
1212

13-
/* 调整类方法标题(def xxx)的整体字体大小 */
13+
/* Method signatures: a touch smaller than body text. */
1414
dl.py.method > dt.sig {
15-
font-size: 1.2em !important; /* 改成你想要的大小 */
15+
font-size: 0.85em;
1616
}
1717

18-
/* 进一步单独放大方法名 def xxx */
1918
dl.py.method > dt.sig .sig-name {
20-
font-size: 1.1em !important;
21-
font-weight: 700; /* 让方法名更突出,可选 */
19+
font-weight: 600;
2220
}
23-

vortex_torch/flow/algorithms.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ class GQABlockSparseAttention(vFlow):
102102
r"""
103103
Grouped-query block-sparse routing with a **softmax over pages**.
104104
105-
Like :class:`BlockSparseAttention` it scores pages by query–centroid
106-
similarity, but keeps every grouped-query head separate: each head turns
107-
its per-page scores into a softmax distribution over pages, and a page's
108-
final score is the **max** of that probability across heads. (Design akin
109-
to the GQA sparse-attention formulation in arXiv:2502.11089.)
105+
Each page keeps a centroid (the mean of its keys). Every grouped-query
106+
head is scored against the centroids separately; each head's per-page
107+
scores are turned into a softmax distribution over pages, and a page's
108+
final score is the **max** of that probability across heads
109+
(cf. the GQA sparse-attention formulation in arXiv:2502.11089).
110110
111111
**Cache.** Per-page centroid :math:`c_p = \frac{1}{|p|}\sum_{k\in p} k`
112112
via :class:`CMean`.
@@ -265,11 +265,12 @@ class LServeSparseAttention(vFlow):
265265
r"""
266266
LSERVE: QUEST envelopes at **sub-block** granularity.
267267
268-
Sharpens :class:`GQAQuestSparseAttention` by splitting each page into
269-
consecutive sub-blocks of :attr:`LSERVE_BLOCK_SIZE` tokens and keeping a
270-
separate max/min envelope **per sub-block**. A page is ranked by its
271-
single best-matching (head, sub-block) pair, so one relevant sub-region is
272-
enough to select the page — a tighter bound than one envelope per page.
268+
Each page is split into consecutive sub-blocks of :attr:`LSERVE_BLOCK_SIZE`
269+
tokens, and a coordinate-wise max/min key envelope is kept **per
270+
sub-block**. The envelopes give a cheap upper bound on the query–key dot
271+
product within a sub-block; a page is ranked by its single best-matching
272+
(head, sub-block) pair, so one relevant sub-region is enough to select the
273+
page.
273274
274275
**Cache.** :meth:`forward_cache` stores, for each of the
275276
:math:`n_b = \text{block\_size} / \text{LSERVE\_BLOCK\_SIZE}` sub-blocks
@@ -294,8 +295,8 @@ class LServeSparseAttention(vFlow):
294295
295296
**Shapes.** ``q`` is ``[B, H_q, D]``; ``cache["max"]`` / ``cache["min"]``
296297
are ``[S, n_b, D]`` (indexer) / ``[B, n_b, D]`` (cache). With
297-
``block_size == LSERVE_BLOCK_SIZE`` (:math:`n_b = 1`) this reduces to
298-
:class:`GQAQuestSparseAttention`.
298+
``block_size == LSERVE_BLOCK_SIZE`` there is one sub-block per page
299+
(:math:`n_b = 1`), i.e. a single envelope over the whole page.
299300
"""
300301
LSERVE_BLOCK_SIZE = 16
301302
def __init__(self):
@@ -354,13 +355,12 @@ def create_cache(self, block_size: int, head_dim: int):
354355
@register("lserve_centroid_sparse_attention")
355356
class LServeCentroidSparseAttention(vFlow):
356357
r"""
357-
Centroid routing at LSERVE **sub-block** granularity.
358+
Centroid routing at sub-block granularity.
358359
359-
Combines :class:`BlockSparseAttention`'s centroid routing with the
360-
sub-block idea of :class:`LServeSparseAttention`: each page is split into
361-
consecutive sub-blocks of :attr:`SUB_BLOCK_SIZE` tokens, a centroid is
362-
kept **per sub-block**, and a page is ranked by its best-matching
363-
sub-block — finer than collapsing the whole page into one centroid.
360+
Each page is split into consecutive sub-blocks of :attr:`SUB_BLOCK_SIZE`
361+
tokens, and a centroid (the mean of its keys) is kept **per sub-block**. A
362+
page is ranked by the query's best match against any of its sub-block
363+
centroids, so one relevant sub-region is enough to select the page.
364364
365365
**Cache.** :meth:`forward_cache` stores, for each of the
366366
:math:`n_b = \text{block\_size} / \text{SUB\_BLOCK\_SIZE}` sub-blocks
@@ -381,8 +381,8 @@ class LServeCentroidSparseAttention(vFlow):
381381
382382
**Shapes.** ``q`` is ``[B, H_q, D]``; ``cache["centroids"]`` is
383383
``[S, n_b, D]`` (indexer) / ``[B, n_b, D]`` (cache). With
384-
``block_size == SUB_BLOCK_SIZE`` (:math:`n_b = 1`) this reduces to
385-
:class:`BlockSparseAttention`.
384+
``block_size == SUB_BLOCK_SIZE`` there is one sub-block per page
385+
(:math:`n_b = 1`), i.e. a single centroid over the whole page.
386386
"""
387387
SUB_BLOCK_SIZE = 16
388388

@@ -437,12 +437,12 @@ class MaskedQuestSparseAttention(vFlow):
437437
r"""
438438
QUEST routing with a feature-axis mask that drops low-signal channels.
439439
440-
Identical to :class:`GQAQuestSparseAttention`, but a :class:`MaskSlice`
441-
zeroes the leading ``MASK_END`` feature coordinates of the QUEST bound
442-
before the feature sum — a cheap, position-only way to exclude
443-
low-signal channels (e.g. large-magnitude "sink" dimensions). Since
444-
:class:`MaskSlice` is a pure position writer, no extra state is threaded
445-
through ``ctx``.
440+
Each page keeps a coordinate-wise max/min key envelope; their combination
441+
upper-bounds the largest query–key dot product in the page. Before summing
442+
over features, a :class:`MaskSlice` zeroes the leading ``MASK_END`` feature
443+
coordinates of that bound — a cheap, position-only way to exclude
444+
low-signal channels (e.g. large-magnitude "sink" dimensions). The mask is a
445+
pure position writer, so no extra state is threaded through ``ctx``.
446446
447447
**Cache.** Per-page key envelopes :math:`M_p = \max_{k\in p} k` and
448448
:math:`m_p = \min_{k\in p} k` via :class:`CMax` / :class:`CMin`.
@@ -523,9 +523,9 @@ class CenteredBlockSparseAttention(vFlow):
523523
r"""
524524
Centroid block-sparse routing with per-request **mean-centering**.
525525
526-
Scores pages by query–centroid similarity like
527-
:class:`BlockSparseAttention`, then subtracts the per-request mean score
528-
across pages before selection — so a page competes by how far *above
526+
Each page keeps a centroid (the mean of its keys); pages are scored by
527+
query–centroid similarity, and the per-request mean score across pages is
528+
subtracted before selection — so a page competes by how far *above
529529
average* it is, not by raw similarity.
530530
531531
**Cache.** Per-page centroid :math:`c_p = \frac{1}{|p|}\sum_{k\in p} k`
@@ -593,9 +593,10 @@ class RunningAvgBlockSparse(vFlow):
593593
Centroid block-sparse routing with a per-page **running score**
594594
(a :class:`Save` / :class:`Load` demo).
595595
596-
Like :class:`BlockSparseAttention`, but instead of scoring on the current
597-
step alone it keeps an exponentially-decayed running score per page across
598-
decode steps: pages that stay relevant accumulate, pages that fade decay.
596+
Each page keeps a centroid (the mean of its keys). Instead of scoring on
597+
the current step alone, the per-page query–centroid score is accumulated
598+
into an exponentially-decayed running score across decode steps: pages that
599+
stay relevant accumulate, pages that fade decay.
599600
600601
**Cache.** Per-page centroid :math:`c_p` via :class:`CMean`; the persistent
601602
``running_score`` is zero-initialised with :class:`CFill` when a page is
@@ -675,10 +676,10 @@ class VEnergyGatedCentroid(vFlow):
675676
r"""
676677
Centroid routing **gated by value-block energy**.
677678
678-
Scores a page by the query–centroid dot product like
679-
:class:`BlockSparseAttention`, then multiplies by the page's mean value
680-
magnitude (its "energy"): pages whose values carry little energy are muted
681-
even when the key centroid aligns with the query.
679+
Each page keeps a key centroid (the mean of its keys); a page is scored by
680+
the query–centroid dot product multiplied by the page's mean value
681+
magnitude (its "energy"), so pages whose values carry little energy are
682+
muted even when the key centroid aligns with the query.
682683
683684
**Cache.** :meth:`forward_cache` stores a per-page key centroid
684685
:math:`c_p` (:class:`CMean`) and the value energy — the mean :math:`L_2`

vortex_torch/indexer/matmul.py

Lines changed: 28 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -7,69 +7,17 @@
77

88
class GeMV(vOp):
99
r"""
10-
General matrix-vector multiplication (GEMV) dispatcher.
11-
12-
This operator computes a *piecewise* batched matrix-vector product.
13-
Let
14-
15-
.. math::
16-
17-
X \in \mathbb{R}^{B \times 1 \times D}, \qquad
18-
Y \in \mathbb{R}^{S_{\text{pack}} \times 1 \times D},
19-
20-
where the ``S``-axis of :math:`Y` is a concatenation of batch-wise
21-
segments
22-
23-
.. math::
24-
25-
S_{\text{pack}} = \sum_{i=0}^{B-1} S_i, \qquad
26-
Y =
27-
\begin{bmatrix}
28-
Y_0 \\
29-
Y_1 \\
30-
\vdots \\
31-
Y_{B-1}
32-
\end{bmatrix},
33-
34-
with
35-
36-
.. math::
37-
38-
Y_i \in \mathbb{R}^{S_i \times 1 \times D}, \qquad
39-
X_i = X[i, 0, :] \in \mathbb{R}^{1 \times D}.
40-
41-
For each batch index :math:`i \in \{0,\dots,B-1\}`, we define
42-
43-
.. math::
44-
45-
O_i = Y_i X_i^{\mathsf{T}} \in \mathbb{R}^{S_i \times 1 \times 1},
46-
47-
and the overall output is the concatenation
48-
49-
.. math::
50-
51-
O =
52-
\begin{bmatrix}
53-
O_0 \\
54-
O_1 \\
55-
\vdots \\
56-
O_{B-1}
57-
\end{bmatrix}
58-
\in \mathbb{R}^{S_{\text{pack}} \times 1 \times 1}.
59-
60-
In the runtime, :math:`S_{\text{pack}}` is given by
61-
``ctx.max_num_pages``. Output format rule: ``BATCHED`` iff both
62-
inputs are ``BATCHED`` (both have their ``S`` axis already collapsed
63-
to 1), otherwise ``RAGGED``. Format compatibility is enforced by
64-
the compiler's per-workload kernel.
65-
66-
Attributes
67-
----------
68-
output_format : Optional[FORMAT]
69-
The output tensor format as determined in :meth:`profile`.
70-
71-
output_buffer : Optional[torch.Tensor]
72-
Preallocated output tensor buffer of shape ``[S_pack, 1, 1]``.
10+
Per-request batched matrix–vector product, :math:`O = Y X^{\top}`.
11+
12+
:Math: with a batched query :math:`X\in\mathbb{R}^{B\times 1\times D}` and
13+
packed pages :math:`Y\in\mathbb{R}^{S\times 1\times D}`, each page
14+
:math:`s` of request :math:`i` scores as
15+
:math:`O[s]=\langle Y[s],\,X[i]\rangle`, giving
16+
:math:`O\in\mathbb{R}^{S\times 1\times 1}`.
17+
:__init__: ``GeMV()`` — no arguments.
18+
:__call__: ``o = op(x, y, ctx=ctx)`` — ``x`` is ``[B, 1, D]``, ``y`` is
19+
``[S, 1, D]`` (matching ``D``); returns ``o`` ``[S, 1, 1]``. Output is
20+
``BATCHED`` iff both inputs are, else ``RAGGED``.
7321
"""
7422

7523
def __init__(self):
@@ -79,19 +27,9 @@ def __init__(self):
7927
self.schedule = Schedule.W
8028
# ---------------- profile ----------------
8129
def profile(self, x: vTensor, y: vTensor, ctx: Context) -> vTensor:
82-
r"""
83-
Validate inputs, allocate the output buffer, and return a
84-
:class:`vTensor` view.
85-
86-
The method enforces the logical shapes
87-
88-
- ``x``: ``[B, 1, D]``
89-
- ``y``: ``[S_pack, 1, D]``
90-
91-
and checks that the last dimensions match. The output buffer is
92-
allocated with shape ``[S_pack, 1, 1]``, where ``S_pack`` is taken
93-
from the runtime context as ``ctx.max_num_pages``.
94-
"""
30+
r"""Trace-time: validate ``x`` ``[B, 1, D]`` / ``y`` ``[S, 1, D]``,
31+
register the op, and return a ``vTensor`` view of the ``[S, 1, 1]``
32+
output (see the class docstring)."""
9533
prefix = self._prefix()
9634

9735
# Type checks
@@ -137,51 +75,17 @@ def profile(self, x: vTensor, y: vTensor, ctx: Context) -> vTensor:
13775
# ------------------------------ GeMM ------------------------------ #
13876
class GeMM(vOp):
13977
r"""
140-
General matrix-matrix multiplication (GeMM) dispatcher.
141-
142-
Logically this computes, for each logical ``S``-slice, a matrix-matrix
143-
product
144-
145-
.. math::
146-
147-
O[s] = Y[s] X[s]^{\mathsf{T}}, \quad s = 0, \dots, S-1,
148-
149-
with slice-wise shapes
150-
151-
.. math::
152-
153-
X[s] \in \mathbb{R}^{N_x \times K}, \quad
154-
Y[s] \in \mathbb{R}^{N_y \times K}, \quad
155-
O[s] \in \mathbb{R}^{N_y \times N_x}.
156-
157-
In the packed 3D representation used by this dispatcher:
158-
159-
- ``Y`` has logical shape ``[S, N_y, K]``.
160-
- ``X`` has logical shape ``[L_x, N_x, K]``, where the leading
161-
dimension :math:`L_x` can represent **either**:
162-
163-
* a batch axis :math:`B` (when ``x_format == FORMAT.BATCHED``), or
164-
* the same ``S`` axis as ``Y`` (when ``x_format`` is ragged/paged and
165-
already laid out per-page).
166-
167-
This is why the code comments use ``X: [B/S, N_x, K]``: the first
168-
dimension is interpreted as either a batch size :math:`B` or an
169-
``S``-like logical page index, depending on the format.
170-
171-
- The output tensor ``O`` has logical shape ``[S, N_y, N_x]``.
172-
173-
At runtime, the logical ``S`` is taken from ``ctx.max_num_pages``.
174-
Output format rule: ``BATCHED`` iff both inputs are ``BATCHED``,
175-
otherwise ``RAGGED``. Format compatibility is enforced by the
176-
compiler's per-workload kernel.
177-
178-
Attributes
179-
----------
180-
output_format : Optional[FORMAT]
181-
The output tensor format as determined in :meth:`profile`.
182-
183-
output_buffer : Optional[torch.Tensor]
184-
Preallocated output tensor buffer of shape ``[S, N_y, N_x]``.
78+
Per-page matrix–matrix product, :math:`O[s] = Y[s]\,X[s]^{\top}`.
79+
80+
:Math: with :math:`Y\in\mathbb{R}^{S\times N_y\times K}` and
81+
:math:`X\in\mathbb{R}^{(B\,\text{or}\,S)\times N_x\times K}`, per page
82+
:math:`s`:
83+
:math:`O[s]=Y[s]\,X[s]^{\top}\in\mathbb{R}^{N_y\times N_x}` — i.e.
84+
``GeMM(x, y) = y xᵀ``. Output :math:`O\in\mathbb{R}^{S\times N_y\times N_x}`.
85+
:__init__: ``GeMM()`` — no arguments.
86+
:__call__: ``o = op(x, y, ctx=ctx)`` — ``x`` is ``[B|S, N_x, K]``, ``y`` is
87+
``[S, N_y, K]`` (matching ``K``); returns ``o`` ``[S, N_y, N_x]``.
88+
Output is ``BATCHED`` iff both inputs are, else ``RAGGED``.
18589
"""
18690

18791
def __init__(self):
@@ -192,48 +96,9 @@ def __init__(self):
19296

19397
# ---------------- profile ----------------
19498
def profile(self, x: vTensor, y: vTensor, ctx: Context) -> vTensor:
195-
r"""
196-
Validate inputs, allocate the output buffer, and return a
197-
:class:`vTensor` view.
198-
199-
The method enforces that both inputs are rank-3 tensors and that the
200-
inner dimension :math:`K` matches:
201-
202-
- ``x``: ``[B_or_S, N_x, K]``
203-
204-
*When* ``x_format == FORMAT.BATCHED``, the leading dimension is a
205-
batch size :math:`B`. For ragged/paged formats, it may conceptually
206-
coincide with :math:`S`.
207-
208-
- ``y``: ``[S, N_y, K]``
209-
210-
The output buffer is allocated with shape ``[S, N_y, N_x]``, where
211-
``S`` is taken from the runtime context as ``ctx.max_num_pages``.
212-
213-
Parameters
214-
----------
215-
x : vTensor
216-
Right-hand operand (transposed in the mathematical view), with
217-
shape ``[B_or_S, N_x, K]``.
218-
219-
y : vTensor
220-
Left-hand operand with shape ``[S, N_y, K]``.
221-
222-
ctx : Context
223-
Execution context providing ``ctx.max_num_pages`` for the logical
224-
``S`` dimension and tracking auxiliary memory.
225-
226-
Returns
227-
-------
228-
vTensor
229-
A ``vTensor`` view wrapping the allocated output buffer.
230-
231-
Raises
232-
------
233-
AssertionError
234-
If types are not ``vTensor``, ranks are not 3, or the inner
235-
dimensions :math:`K` do not match.
236-
"""
99+
r"""Trace-time: validate ``x`` ``[B|S, N_x, K]`` / ``y`` ``[S, N_y, K]``
100+
(matching ``K``), register the op, and return a ``vTensor`` view of the
101+
``[S, N_y, N_x]`` output (see the class docstring)."""
237102
prefix = self._prefix()
238103

239104
# Type checks

0 commit comments

Comments
 (0)