Skip to content

Commit 0168ba1

Browse files
committed
Fix comments and format
Signed-off-by: Haowen Ning <hning@google.com>
1 parent 54aeb6b commit 0168ba1

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

tpu_inference/kernels/gdn/fused_gdn_decode_kernel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_default_block_sizes(
4545
) -> int:
4646
"""Choose bt to balance pipelining and VMEM utilization to minimize latency
4747
48-
Accounts for state scratch ``(bt, H_v, K, V)`` float32, optional
48+
Accounts for state scratch ``(bt, H_v, K, V)`` of ``state_dtype``, optional
4949
a_log / dt_bias, and bt-proportional tiles that ``emit_pipeline``
5050
double-buffers (q, k, v, g, b, o).
5151
"""
@@ -60,7 +60,7 @@ def get_default_block_sizes(
6060
fixed_bits += 2 * H_v * num_lanes * 32 # dt_bias: (H_v, num_lanes) f32
6161

6262
# bt-proportional (in bits):
63-
# state scratch: (2*bt, H_v, K, V) float32 (double buffer)
63+
# state scratch: (2*bt, H_v, K, V) state_dtype (double buffer)
6464
# pipeline tiles (×2 for emit_pipeline double buffering):
6565
# q(bt,H_qk,K) + k(bt,H_qk,K) -> 2·H_qk·K·ibits
6666
# g(bt,H_v,K) float32 -> H_v·K·32
@@ -401,7 +401,7 @@ def fused_decoding_gdn(
401401
k: jax.Array, # [T, H_qk, K]
402402
v: jax.Array, # [T, H_v, V]
403403
g: jax.Array, # [T, H_v, K] float32
404-
initial_state: jax.Array, # [num_states, H_v, K, V] float32
404+
initial_state: jax.Array, # [num_states, H_v, K, V]
405405
state_indices: jax.Array, # [max_num_req] int32
406406
distribution: jax.Array, # [2] int32
407407
b: jax.Array | None, # [T, H_v, num_lanes] or None
@@ -421,7 +421,7 @@ def fused_decoding_gdn(
421421
k: Keys ``[T, H_qk, K]``.
422422
v: Values ``[T, H_v, V]``.
423423
g: Per-key gating ``[T, H_v, K]``, float32.
424-
initial_state: State cache ``[num_states, H_v, K, V]`` float32.
424+
initial_state: State cache ``[num_states, H_v, K, V]``
425425
state_indices: ``i32[max_num_req]`` — indices into the state cache.
426426
distribution: ``i32[2]`` — ``(decode_end, total)``.
427427
b: Raw betas ``[T, H_v, num_lanes]`` (sigmoid applied inside kernel).

0 commit comments

Comments
 (0)