Skip to content

Bug: float division instead of integer division in GQA reshape causes silent shape truncation #641

@tomohiro86

Description

@tomohiro86

Summary

In the GQA (Grouped Query Attention) reshape operations across multiple modules, int(kg / self.num_kv_heads) uses float division (/) instead of integer division (//). When kg is not exactly divisible by num_kv_heads, Python's int() silently truncates the result, producing an incorrect reshape dimension with no error.

Affected files

File Lines
gemma/gm/nn/_modules.py 244, 288
gemma/gm/nn/gemma3n/_modules.py 342, 387
gemma/gm/nn/gemma4/_modules.py 324, 363
gemma/research/t5gemma/modules.py 240, 272

Example (from gm/nn/_modules.py)

# Current (buggy)
query_scaled = query_scaled.reshape(
    (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h)
)

# Fixed
query_scaled = query_scaled.reshape(
    (b, t, self.num_kv_heads, kg // self.num_kv_heads, h)
)

Impact

If kg % num_kv_heads != 0, the truncated dimension causes the reshape to produce a tensor whose total element count does not match the original. This leads to:

  • A ValueError at runtime from JAX/NumPy (reshape size mismatch), or
  • Silent incorrect computation if the truncation happens to produce a valid (but wrong) shape by coincidence

In practice, standard model configurations keep num_query_heads a multiple of num_kv_heads, so this bug is latent. However, custom or experimental configurations will hit it unexpectedly.

Fix

Replace all 8 occurrences of int(kg / self.num_kv_heads) with kg // self.num_kv_heads.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions