Skip to content

Commit f376f13

Browse files
author
Zhuoming Chen
committed
comments and docs
1 parent 1448a78 commit f376f13

18 files changed

Lines changed: 706 additions & 2221 deletions

vortex_torch/cache/elementwise.py

Lines changed: 63 additions & 259 deletions
Original file line numberDiff line numberDiff line change
@@ -6,48 +6,24 @@
66

77
class Elementwise(vOp):
88
r"""
9-
Unary elementwise op (e.g. ReLU/Sigmoid/SiLU/Abs/Affine).
10-
11-
Operates on rank-3 tensors
12-
13-
.. math::
14-
15-
X \in \mathbb{R}^{B \times N \times D},
16-
17-
where:
18-
19-
- :math:`B` is a leading batch-like axis (for example,
20-
``max_new_tokens_per_batch * head_num`` coming from the runtime
21-
context),
22-
- :math:`N` is a sequence or position dimension, and
23-
- :math:`D` is a feature/channel dimension.
24-
25-
The operation is applied pointwise:
26-
27-
.. math::
28-
29-
Y[b, n, d] = f(X[b, n, d]; \alpha, \beta, \text{op_type}),
30-
31-
where the actual function :math:`f` is selected by :attr:`op_type`.
32-
33-
Output format rule: if a caller-provided ``output`` is supplied with
34-
``PAGED`` format, the output is ``PAGED``; in every other case
35-
(``output is None``, or ``output._format == RAGGED``) the output is
36-
``RAGGED``. Format compatibility is enforced by the compiler's
37-
per-block kernel.
38-
39-
Attributes
40-
----------
41-
alpha : float
42-
Scalar parameter used by certain unary ops.
43-
beta : float
44-
Scalar parameter used by certain unary ops.
45-
op_type : Optional[ElementwiseOpType]
46-
Runtime-set enum/int describing the specific elementwise operation.
47-
output_format : Optional[FORMAT]
48-
The output tensor format as determined in :meth:`profile`.
49-
output_buffer : Optional[vTensor]
50-
Pure-metadata vTensor descriptor for the output (graph node).
9+
Unary elementwise op — applies a scalar function pointwise.
10+
11+
:Math:
12+
.. math::
13+
14+
Y_{b,n,d} = f(X_{b,n,d};\, \alpha, \beta),
15+
16+
where :math:`f` is fixed by the subclass (ReLU / SiLU / Sigmoid /
17+
affine / abs / log / exp).
18+
:__init__: ``Elementwise(alpha=0.0, beta=1.0)`` — scalar parameters
19+
:math:`\alpha`, :math:`\beta` consumed by :math:`f`.
20+
:__call__: ``y = op(x, output, loc=loc, ctx=ctx)`` — ``x`` is ``[B, N, D]``;
21+
``output`` is an optional preallocated buffer (``None`` → fresh
22+
``RAGGED``). Returns the same ``[B, N, D]`` shape; ``PAGED`` iff a
23+
``PAGED`` ``output`` is supplied, else ``RAGGED``.
24+
:Note: use a concrete subclass — :class:`Relu`, :class:`Silu`,
25+
:class:`Sigmoid`, :class:`Add_Mul`, :class:`Abs`, :class:`Log`,
26+
:class:`Exp`.
5127
"""
5228

