Skip to content

Commit e532c21

Browse files
authored
linear attention signature (#27842)
Proposal for CausalConvWithState and LinearAttention onnxruntime custom operator. This follows the proposal in onnx/onnx#7767.
1 parent edd9f58 commit e532c21

5 files changed

Lines changed: 2242 additions & 0 deletions

File tree

docs/ContribOperators.md

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Do not modify directly.*
1515
* <a href="#com.microsoft.BitmaskBiasDropout">com.microsoft.BitmaskBiasDropout</a>
1616
* <a href="#com.microsoft.BitmaskDropout">com.microsoft.BitmaskDropout</a>
1717
* <a href="#com.microsoft.CDist">com.microsoft.CDist</a>
18+
* <a href="#com.microsoft.CausalConvWithState">com.microsoft.CausalConvWithState</a>
1819
* <a href="#com.microsoft.ComplexMul">com.microsoft.ComplexMul</a>
1920
* <a href="#com.microsoft.ComplexMulConj">com.microsoft.ComplexMulConj</a>
2021
* <a href="#com.microsoft.ConvTransposeWithDynamicPads">com.microsoft.ConvTransposeWithDynamicPads</a>
@@ -49,6 +50,7 @@ Do not modify directly.*
4950
* <a href="#com.microsoft.GroupQueryAttention">com.microsoft.GroupQueryAttention</a>
5051
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
5152
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
53+
* <a href="#com.microsoft.LinearAttention">com.microsoft.LinearAttention</a>
5254
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
5355
* <a href="#com.microsoft.MatMulBnb4">com.microsoft.MatMulBnb4</a>
5456
* <a href="#com.microsoft.MatMulFpQ4">com.microsoft.MatMulFpQ4</a>
@@ -900,6 +902,68 @@ This version of the operator has been available since version 1 of the 'com.micr
900902
</dl>
901903

902904

905+
### <a name="com.microsoft.CausalConvWithState"></a><a name="com.microsoft.causalconvwithstate">**com.microsoft.CausalConvWithState**</a>
906+
907+
Stateful causal depthwise convolution, generalized to N spatial dimensions.
908+
909+
Used by Gated DeltaNet (Qwen3.5) and Mamba (Jamba, FalconMamba) as a preprocessing step.
910+
Replaces the 3-op pattern (Concat + Conv + Slice) with a single fused operation.
911+
912+
The convolution is causal (looks only at current and past positions along the last
913+
spatial dimension) and depthwise (each channel is convolved independently with its own kernel).
914+
915+
Input layout is channels-first: (batch_size, channels, ...).
916+
Weight layout: (channels, 1, k_1, ...) for depthwise convolution.
917+
The carry state stores the last (k-1) positions along the causal axis for incremental decode.
918+
919+
The ndim attribute generalizes the op to 1D, 2D, or 3D spatial dimensions. Causality is
920+
enforced on the last spatial dimension only.
921+
922+
The optional activation attribute supports fused SiLU/Swish activation.
923+
924+
#### Version
925+
926+
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
927+
928+
#### Attributes
929+
930+
<dl>
931+
<dt><tt>activation</tt> : string</dt>
932+
<dd>Fused activation function. One of: 'silu', 'swish', 'none'. Default is 'none'.</dd>
933+
<dt><tt>ndim</tt> : int</dt>
934+
<dd>Spatial dimensionality: 1, 2, or 3. Default is 1.</dd>
935+
</dl>
936+
937+
#### Inputs (2 - 4)
938+
939+
<dl>
940+
<dt><tt>input</tt> : T</dt>
941+
<dd>Input tensor with shape (batch_size, channels, ...). Channels-first layout. Spatial dims: 1D: (L,); 2D: (H, W); 3D: (D, H, W).</dd>
942+
<dt><tt>weight</tt> : T</dt>
943+
<dd>Depthwise convolution kernel with shape (channels, 1, k_1, ...). Spatial kernel sizes: (k_1, ..., k_ndim).</dd>
944+
<dt><tt>bias</tt> (optional) : T</dt>
945+
<dd>Optional per-channel bias with shape (channels).</dd>
946+
<dt><tt>past_state</tt> (optional) : T</dt>
947+
<dd>Carry state from previous step. For ndim=1: (batch_size, channels, k_1 - 1). If not provided, padding is zero.</dd>
948+
</dl>
949+
950+
#### Outputs
951+
952+
<dl>
953+
<dt><tt>output</tt> : T</dt>
954+
<dd>Convolution output with same shape as input.</dd>
955+
<dt><tt>present_state</tt> : T</dt>
956+
<dd>Updated carry state. For ndim=1: (batch_size, channels, k_1 - 1). Contains the last (k-1) values from the virtual input along the causal axis.</dd>
957+
</dl>
958+
959+
#### Type Constraints
960+
961+
<dl>
962+
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
963+
<dd>Constrain input and output types to float tensors.</dd>
964+
</dl>
965+
966+
903967
### <a name="com.microsoft.ComplexMul"></a><a name="com.microsoft.complexmul">**com.microsoft.ComplexMul**</a>
904968

905969
#### Version
@@ -2703,6 +2767,79 @@ This version of the operator has been available since version 1 of the 'com.micr
27032767
</dl>
27042768

27052769

2770+
### <a name="com.microsoft.LinearAttention"></a><a name="com.microsoft.linearattention">**com.microsoft.LinearAttention**</a>
2771+
2772+
Unified linear attention operator for autoregressive decoding (T=1) and prefill (T>1).
2773+
2774+
All inputs use 3D packed format [B, T, H*D]; q_num_heads and kv_num_heads are always
2775+
required. The op internally unpacks to 4D for computation.
2776+
2777+
The update_rule attribute selects the recurrence type:
2778+
- "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t
2779+
- "gated": S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t
2780+
- "delta": S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = scale * q_t^T S_t
2781+
- "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = scale * q_t^T S_t
2782+
2783+
where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product.
2784+
2785+
Semantics: Equivalent to running the recurrent update sequentially for each token,
2786+
but may be implemented using chunk-parallel algorithms for GPU efficiency.
2787+
2788+
#### Version
2789+
2790+
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
2791+
2792+
#### Attributes
2793+
2794+
<dl>
2795+
<dt><tt>chunk_size</tt> : int</dt>
2796+
<dd>Chunk size for the chunk-parallel WY decomposition during prefill (T>1). Tuning hint; does not affect output correctness.</dd>
2797+
<dt><tt>kv_num_heads</tt> : int (required)</dt>
2798+
<dd>Number of key/value heads. Always required.</dd>
2799+
<dt><tt>q_num_heads</tt> : int (required)</dt>
2800+
<dd>Number of query heads. Always required.</dd>
2801+
<dt><tt>scale</tt> : float</dt>
2802+
<dd>Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads and uses 1/sqrt(d_k). Set explicitly to override.</dd>
2803+
<dt><tt>update_rule</tt> : string</dt>
2804+
<dd>The update rule for the linear attention recurrence. One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.</dd>
2805+
</dl>
2806+
2807+
#### Inputs (3 - 6)
2808+
2809+
<dl>
2810+
<dt><tt>query</tt> : T</dt>
2811+
<dd>Query vectors with 3D packed shape (B, T, H_q * d_k). Heads are packed into the last dimension.</dd>
2812+
<dt><tt>key</tt> : T</dt>
2813+
<dd>Key vectors with 3D packed shape (B, T, H_kv * d_k). Should be L2-normalized for delta/gated_delta modes.</dd>
2814+
<dt><tt>value</tt> : T</dt>
2815+
<dd>Value vectors with 3D packed shape (B, T, H_kv * d_v).</dd>
2816+
<dt><tt>past_state</tt> (optional) : S</dt>
2817+
<dd>Recurrent state from previous step with shape (B, H_kv, d_k, d_v). Always 4D. If not provided, defaults to zeros.</dd>
2818+
<dt><tt>decay</tt> (optional) : T</dt>
2819+
<dd>Exponential decay gate in log-space. 3D packed shape: (B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or (B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). Required for 'gated' and 'gated_delta' modes.</dd>
2820+
<dt><tt>beta</tt> (optional) : T</dt>
2821+
<dd>Update rate (sigmoid output). 3D packed shape: (B, T, H_kv) or (B, T, 1). Required for 'delta' and 'gated_delta' modes.</dd>
2822+
</dl>
2823+
2824+
#### Outputs
2825+
2826+
<dl>
2827+
<dt><tt>output</tt> : T</dt>
2828+
<dd>Attention output with 3D packed shape (B, T, H_q * d_v).</dd>
2829+
<dt><tt>present_state</tt> : S</dt>
2830+
<dd>Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.</dd>
2831+
</dl>
2832+
2833+
#### Type Constraints
2834+
2835+
<dl>
2836+
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
2837+
<dd>Constrain input and output types to float tensors.</dd>
2838+
<dt><tt>S</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
2839+
<dd>Constrain state types to float tensors.</dd>
2840+
</dl>
2841+
2842+
27062843
### <a name="com.microsoft.LongformerAttention"></a><a name="com.microsoft.longformerattention">**com.microsoft.LongformerAttention**</a>
27072844

27082845
Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token

0 commit comments

Comments
 (0)