@@ -45,7 +45,7 @@ def get_default_block_sizes(
4545) -> int :
4646 """Choose bt to balance pipelining and VMEM utilization to minimize latency
4747
48- Accounts for state scratch ``(bt, H_v, K, V)`` float32 , optional
48+ Accounts for state scratch ``(bt, H_v, K, V)`` of ``state_dtype`` , optional
4949 a_log / dt_bias, and bt-proportional tiles that ``emit_pipeline``
5050 double-buffers (q, k, v, g, b, o).
5151 """
@@ -60,7 +60,7 @@ def get_default_block_sizes(
6060 fixed_bits += 2 * H_v * num_lanes * 32 # dt_bias: (H_v, num_lanes) f32
6161
6262 # bt-proportional (in bits):
63- # state scratch: (2*bt, H_v, K, V) float32 (double buffer)
63+ # state scratch: (2*bt, H_v, K, V) state_dtype (double buffer)
6464 # pipeline tiles (×2 for emit_pipeline double buffering):
6565 # q(bt,H_qk,K) + k(bt,H_qk,K) -> 2·H_qk·K·ibits
6666 # g(bt,H_v,K) float32 -> H_v·K·32
@@ -401,7 +401,7 @@ def fused_decoding_gdn(
401401 k : jax .Array , # [T, H_qk, K]
402402 v : jax .Array , # [T, H_v, V]
403403 g : jax .Array , # [T, H_v, K] float32
404- initial_state : jax .Array , # [num_states, H_v, K, V] float32
404+ initial_state : jax .Array , # [num_states, H_v, K, V]
405405 state_indices : jax .Array , # [max_num_req] int32
406406 distribution : jax .Array , # [2] int32
407407 b : jax .Array | None , # [T, H_v, num_lanes] or None
@@ -421,7 +421,7 @@ def fused_decoding_gdn(
421421 k: Keys ``[T, H_qk, K]``.
422422 v: Values ``[T, H_v, V]``.
423423 g: Per-key gating ``[T, H_v, K]``, float32.
424- initial_state: State cache ``[num_states, H_v, K, V]`` float32.
424+ initial_state: State cache ``[num_states, H_v, K, V]``
425425 state_indices: ``i32[max_num_req]`` — indices into the state cache.
426426 distribution: ``i32[2]`` — ``(decode_end, total)``.
427427 b: Raw betas ``[T, H_v, num_lanes]`` (sigmoid applied inside kernel).
0 commit comments