5329
def __init__(self, alpha: float = 0.0, beta: float = 1.0):
@@ -67,63 +43,9 @@ def __init__(self, alpha: float = 0.0, beta: float = 1.0):
6743
def profile(
6844
self, x: vTensor, output: Optional[vTensor], loc: torch.Tensor, ctx: Context
6945
) -> vTensor:
70-
r"""
71-
Validate inputs and optionally allocate an internal output buffer.
72-
73-
The input tensor ``x`` is expected to have logical shape
74-
``[B, N, D]``.
75-
76-
Two modes:
77-
78-
- **No output provided** (``output is None``):
79-
80-
- Allocate an internal RAGGED buffer with shape ``[B, N, D]``,
81-
where
82-
83-
.. math::
84-
85-
B = \text{ctx.max_new_tokens_per_batch} \times \text{ctx.head_num}.
86-
87-
- **Output provided** (``output is not None``):
88-
89-
- Take the format directly from ``output._format`` (must be
90-
``PAGED`` or ``RAGGED``).
91-
- Validate that ``output`` has rank 3 and preserves the
92-
``(N, D)`` dimensions of ``x``.
93-
- Validate device consistency between ``x`` and ``output``.
94-
95-
Parameters
96-
----------
97-
x : vTensor
98-
Input tensor with logical shape ``[B, N, D]``.
99-
100-
output : Optional[vTensor]
101-
Optional preallocated output tensor. If ``None``, an internal
102-
RAGGED buffer is allocated; otherwise, this tensor must have
103-
shape ``[B_out, N, D]`` for some ``B_out`` and a format in
104-
``{PAGED, RAGGED}``.
105-
106-
loc : torch.Tensor
107-
Auxiliary tensor carrying per-position metadata required by
108-
the implementation (e.g., location/segment indices).
109-
110-
ctx : Context
111-
Execution context that provides the runtime value of ``B``
112-
(via ``ctx.max_new_tokens_per_batch`` and ``ctx.head_num``)
113-
and is used for auxiliary memory accounting.
114-
115-
Returns
116-
-------
117-
vTensor
118-
A :class:`vTensor` view representing the resolved output:
119-
either the provided ``output`` or an internally allocated
120-
buffer.
121-
122-
Raises
123-
------
124-
AssertionError
125-
If types, ranks, shapes, or devices are incompatible.
126-
"""
46+
r"""Trace-time: validate ``x`` ``[B, N, D]`` (and ``output`` if given),
47+
register the op, and return a ``vTensor`` view of the same-shape output
48+
(a fresh ``RAGGED`` buffer when ``output is None``)."""
12749
prefix = self._prefix()
12850

12951
# --- type & rank checks ---
@@ -189,34 +111,14 @@ def profile(
189111

190112
class Relu(Elementwise):
191113
r"""
192-
Piecewise ReLU-like activation.
193-
194-
This operator applies, elementwise, the scalar function
114+
ReLU-like activation with threshold/fallback (an :class:`Elementwise`).
195115
196-
.. math::
116+
:Math:
117+
.. math::
197118
198-
f(x; \alpha, \beta) =
199-
\begin{cases}
200-
x, & x \ge \alpha, \\
201-
\beta, & x < \alpha.
202-
\end{cases}
203-
204-
Given an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
205-
the output is defined by
206-
207-
.. math::
208-
209-
Y[b, n, d] = f\bigl(X[b, n, d]; \alpha, \beta\bigr).
210-
211-
Parameters
212-
----------
213-
alpha : float, optional
214-
Threshold value :math:`\alpha`. Inputs greater than or equal to
215-
this threshold are passed through unchanged. Default is ``0.0``.
216-
217-
beta : float, optional
218-
Fallback value :math:`\beta` used when :math:`x < \alpha`.
219-
Default is ``0.0``.
119+
f(x;\alpha,\beta) = \begin{cases} x, & x \ge \alpha, \\ \beta, & x < \alpha. \end{cases}
120+
:__init__: ``Relu(alpha=0.0, beta=0.0)`` — threshold :math:`\alpha`,
121+
fallback value :math:`\beta` (used when :math:`x<\alpha`).
220122
"""
221123
def __init__(self, alpha: float = 0.0, beta: float = 0.0):
222124
super().__init__(alpha, beta)
@@ -225,32 +127,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 0.0):
225127

