@@ -10,13 +10,17 @@ class Reduce(vOp):
1010
1111 :Math:
1212 For input :math:`X\in\mathbb{R}^{N\times D_0\times D_1}` and a
13- reduction :math:`\rho` (mean / max / min / L2-norm / sum, fixed by the
14- subclass): ``dim=1`` →
15- :math:`\text{out}[n,0,d_1]=\rho_{d_0}\,X[n,d_0,d_1]` (shape
16- ``[N, 1, D_1]``); ``dim=2`` →
17- :math:`\text{out}[n,d_0,0]=\rho_{d_1}\,X[n,d_0,d_1]` (shape
18- ``[N, D_0, 1]``); ``dim=0`` collapses the packed leading axis to one
19- row per ``(batch, kv_head)``.
13+ per-axis reduction :math:`\rho` (mean / max / min / L2-norm / sum,
14+ fixed by the subclass):
15+
16+ .. math::
17+
18+ Y_{n,0,d} = \rho_{\,0 \le i < D_0}\, X_{n,i,d}\ \ (\text{dim}=1),
19+ \qquad
20+ Y_{n,d,0} = \rho_{\,0 \le j < D_1}\, X_{n,d,j}\ \ (\text{dim}=2).
21+
22+ ``dim=0`` collapses the packed leading axis to one row per
23+ ``(batch, kv\_head)``.
2024 :__init__:
2125 ``Reduce(dim=1)`` — logical axis to reduce, one of ``0`` / ``1`` / ``2``.
2226 :__call__:
@@ -101,8 +105,12 @@ class Max(Reduce):
101105 r"""
102106 Max reduction over one logical axis (a :class:`Reduce`).
103107
104- :Math: ``dim=1``: :math:`\text{out}[n,0,d_1]=\max_{d_0}X[n,d_0,d_1]`;
105- ``dim=2``: :math:`\text{out}[n,d_0,0]=\max_{d_1}X[n,d_0,d_1]`.
108+ :Math:
109+ .. math::
110+
111+ Y_{n,0,d} = \max_{0 \le i < D_0} X_{n,i,d}\ \ (\text{dim}=1),
112+ \qquad
113+ Y_{n,d,0} = \max_{0 \le j < D_1} X_{n,d,j}\ \ (\text{dim}=2).
106114 :__init__: ``Max(dim=1)`` — axis to reduce (``1`` → :math:`D_0`,
107115 ``2`` → :math:`D_1`).
108116 :__call__: ``y = op(x, ctx=ctx)`` — ``[N, D_0, D_1]`` → ``[N, 1, D_1]``
@@ -117,8 +125,12 @@ class Min(Reduce):
117125 r"""
118126 Min reduction over one logical axis (a :class:`Reduce`).
119127
120- :Math: ``dim=1``: :math:`\text{out}[n,0,d_1]=\min_{d_0}X[n,d_0,d_1]`;
121- ``dim=2``: :math:`\text{out}[n,d_0,0]=\min_{d_1}X[n,d_0,d_1]`.
128+ :Math:
129+ .. math::
130+
131+ Y_{n,0,d} = \min_{0 \le i < D_0} X_{n,i,d}\ \ (\text{dim}=1),
132+ \qquad
133+ Y_{n,d,0} = \min_{0 \le j < D_1} X_{n,d,j}\ \ (\text{dim}=2).
122134 :__init__: ``Min(dim=1)`` — axis to reduce (``1`` → :math:`D_0`,
123135 ``2`` → :math:`D_1`).
124136 :__call__: ``y = op(x, ctx=ctx)`` — ``[N, D_0, D_1]`` → ``[N, 1, D_1]``
@@ -133,8 +145,12 @@ class Mean(Reduce):
133145 r"""
134146 Mean reduction over one logical axis (a :class:`Reduce`).
135147
136- :Math: ``dim=1``: :math:`\text{out}[n,0,d_1]=\frac{1}{D_0}\sum_{d_0}X[n,d_0,d_1]`;
137- ``dim=2``: :math:`\text{out}[n,d_0,0]=\frac{1}{D_1}\sum_{d_1}X[n,d_0,d_1]`.
148+ :Math:
149+ .. math::
150+
151+ Y_{n,0,d} = \frac{1}{D_0}\sum_{i=0}^{D_0-1} X_{n,i,d}\ \ (\text{dim}=1),
152+ \qquad
153+ Y_{n,d,0} = \frac{1}{D_1}\sum_{j=0}^{D_1-1} X_{n,d,j}\ \ (\text{dim}=2).
138154 :__init__: ``Mean(dim=1)`` — axis to reduce (``1`` → :math:`D_0`,
139155 ``2`` → :math:`D_1`).
140156 :__call__: ``y = op(x, ctx=ctx)`` — ``[N, D_0, D_1]`` → ``[N, 1, D_1]``
@@ -149,8 +165,12 @@ class L2Norm(Reduce):
149165 r"""
150166 L2-norm reduction over one logical axis (a :class:`Reduce`).
151167
152- :Math: ``dim=1``: :math:`\text{out}[n,0,d_1]=\sqrt{\sum_{d_0}X[n,d_0,d_1]^2}`;
153- ``dim=2``: :math:`\text{out}[n,d_0,0]=\sqrt{\sum_{d_1}X[n,d_0,d_1]^2}`.
168+ :Math:
169+ .. math::
170+
171+ Y_{n,0,d} = \Big(\sum_{i=0}^{D_0-1} X_{n,i,d}^2\Big)^{1/2}\ \ (\text{dim}=1),
172+ \qquad
173+ Y_{n,d,0} = \Big(\sum_{j=0}^{D_1-1} X_{n,d,j}^2\Big)^{1/2}\ \ (\text{dim}=2).
154174 :__init__: ``L2Norm(dim=1)`` — axis to reduce (``1`` → :math:`D_0`,
155175 ``2`` → :math:`D_1`).
156176 :__call__: ``y = op(x, ctx=ctx)`` — ``[N, D_0, D_1]`` → ``[N, 1, D_1]``
@@ -165,8 +185,12 @@ class Sum(Reduce):
165185 r"""
166186 Sum reduction over one logical axis (a :class:`Reduce`).
167187
168- :Math: ``dim=1``: :math:`\text{out}[n,0,d_1]=\sum_{d_0}X[n,d_0,d_1]`;
169- ``dim=2``: :math:`\text{out}[n,d_0,0]=\sum_{d_1}X[n,d_0,d_1]`.
188+ :Math:
189+ .. math::
190+
191+ Y_{n,0,d} = \sum_{i=0}^{D_0-1} X_{n,i,d}\ \ (\text{dim}=1),
192+ \qquad
193+ Y_{n,d,0} = \sum_{j=0}^{D_1-1} X_{n,d,j}\ \ (\text{dim}=2).
170194 :__init__: ``Sum(dim=1)`` — axis to reduce (``1`` → :math:`D_0`,
171195 ``2`` → :math:`D_1`).
172196 :__call__: ``y = op(x, ctx=ctx)`` — ``[N, D_0, D_1]`` → ``[N, 1, D_1]``
0 commit comments