66
77class Elementwise (vOp ):
88 r"""
9- Unary elementwise op (e.g. ReLU/Sigmoid/SiLU/Abs/Affine).
10-
11- Operates on rank-3 tensors
12-
13- .. math::
14-
15- X \in \mathbb{R}^{B \times N \times D},
16-
17- where:
18-
19- - :math:`B` is a leading batch-like axis (for example,
20- ``max_new_tokens_per_batch * head_num`` coming from the runtime
21- context),
22- - :math:`N` is a sequence or position dimension, and
23- - :math:`D` is a feature/channel dimension.
24-
25- The operation is applied pointwise:
26-
27- .. math::
28-
29- Y[b, n, d] = f(X[b, n, d]; \alpha, \beta, \text{op_type}),
30-
31- where the actual function :math:`f` is selected by :attr:`op_type`.
32-
33- Output format rule: if a caller-provided ``output`` is supplied with
34- ``PAGED`` format, the output is ``PAGED``; in every other case
35- (``output is None``, or ``output._format == RAGGED``) the output is
36- ``RAGGED``. Format compatibility is enforced by the compiler's
37- per-block kernel.
38-
39- Attributes
40- ----------
41- alpha : float
42- Scalar parameter used by certain unary ops.
43- beta : float
44- Scalar parameter used by certain unary ops.
45- op_type : Optional[ElementwiseOpType]
46- Runtime-set enum/int describing the specific elementwise operation.
47- output_format : Optional[FORMAT]
48- The output tensor format as determined in :meth:`profile`.
49- output_buffer : Optional[vTensor]
50- Pure-metadata vTensor descriptor for the output (graph node).
9+ Unary elementwise op — applies a scalar function pointwise.
10+
11+ :Math:
12+ .. math::
13+
14+ Y_{b,n,d} = f(X_{b,n,d};\, \alpha, \beta),
15+
16+ where :math:`f` is fixed by the subclass (ReLU / SiLU / Sigmoid /
17+ affine / abs / log / exp).
18+ :__init__: ``Elementwise(alpha=0.0, beta=1.0)`` — scalar parameters
19+ :math:`\alpha`, :math:`\beta` consumed by :math:`f`.
20+ :__call__: ``y = op(x, output, loc=loc, ctx=ctx)`` — ``x`` is ``[B, N, D]``;
21+ ``output`` is an optional preallocated buffer (``None`` → fresh
22+ ``RAGGED``). Returns the same ``[B, N, D]`` shape; ``PAGED`` iff a
23+ ``PAGED`` ``output`` is supplied, else ``RAGGED``.
24+ :Note: use a concrete subclass — :class:`Relu`, :class:`Silu`,
25+ :class:`Sigmoid`, :class:`Add_Mul`, :class:`Abs`, :class:`Log`,
26+ :class:`Exp`.
5127 """
5228
5329 def __init__ (self , alpha : float = 0.0 , beta : float = 1.0 ):
@@ -67,63 +43,9 @@ def __init__(self, alpha: float = 0.0, beta: float = 1.0):
6743 def profile (
6844 self , x : vTensor , output : Optional [vTensor ], loc : torch .Tensor , ctx : Context
6945 ) -> vTensor :
70- r"""
71- Validate inputs and optionally allocate an internal output buffer.
72-
73- The input tensor ``x`` is expected to have logical shape
74- ``[B, N, D]``.
75-
76- Two modes:
77-
78- - **No output provided** (``output is None``):
79-
80- - Allocate an internal RAGGED buffer with shape ``[B, N, D]``,
81- where
82-
83- .. math::
84-
85- B = \text{ctx.max_new_tokens_per_batch} \times \text{ctx.head_num}.
86-
87- - **Output provided** (``output is not None``):
88-
89- - Take the format directly from ``output._format`` (must be
90- ``PAGED`` or ``RAGGED``).
91- - Validate that ``output`` has rank 3 and preserves the
92- ``(N, D)`` dimensions of ``x``.
93- - Validate device consistency between ``x`` and ``output``.
94-
95- Parameters
96- ----------
97- x : vTensor
98- Input tensor with logical shape ``[B, N, D]``.
99-
100- output : Optional[vTensor]
101- Optional preallocated output tensor. If ``None``, an internal
102- RAGGED buffer is allocated; otherwise, this tensor must have
103- shape ``[B_out, N, D]`` for some ``B_out`` and a format in
104- ``{PAGED, RAGGED}``.
105-
106- loc : torch.Tensor
107- Auxiliary tensor carrying per-position metadata required by
108- the implementation (e.g., location/segment indices).
109-
110- ctx : Context
111- Execution context that provides the runtime value of ``B``
112- (via ``ctx.max_new_tokens_per_batch`` and ``ctx.head_num``)
113- and is used for auxiliary memory accounting.
114-
115- Returns
116- -------
117- vTensor
118- A :class:`vTensor` view representing the resolved output:
119- either the provided ``output`` or an internally allocated
120- buffer.
121-
122- Raises
123- ------
124- AssertionError
125- If types, ranks, shapes, or devices are incompatible.
126- """
46+ r"""Trace-time: validate ``x`` ``[B, N, D]`` (and ``output`` if given),
47+ register the op, and return a ``vTensor`` view of the same-shape output
48+ (a fresh ``RAGGED`` buffer when ``output is None``)."""
12749 prefix = self ._prefix ()
12850
12951 # --- type & rank checks ---
@@ -189,34 +111,14 @@ def profile(
189111
190112class Relu (Elementwise ):
191113 r"""
192- Piecewise ReLU-like activation.
193-
194- This operator applies, elementwise, the scalar function
114+ ReLU-like activation with threshold/fallback (an :class:`Elementwise`).
195115
196- .. math::
116+ :Math:
117+ .. math::
197118
198- f(x; \alpha, \beta) =
199- \begin{cases}
200- x, & x \ge \alpha, \\
201- \beta, & x < \alpha.
202- \end{cases}
203-
204- Given an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
205- the output is defined by
206-
207- .. math::
208-
209- Y[b, n, d] = f\bigl(X[b, n, d]; \alpha, \beta\bigr).
210-
211- Parameters
212- ----------
213- alpha : float, optional
214- Threshold value :math:`\alpha`. Inputs greater than or equal to
215- this threshold are passed through unchanged. Default is ``0.0``.
216-
217- beta : float, optional
218- Fallback value :math:`\beta` used when :math:`x < \alpha`.
219- Default is ``0.0``.
119+ f(x;\alpha,\beta) = \begin{cases} x, & x \ge \alpha, \\ \beta, & x < \alpha. \end{cases}
120+ :__init__: ``Relu(alpha=0.0, beta=0.0)`` — threshold :math:`\alpha`,
121+ fallback value :math:`\beta` (used when :math:`x<\alpha`).
220122 """
221123 def __init__ (self , alpha : float = 0.0 , beta : float = 0.0 ):
222124 super ().__init__ (alpha , beta )
@@ -225,32 +127,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 0.0):
225127
226128class Silu (Elementwise ):
227129 r"""
228- SiLU-like activation with configurable shift and slope.
229-
230- This operator applies, elementwise, the scalar function
231-
232- .. math::
233-
234- \operatorname{SiLU}(x; \alpha, \beta)
235- = \frac{x}{1 + \exp(\beta x + \alpha)}.
130+ SiLU-like activation with configurable shift/slope (an :class:`Elementwise`).
236131
237- Given an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
238- the output is
132+ :Math:
133+ .. math::
239134
240- .. math::
241-
242- Y[b, n, d]
243- = \operatorname{SiLU}\bigl(X[b, n, d]; \alpha, \beta\bigr).
244-
245- Parameters
246- ----------
247- alpha : float, optional
248- Bias term :math:`\alpha` added inside the exponential. Default is
249- ``0.0``.
250-
251- beta : float, optional
252- Slope :math:`\beta` multiplying :math:`x` inside the exponential.
253- Default is ``0.0``.
135+ f(x;\alpha,\beta) = \frac{x}{1 + \exp(\beta x + \alpha)}.
136+ :__init__: ``Silu(alpha=0.0, beta=0.0)`` — bias :math:`\alpha`, slope
137+ :math:`\beta` inside the exponential.
254138 """
255139 def __init__ (self , alpha : float = 0.0 , beta : float = 0.0 ):
256140 super ().__init__ (alpha , beta )
@@ -259,32 +143,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 0.0):
259143
260144class Sigmoid (Elementwise ):
261145 r"""
262- Sigmoid activation with configurable shift and slope.
263-
264- This operator applies, elementwise, the scalar function
265-
266- .. math::
267-
268- \sigma(x; \alpha, \beta)
269- = \frac{1}{1 + \exp(\beta x + \alpha)}.
270-
271- Given an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
272- the output is
273-
274- .. math::
275-
276- Y[b, n, d]
277- = \sigma\bigl(X[b, n, d]; \alpha, \beta\bigr).
146+ Sigmoid activation with configurable shift/slope (an :class:`Elementwise`).
278147
279- Parameters
280- ----------
281- alpha : float, optional
282- Bias term :math:`\alpha` added inside the exponential. Default is
283- ``0.0``.
148+ :Math:
149+ .. math::
284150
285- beta : float, optional
286- Slope :math:`\ beta` multiplying :math:`x` inside the exponential.
287- Default is ``0.0`` .
151+ f(x;\alpha,\beta) = \frac{1}{1 + \exp(\beta x + \alpha)}.
152+ :__init__: ``Sigmoid(alpha=0.0, beta=0.0)`` — bias :math:`\alpha`, slope
153+ :math:`\beta` inside the exponential .
288154 """
289155 def __init__ (self , alpha : float = 0.0 , beta : float = 0.0 ):
290156 super ().__init__ (alpha , beta )
@@ -293,31 +159,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 0.0):
293159
294160class Add_Mul (Elementwise ):
295161 r"""
296- Affine transformation :math:`y = \beta x + \alpha`.
162+ Affine transform :math:`\beta x + \alpha` (an :class:`Elementwise`) .
297163
298- This operator applies, elementwise, the scalar function
164+ :Math:
165+ .. math::
299166
300- .. math::
301-
302- f(x; \alpha, \beta) = \beta x + \alpha.
303-
304- For an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
305- the output is
306-
307- .. math::
308-
309- Y[b, n, d]
310- = \beta \, X[b, n, d] + \alpha.
311-
312- Parameters
313- ----------
314- alpha : float, optional
315- Additive term :math:`\alpha` in the affine transform. Default is
316- ``0.0``.
317-
318- beta : float, optional
319- Multiplicative term :math:`\beta` in the affine transform.
320- Default is ``1.0``.
167+ f(x;\alpha,\beta) = \beta x + \alpha.
168+ :__init__: ``Add_Mul(alpha=0.0, beta=1.0)`` — additive :math:`\alpha`,
169+ multiplicative :math:`\beta`.
321170 """
322171 def __init__ (self , alpha : float = 0.0 , beta : float = 1.0 ):
323172 super ().__init__ (alpha , beta )
@@ -326,31 +175,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 1.0):
326175
327176class Abs (Elementwise ):
328177 r"""
329- Absolute-value transform of an affine argument.
330-
331- This operator applies, elementwise, the scalar function
332-
333- .. math::
178+ Absolute value of an affine transform (an :class:`Elementwise`).
334179
335- f(x; \alpha, \beta) = \bigl|\beta x + \alpha\bigr|.
180+ :Math:
181+ .. math::
336182
337- For an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
338- the output is
339-
340- .. math::
341-
342- Y[b, n, d]
343- = \bigl|\beta \, X[b, n, d] + \alpha\bigr|.
344-
345- Parameters
346- ----------
347- alpha : float, optional
348- Additive term :math:`\alpha` inside the absolute value. Default is
349- ``0.0``.
350-
351- beta : float, optional
352- Multiplicative term :math:`\beta` inside the absolute value.
353- Default is ``1.0``.
183+ f(x;\alpha,\beta) = \lvert \beta x + \alpha \rvert.
184+ :__init__: ``Abs(alpha=0.0, beta=1.0)`` — additive :math:`\alpha`,
185+ multiplicative :math:`\beta` inside the absolute value.
354186 """
355187 def __init__ (self , alpha : float = 0.0 , beta : float = 1.0 ):
356188 super ().__init__ (alpha , beta )
@@ -359,28 +191,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 1.0):
359191
360192class Log (Elementwise ):
361193 r"""
362- Natural logarithm of an affine transform.
363-
364- This operator applies, elementwise, the scalar function
365-
366- .. math::
367-
368- f(x; \alpha, \beta) = \log(\beta x + \alpha).
369-
370- Given an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
371- the output is
194+ Natural logarithm of an affine transform (an :class:`Elementwise`).
372195
373- .. math::
196+ :Math:
197+ .. math::
374198
375- Y[b, n, d] = \log(\beta X[b, n, d] + \alpha).
376-
377- Parameters
378- ----------
379- alpha : float, optional
380- Additive bias term inside the logarithm. Default is ``0.0``.
381-
382- beta : float, optional
383- Multiplicative scale term inside the logarithm. Default is ``1.0``.
199+ f(x;\alpha,\beta) = \log(\beta x + \alpha).
200+ :__init__: ``Log(alpha=0.0, beta=1.0)`` — additive :math:`\alpha`,
201+ multiplicative :math:`\beta` inside the logarithm.
384202 """
385203 def __init__ (self , alpha : float = 0.0 , beta : float = 1.0 ):
386204 super ().__init__ (alpha , beta )
@@ -389,28 +207,14 @@ def __init__(self, alpha: float = 0.0, beta: float = 1.0):
389207
390208class Exp (Elementwise ):
391209 r"""
392- Exponential of an affine transform.
393-
394- This operator applies, elementwise, the scalar function
395-
396- .. math::
397-
398- f(x; \alpha, \beta) = \exp(\beta x + \alpha).
399-
400- Given an input tensor :math:`X \in \mathbb{R}^{B \times N \times D}`,
401- the output is
402-
403- .. math::
404-
405- Y[b, n, d] = \exp(\beta X[b, n, d] + \alpha).
210+ Exponential of an affine transform (an :class:`Elementwise`).
406211
407- Parameters
408- ----------
409- alpha : float, optional
410- Additive bias term inside the exponential. Default is ``0.0``.
212+ :Math:
213+ .. math::
411214
412- beta : float, optional
413- Multiplicative scale term inside the exponential. Default is ``1.0``.
215+ f(x;\alpha,\beta) = \exp(\beta x + \alpha).
216+ :__init__: ``Exp(alpha=0.0, beta=1.0)`` — additive :math:`\alpha`,
217+ multiplicative :math:`\beta` inside the exponential.
414218 """
415219 def __init__ (self , alpha : float = 0.0 , beta : float = 1.0 ):
416220 super ().__init__ (alpha , beta )
0 commit comments