226128
class Silu(Elementwise):
227129
r"""
228-
SiLU-like activation with configurable shift and slope.
229-
230-
This operator applies, elementwise, the scalar function
231-
232-
.. math::
233-
234-
\operatorname{SiLU}(x; \alpha, \beta)
235-
= \frac{x}{1 + \exp(\beta x + \alpha)}.
130+
SiLU-like activation with configurable shift/slope (an :class:`Elementwise`).
236131
237-
Given an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
238-
the output is
132+
:Math:
133+
.. math::
239134
240-
.. math::
241-
242-
Y[b, n, d]
243-
= \operatorname{SiLU}\bigl(X[b, n, d]; \alpha, \beta\bigr).
244-
245-
Parameters
246-
----------
247-
alpha : float, optional
248-
Bias term :math:`\alpha` added inside the exponential. Default is
249-
``0.0``.
250-
251-
beta : float, optional
252-
Slope :math:`\beta` multiplying :math:`x` inside the exponential.
253-
Default is ``0.0``.
135+
f(x;\alpha,\beta) = \frac{x}{1 + \exp(\beta x + \alpha)}.
136+
:__init__: ``Silu(alpha=0.0, beta=0.0)`` — bias :math:`\alpha`, slope
137+
:math:`\beta` inside the exponential.
254138
"""
255139
def __init__(self, alpha: float = 0.0, beta: float = 0.0):
256140
super().__init__(alpha, beta)
@@ -259,32 +143,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 0.0):
259143

260144
class Sigmoid(Elementwise):
261145
r"""
262-
Sigmoid activation with configurable shift and slope.
263-
264-
This operator applies, elementwise, the scalar function
265-
266-
.. math::
267-
268-
\sigma(x; \alpha, \beta)
269-
= \frac{1}{1 + \exp(\beta x + \alpha)}.
270-
271-
Given an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
272-
the output is
273-
274-
.. math::
275-
276-
Y[b, n, d]
277-
= \sigma\bigl(X[b, n, d]; \alpha, \beta\bigr).
146+
Sigmoid activation with configurable shift/slope (an :class:`Elementwise`).
278147
279-
Parameters
280-
----------
281-
alpha : float, optional
282-
Bias term :math:`\alpha` added inside the exponential. Default is
283-
``0.0``.
148+
:Math:
149+
.. math::
284150
285-
beta : float, optional
286-
Slope :math:`\beta` multiplying :math:`x` inside the exponential.
287-
Default is ``0.0``.
151+
f(x;\alpha,\beta) = \frac{1}{1 + \exp(\beta x + \alpha)}.
152+
:__init__: ``Sigmoid(alpha=0.0, beta=0.0)`` — bias :math:`\alpha`, slope
153+
:math:`\beta` inside the exponential.
288154
"""
289155
def __init__(self, alpha: float = 0.0, beta: float = 0.0):
290156
super().__init__(alpha, beta)
@@ -293,31 +159,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 0.0):
293159

294160
class Add_Mul(Elementwise):
295161
r"""
296-
Affine transformation :math:`y = \beta x + \alpha`.
162+
Affine transform :math:`\beta x + \alpha` (an :class:`Elementwise`).
297163
298-
This operator applies, elementwise, the scalar function
164+
:Math:
165+
.. math::
299166
300-
.. math::
301-
302-
f(x; \alpha, \beta) = \beta x + \alpha.
303-
304-
For an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
305-
the output is
306-
307-
.. math::
308-
309-
Y[b, n, d]
310-
= \beta \, X[b, n, d] + \alpha.
311-
312-
Parameters
313-
----------
314-
alpha : float, optional
315-
Additive term :math:`\alpha` in the affine transform. Default is
316-
``0.0``.
317-
318-
beta : float, optional
319-
Multiplicative term :math:`\beta` in the affine transform.
320-
Default is ``1.0``.
167+
f(x;\alpha,\beta) = \beta x + \alpha.
168+
:__init__: ``Add_Mul(alpha=0.0, beta=1.0)`` — additive :math:`\alpha`,
169+
multiplicative :math:`\beta`.
321170
"""
322171
def __init__(self, alpha: float = 0.0, beta: float = 1.0):
323172
super().__init__(alpha, beta)
@@ -326,31 +175,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 1.0):
326175

