77
88class GeMV (vOp ):
99 r"""
10- General matrix-vector multiplication (GEMV) dispatcher.
11-
12- This operator computes a *piecewise* batched matrix-vector product.
13- Let
14-
15- .. math::
16-
17- X \in \mathbb{R}^{B \times 1 \times D}, \qquad
18- Y \in \mathbb{R}^{S_{\text{pack}} \times 1 \times D},
19-
20- where the ``S``-axis of :math:`Y` is a concatenation of batch-wise
21- segments
22-
23- .. math::
24-
25- S_{\text{pack}} = \sum_{i=0}^{B-1} S_i, \qquad
26- Y =
27- \begin{bmatrix}
28- Y_0 \\
29- Y_1 \\
30- \vdots \\
31- Y_{B-1}
32- \end{bmatrix},
33-
34- with
35-
36- .. math::
37-
38- Y_i \in \mathbb{R}^{S_i \times 1 \times D}, \qquad
39- X_i = X[i, 0, :] \in \mathbb{R}^{1 \times D}.
40-
41- For each batch index :math:`i \in \{0,\dots,B-1\}`, we define
42-
43- .. math::
44-
45- O_i = Y_i X_i^{\mathsf{T}} \in \mathbb{R}^{S_i \times 1 \times 1},
46-
47- and the overall output is the concatenation
48-
49- .. math::
50-
51- O =
52- \begin{bmatrix}
53- O_0 \\
54- O_1 \\
55- \vdots \\
56- O_{B-1}
57- \end{bmatrix}
58- \in \mathbb{R}^{S_{\text{pack}} \times 1 \times 1}.
59-
60- In the runtime, :math:`S_{\text{pack}}` is given by
61- ``ctx.max_num_pages``. Output format rule: ``BATCHED`` iff both
62- inputs are ``BATCHED`` (both have their ``S`` axis already collapsed
63- to 1), otherwise ``RAGGED``. Format compatibility is enforced by
64- the compiler's per-workload kernel.
65-
66- Attributes
67- ----------
68- output_format : Optional[FORMAT]
69- The output tensor format as determined in :meth:`profile`.
70-
71- output_buffer : Optional[torch.Tensor]
72- Preallocated output tensor buffer of shape ``[S_pack, 1, 1]``.
10+ Per-request batched matrix–vector product, :math:`O = Y X^{\top}`.
11+
12+ :Math: with a batched query :math:`X\in\mathbb{R}^{B\times 1\times D}` and
13+ packed pages :math:`Y\in\mathbb{R}^{S\times 1\times D}`, each page
14+ :math:`s` of request :math:`i` scores as
15+ :math:`O[s]=\langle Y[s],\,X[i]\rangle`, giving
16+ :math:`O\in\mathbb{R}^{S\times 1\times 1}`.
17+ :__init__: ``GeMV()`` — no arguments.
18+ :__call__: ``o = op(x, y, ctx=ctx)`` — ``x`` is ``[B, 1, D]``, ``y`` is
19+ ``[S, 1, D]`` (matching ``D``); returns ``o`` ``[S, 1, 1]``. Output is
20+ ``BATCHED`` iff both inputs are, else ``RAGGED``.
7321 """
7422
7523 def __init__ (self ):
@@ -79,19 +27,9 @@ def __init__(self):
7927 self .schedule = Schedule .W
8028 # ---------------- profile ----------------
8129 def profile (self , x : vTensor , y : vTensor , ctx : Context ) -> vTensor :
82- r"""
83- Validate inputs, allocate the output buffer, and return a
84- :class:`vTensor` view.
85-
86- The method enforces the logical shapes
87-
88- - ``x``: ``[B, 1, D]``
89- - ``y``: ``[S_pack, 1, D]``
90-
91- and checks that the last dimensions match. The output buffer is
92- allocated with shape ``[S_pack, 1, 1]``, where ``S_pack`` is taken
93- from the runtime context as ``ctx.max_num_pages``.
94- """
30+ r"""Trace-time: validate ``x`` ``[B, 1, D]`` / ``y`` ``[S, 1, D]``,
31+ register the op, and return a ``vTensor`` view of the ``[S, 1, 1]``
32+ output (see the class docstring)."""
9533 prefix = self ._prefix ()
9634
9735 # Type checks
@@ -137,51 +75,17 @@ def profile(self, x: vTensor, y: vTensor, ctx: Context) -> vTensor:
13775# ------------------------------ GeMM ------------------------------ #
13876class GeMM (vOp ):
13977 r"""
140- General matrix-matrix multiplication (GeMM) dispatcher.
141-
142- Logically this computes, for each logical ``S``-slice, a matrix-matrix
143- product
144-
145- .. math::
146-
147- O[s] = Y[s] X[s]^{\mathsf{T}}, \quad s = 0, \dots, S-1,
148-
149- with slice-wise shapes
150-
151- .. math::
152-
153- X[s] \in \mathbb{R}^{N_x \times K}, \quad
154- Y[s] \in \mathbb{R}^{N_y \times K}, \quad
155- O[s] \in \mathbb{R}^{N_y \times N_x}.
156-
157- In the packed 3D representation used by this dispatcher:
158-
159- - ``Y`` has logical shape ``[S, N_y, K]``.
160- - ``X`` has logical shape ``[L_x, N_x, K]``, where the leading
161- dimension :math:`L_x` can represent **either**:
162-
163- * a batch axis :math:`B` (when ``x_format == FORMAT.BATCHED``), or
164- * the same ``S`` axis as ``Y`` (when ``x_format`` is ragged/paged and
165- already laid out per-page).
166-
167- This is why the code comments use ``X: [B/S, N_x, K]``: the first
168- dimension is interpreted as either a batch size :math:`B` or an
169- ``S``-like logical page index, depending on the format.
170-
171- - The output tensor ``O`` has logical shape ``[S, N_y, N_x]``.
172-
173- At runtime, the logical ``S`` is taken from ``ctx.max_num_pages``.
174- Output format rule: ``BATCHED`` iff both inputs are ``BATCHED``,
175- otherwise ``RAGGED``. Format compatibility is enforced by the
176- compiler's per-workload kernel.
177-
178- Attributes
179- ----------
180- output_format : Optional[FORMAT]
181- The output tensor format as determined in :meth:`profile`.
182-
183- output_buffer : Optional[torch.Tensor]
184- Preallocated output tensor buffer of shape ``[S, N_y, N_x]``.
78+ Per-page matrix–matrix product, :math:`O[s] = Y[s]\,X[s]^{\top}`.
79+
80+ :Math: with :math:`Y\in\mathbb{R}^{S\times N_y\times K}` and
81+ :math:`X\in\mathbb{R}^{(B\,\text{or}\,S)\times N_x\times K}`, per page
82+ :math:`s`:
83+ :math:`O[s]=Y[s]\,X[s]^{\top}\in\mathbb{R}^{N_y\times N_x}` — i.e.
84+ ``GeMM(x, y) = y xᵀ``. Output :math:`O\in\mathbb{R}^{S\times N_y\times N_x}`.
85+ :__init__: ``GeMM()`` — no arguments.
86+ :__call__: ``o = op(x, y, ctx=ctx)`` — ``x`` is ``[B|S, N_x, K]``, ``y`` is
87+ ``[S, N_y, K]`` (matching ``K``); returns ``o`` ``[S, N_y, N_x]``.
88+ Output is ``BATCHED`` iff both inputs are, else ``RAGGED``.
18589 """
18690
18791 def __init__ (self ):
@@ -192,48 +96,9 @@ def __init__(self):
19296
19397 # ---------------- profile ----------------
19498 def profile (self , x : vTensor , y : vTensor , ctx : Context ) -> vTensor :
195- r"""
196- Validate inputs, allocate the output buffer, and return a
197- :class:`vTensor` view.
198-
199- The method enforces that both inputs are rank-3 tensors and that the
200- inner dimension :math:`K` matches:
201-
202- - ``x``: ``[B_or_S, N_x, K]``
203-
204- *When* ``x_format == FORMAT.BATCHED``, the leading dimension is a
205- batch size :math:`B`. For ragged/paged formats, it may conceptually
206- coincide with :math:`S`.
207-
208- - ``y``: ``[S, N_y, K]``
209-
210- The output buffer is allocated with shape ``[S, N_y, N_x]``, where
211- ``S`` is taken from the runtime context as ``ctx.max_num_pages``.
212-
213- Parameters
214- ----------
215- x : vTensor
216- Right-hand operand (transposed in the mathematical view), with
217- shape ``[B_or_S, N_x, K]``.
218-
219- y : vTensor
220- Left-hand operand with shape ``[S, N_y, K]``.
221-
222- ctx : Context
223- Execution context providing ``ctx.max_num_pages`` for the logical
224- ``S`` dimension and tracking auxiliary memory.
225-
226- Returns
227- -------
228- vTensor
229- A ``vTensor`` view wrapping the allocated output buffer.
230-
231- Raises
232- ------
233- AssertionError
234- If types are not ``vTensor``, ranks are not 3, or the inner
235- dimensions :math:`K` do not match.
236- """
99+ r"""Trace-time: validate ``x`` ``[B|S, N_x, K]`` / ``y`` ``[S, N_y, K]``
100+ (matching ``K``), register the op, and return a ``vTensor`` view of the
101+ ``[S, N_y, N_x]`` output (see the class docstring)."""
237102 prefix = self ._prefix ()
238103
239104 # Type checks
0 commit comments