Skip to content

fix(gqa): replace float division with integer division in GQA reshape#668

Open
Sumu004 wants to merge 1 commit into
google-deepmind:mainfrom
Sumu004:fix/gqa-integer-division
Open

fix(gqa): replace float division with integer division in GQA reshape#668
Sumu004 wants to merge 1 commit into
google-deepmind:mainfrom
Sumu004:fix/gqa-integer-division

Conversation

@Sumu004
Copy link
Copy Markdown

@Sumu004 Sumu004 commented May 29, 2026

Problem

In the GQA (Grouped Query Attention) reshape operations, the number of query heads per KV head group is computed as:

int(kg / self.num_kv_heads)   # float division → int()

This is semantically wrong for two reasons:

  1. Silent truncation: When kg % num_kv_heads != 0, float division produces a non-integer (e.g. 7/4 = 1.75), and int() silently truncates it to 1. The reshape proceeds with a wrong dimension — either silently corrupting the tensor shape or crashing with a confusing size-mismatch error.

  2. Semantic mismatch: Tensor dimensions are integers by definition. Using float division here obscures the intent and masks invalid configurations.

Standard Gemma model configs (where num_query_heads is always a multiple of num_kv_heads) happen to produce exact divisions, so this bug is latent in production — but it surfaces immediately with custom or experimental head configurations.

Fix

Replace all 8 occurrences of int(kg / self.num_kv_heads) with kg // self.num_kv_heads:

# Before
query_scaled = query_scaled.reshape(
    (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h)
)

# After
query_scaled = query_scaled.reshape(
    (b, t, self.num_kv_heads, kg // self.num_kv_heads, h)
)

Affected files

File Lines
gemma/gm/nn/_modules.py 244, 288
gemma/gm/nn/gemma4/_modules.py 326, 365
gemma/gm/nn/gemma3n/_modules.py 342, 387
gemma/research/t5gemma/modules.py 240, 272

Verification

Confirmed all 8 occurrences replaced, zero remaining instances of int(kg / self.num_kv_heads) in the codebase.

Fixes #641

In grouped query attention (GQA), the number of query heads per KV head
was computed as int(kg / self.num_kv_heads) using float division + int().

When kg is not exactly divisible by num_kv_heads, float division produces
a non-integer result that int() silently truncates. This yields an
incorrect reshape dimension with no error, causing silent shape corruption
or an unexpected crash in non-standard head configurations.

Replaced all 8 occurrences across 4 files with integer division (//),
which is semantically correct for integer tensor dimensions and makes the
intent explicit.

Affected files:
- gemma/gm/nn/_modules.py (lines 244, 288)
- gemma/gm/nn/gemma4/_modules.py (lines 326, 365)
- gemma/gm/nn/gemma3n/_modules.py (lines 342, 387)
- gemma/research/t5gemma/modules.py (lines 240, 272)

Fixes: google-deepmind#641
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug: float division instead of integer division in GQA reshape causes silent shape truncation

1 participant