You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<dd>Optional per-channel bias with shape (channels).</dd>
946
+
<dt><tt>past_state</tt> (optional) : T</dt>
947
+
<dd>Carry state from previous step. For ndim=1: (batch_size, channels, k_1 - 1). If not provided, padding is zero.</dd>
948
+
</dl>
949
+
950
+
#### Outputs
951
+
952
+
<dl>
953
+
<dt><tt>output</tt> : T</dt>
954
+
<dd>Convolution output with same shape as input.</dd>
955
+
<dt><tt>present_state</tt> : T</dt>
956
+
<dd>Updated carry state. For ndim=1: (batch_size, channels, k_1 - 1). Contains the last (k-1) values from the virtual input along the causal axis.</dd>
where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product.
2784
+
2785
+
Semantics: Equivalent to running the recurrent update sequentially for each token,
2786
+
but may be implemented using chunk-parallel algorithms for GPU efficiency.
2787
+
2788
+
#### Version
2789
+
2790
+
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
2791
+
2792
+
#### Attributes
2793
+
2794
+
<dl>
2795
+
<dt><tt>chunk_size</tt> : int</dt>
2796
+
<dd>Chunk size for the chunk-parallel WY decomposition during prefill (T>1). Tuning hint; does not affect output correctness.</dd>
2797
+
<dt><tt>kv_num_heads</tt> : int (required)</dt>
2798
+
<dd>Number of key/value heads. Always required.</dd>
2799
+
<dt><tt>q_num_heads</tt> : int (required)</dt>
2800
+
<dd>Number of query heads. Always required.</dd>
2801
+
<dt><tt>scale</tt> : float</dt>
2802
+
<dd>Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads and uses 1/sqrt(d_k). Set explicitly to override.</dd>
2803
+
<dt><tt>update_rule</tt> : string</dt>
2804
+
<dd>The update rule for the linear attention recurrence. One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.</dd>
2805
+
</dl>
2806
+
2807
+
#### Inputs (3 - 6)
2808
+
2809
+
<dl>
2810
+
<dt><tt>query</tt> : T</dt>
2811
+
<dd>Query vectors with 3D packed shape (B, T, H_q * d_k). Heads are packed into the last dimension.</dd>
2812
+
<dt><tt>key</tt> : T</dt>
2813
+
<dd>Key vectors with 3D packed shape (B, T, H_kv * d_k). Should be L2-normalized for delta/gated_delta modes.</dd>
2814
+
<dt><tt>value</tt> : T</dt>
2815
+
<dd>Value vectors with 3D packed shape (B, T, H_kv * d_v).</dd>
2816
+
<dt><tt>past_state</tt> (optional) : S</dt>
2817
+
<dd>Recurrent state from previous step with shape (B, H_kv, d_k, d_v). Always 4D. If not provided, defaults to zeros.</dd>
2818
+
<dt><tt>decay</tt> (optional) : T</dt>
2819
+
<dd>Exponential decay gate in log-space. 3D packed shape: (B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or (B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). Required for 'gated' and 'gated_delta' modes.</dd>
2820
+
<dt><tt>beta</tt> (optional) : T</dt>
2821
+
<dd>Update rate (sigmoid output). 3D packed shape: (B, T, H_kv) or (B, T, 1). Required for 'delta' and 'gated_delta' modes.</dd>
2822
+
</dl>
2823
+
2824
+
#### Outputs
2825
+
2826
+
<dl>
2827
+
<dt><tt>output</tt> : T</dt>
2828
+
<dd>Attention output with 3D packed shape (B, T, H_q * d_v).</dd>
2829
+
<dt><tt>present_state</tt> : S</dt>
2830
+
<dd>Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.</dd>
0 commit comments