327176
class Abs(Elementwise):
328177
r"""
329-
Absolute-value transform of an affine argument.
330-
331-
This operator applies, elementwise, the scalar function
332-
333-
.. math::
178+
Absolute value of an affine transform (an :class:`Elementwise`).
334179
335-
f(x; \alpha, \beta) = \bigl|\beta x + \alpha\bigr|.
180+
:Math:
181+
.. math::
336182
337-
For an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
338-
the output is
339-
340-
.. math::
341-
342-
Y[b, n, d]
343-
= \bigl|\beta \, X[b, n, d] + \alpha\bigr|.
344-
345-
Parameters
346-
----------
347-
alpha : float, optional
348-
Additive term :math:`\alpha` inside the absolute value. Default is
349-
``0.0``.
350-
351-
beta : float, optional
352-
Multiplicative term :math:`\beta` inside the absolute value.
353-
Default is ``1.0``.
183+
f(x;\alpha,\beta) = \lvert \beta x + \alpha \rvert.
184+
:__init__: ``Abs(alpha=0.0, beta=1.0)`` — additive :math:`\alpha`,
185+
multiplicative :math:`\beta` inside the absolute value.
354186
"""
355187
def __init__(self, alpha: float = 0.0, beta: float = 1.0):
356188
super().__init__(alpha, beta)
@@ -359,28 +191,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 1.0):
359191

360192
class Log(Elementwise):
361193
r"""
362-
Natural logarithm of an affine transform.
363-
364-
This operator applies, elementwise, the scalar function
365-
366-
.. math::
367-
368-
f(x; \alpha, \beta) = \log(\beta x + \alpha).
369-
370-
Given an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
371-
the output is
194+
Natural logarithm of an affine transform (an :class:`Elementwise`).
372195
373-
.. math::
196+
:Math:
197+
.. math::
374198
375-
Y[b, n, d] = \log(\beta X[b, n, d] + \alpha).
376-
377-
Parameters
378-
----------
379-
alpha : float, optional
380-
Additive bias term inside the logarithm. Default is ``0.0``.
381-
382-
beta : float, optional
383-
Multiplicative scale term inside the logarithm. Default is ``1.0``.
199+
f(x;\alpha,\beta) = \log(\beta x + \alpha).
200+
:__init__: ``Log(alpha=0.0, beta=1.0)`` — additive :math:`\alpha`,
201+
multiplicative :math:`\beta` inside the logarithm.
384202
"""
385203
def __init__(self, alpha: float = 0.0, beta: float = 1.0):
386204
super().__init__(alpha, beta)
@@ -389,28 +207,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 1.0):
389207

390208
class Exp(Elementwise):
391209
r"""
392-
Exponential of an affine transform.
393-
394-
This operator applies, elementwise, the scalar function
395-
396-
.. math::
397-
398-
f(x; \alpha, \beta) = \exp(\beta x + \alpha).
399-
400-
Given an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
401-
the output is
402-
403-
.. math::
404-
405-
Y[b, n, d] = \exp(\beta X[b, n, d] + \alpha).
210+
Exponential of an affine transform (an :class:`Elementwise`).
406211
407-
Parameters
408-
----------
409-
alpha : float, optional
410-
Additive bias term inside the exponential. Default is ``0.0``.
212+
:Math:
213+
.. math::
411214
412-
beta : float, optional
413-
Multiplicative scale term inside the exponential. Default is ``1.0``.
215+
f(x;\alpha,\beta) = \exp(\beta x + \alpha).
216+
:__init__: ``Exp(alpha=0.0, beta=1.0)`` — additive :math:`\alpha`,
217+
multiplicative :math:`\beta` inside the exponential.
414218
"""
415219
def __init__(self, alpha: float = 0.0, beta: float = 1.0):
416220
super().__init__(alpha, beta)

0 commit comments

Comments
 (0)