Summary
In the GQA (Grouped Query Attention) reshape operations across multiple modules, int(kg / self.num_kv_heads) uses float division (/) instead of integer division (//). When kg is not exactly divisible by num_kv_heads, Python's int() silently truncates the result, producing an incorrect reshape dimension with no error.
Affected files
| File |
Lines |
gemma/gm/nn/_modules.py |
244, 288 |
gemma/gm/nn/gemma3n/_modules.py |
342, 387 |
gemma/gm/nn/gemma4/_modules.py |
324, 363 |
gemma/research/t5gemma/modules.py |
240, 272 |
Example (from gm/nn/_modules.py)
# Current (buggy)
query_scaled = query_scaled.reshape(
(b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h)
)
# Fixed
query_scaled = query_scaled.reshape(
(b, t, self.num_kv_heads, kg // self.num_kv_heads, h)
)
Impact
If kg % num_kv_heads != 0, the truncated dimension causes the reshape to produce a tensor whose total element count does not match the original. This leads to:
- A
ValueError at runtime from JAX/NumPy (reshape size mismatch), or
- Silent incorrect computation if the truncation happens to produce a valid (but wrong) shape by coincidence
In practice, standard model configurations keep num_query_heads a multiple of num_kv_heads, so this bug is latent. However, custom or experimental configurations will hit it unexpectedly.
Fix
Replace all 8 occurrences of int(kg / self.num_kv_heads) with kg // self.num_kv_heads.
Summary
In the GQA (Grouped Query Attention) reshape operations across multiple modules,
int(kg / self.num_kv_heads)uses float division (/) instead of integer division (//). Whenkgis not exactly divisible bynum_kv_heads, Python'sint()silently truncates the result, producing an incorrect reshape dimension with no error.Affected files
gemma/gm/nn/_modules.pygemma/gm/nn/gemma3n/_modules.pygemma/gm/nn/gemma4/_modules.pygemma/research/t5gemma/modules.pyExample (from
gm/nn/_modules.py)Impact
If
kg % num_kv_heads != 0, the truncated dimension causes the reshape to produce a tensor whose total element count does not match the original. This leads to:ValueErrorat runtime from JAX/NumPy (reshape size mismatch), orIn practice, standard model configurations keep
num_query_headsa multiple ofnum_kv_heads, so this bug is latent. However, custom or experimental configurations will hit it unexpectedly.Fix
Replace all 8 occurrences of
int(kg / self.num_kv_heads)withkg // self.num_kv_heads.