Skip to content

Commit 6273bb2

Browse files
committed
On-the-fly RoPE and fix imports after gemma4 upstream rebase
Adapt gemma4_31b to upstream gemma4 changes (33419c0) that removed precompute_freqs_cis in favor of on-the-fly RoPE computation: - Store inv_freq buffer instead of precomputed [max_seq_len, head_dim] cos/sin tables — saves memory, matches qwen3_5_moe and gemma4 E2B - Compute cos/sin per forward via torch.outer(positions, inv_freq) - Fix gemma4/text_decoder/__init__.py to remove stale precompute_freqs_cis re-export - Update model.md to reflect current architecture
1 parent 644ec6e commit 6273bb2

3 files changed

Lines changed: 35 additions & 48 deletions

File tree

examples/models/gemma4/text_decoder/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
apply_rotary_emb,
1111
apply_rotary_emb_single,
1212
Gemma4KVCache,
13-
precompute_freqs_cis,
1413
rotate_half,
1514
)
1615
from .gemma4_config import Gemma4Config # noqa: F401

examples/models/gemma4_31b/model.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`,
105105
| Method | Input | Output (sampled) |
106106
|-----------|------------------------------------------------------------|------------------|
107107
| `decode` | tokens `(1, 1)` + input_pos `(1,)` + temperature `(1,)` | `(1, 1)` float |
108-
| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[2, min(max_seq_len-1, 2×sliding_window)] | `(1, 1)` float |
108+
| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[5, min(max_seq_len-1, 2×sliding_window)] | `(1, 1)` float |
109109

110110
Both methods share the same KV-cache buffers via
111111
`MemoryPlanningPass(share_mutable_buffers=True)` and
@@ -145,11 +145,11 @@ quantize_and_save.py export.py / inference.py
145145
| |
146146
quantize_weight() load (torchao safetensors)
147147
| |
148-
Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked
148+
Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked (used directly)
149149
| |
150-
save (torchao safetensors) pack_model()
150+
save (torchao safetensors) int4_dispatch routes to int4_plain_mm
151151
| |
152-
model.safetensors Int4TilePackedTo4dTensor (runtime)
152+
model.safetensors dp4a decode / dequant+cuBLAS prefill
153153
```
154154

155155
`embed_tokens` and `lm_head` start tied; they are untied before
@@ -159,18 +159,17 @@ lossless for index lookup).
159159

160160
## Runtime buffer materialization
161161

162-
After weight loading (via `pack_model()` or `from_hf_checkpoint()`), the
163-
model's KV caches, RoPE tables, and scalar constants are still on the meta
164-
device. `materialize_runtime_buffers(model, dtype, device)` in `model.py`
165-
replaces them with real tensors:
162+
After weight loading (via `from_hf_checkpoint()`), the model's KV caches,
163+
RoPE inv_freq buffers, and scalar constants are still on the meta device.
164+
`materialize_runtime_buffers(model, dtype, device)` in `model.py` replaces
165+
them with real tensors:
166166

167167
- KV caches → zeros in `dtype` (bf16 for inference, bf16 for export)
168-
- RoPE tables → computed per-layer (sliding vs full, different θ and head_dim)
168+
- `inv_freq` → moved to target device (cos/sin computed on the fly per forward)
169169
- `embed_normalizer`, `logit_softcap`, `cache_positions` → scalar constants
170170

171171
Called by `export.py` (device="cpu" for tracing) and `inference.py`
172-
(device="cuda" for eager execution). Having one function avoids duplicating
173-
the RoPE computation and constant setup across scripts.
172+
(device="cuda" for eager execution).
174173

175174
## Customizations vs. vLLM / transformers reference
176175

@@ -183,9 +182,10 @@ These exist solely to make the model exportable / efficient under ExecuTorch:
183182
via modulo and the attention mask reconstructs which slots are valid.
184183
Full-attention layers use a flat `Gemma4KVCache` sized to `max_seq_len`.
185184
Both use `index_copy_(dim=2, ...)` for trace-friendly updates.
186-
- **Per-layer RoPE tables** registered as `persistent=False` buffers (sliding
187-
uses full RoPE, full uses proportional partial RoPE — head_dim and θ
188-
differ, so the table is not shared).
185+
- **On-the-fly RoPE**: stores only `inv_freq` per layer, computes cos/sin
186+
via `torch.outer(positions, inv_freq)` each forward. Saves memory vs
187+
precomputed `[max_seq_len, head_dim]` tables (sliding uses full RoPE,
188+
full uses proportional partial RoPE — head_dim and θ differ).
189189
- **On-device Gumbel-max sampling** so the exported program emits a token
190190
rather than a full logits tensor — keeps the runner GPU↔CPU traffic to a
191191
single float per step.
@@ -198,6 +198,6 @@ These exist solely to make the model exportable / efficient under ExecuTorch:
198198
The numerically-sensitive math primitives are imported from
199199
`examples.models.gemma4.text_decoder` and shared with the Gemma 4 E2B/E4B
200200
example: `RMSNorm`, `RMSNormNoWeight`, `Gemma4MLP`, `Gemma4KVCache`,
201-
`precompute_freqs_cis`, `apply_rotary_emb`. The 31B-specific pieces
202-
(attention with K=V branch, decoder layer, top-level model with softcap +
203-
sampling, checkpoint loader) live in `model.py`.
201+
`apply_rotary_emb`. The 31B-specific pieces (attention with K=V branch,
202+
decoder layer, top-level model with softcap + sampling, checkpoint loader)
203+
live in `model.py`.

examples/models/gemma4_31b/model.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
apply_rotary_emb,
5050
Gemma4KVCache,
5151
Gemma4MLP,
52-
precompute_freqs_cis,
5352
RMSNorm,
5453
RMSNormNoWeight,
5554
)
@@ -255,21 +254,22 @@ def __init__(self, config: Gemma4_31BConfig, layer_idx: int):
255254
# Precomputed RoPE table for this layer (per-layer because head_dim
256255
# and theta differ between sliding and full attention). For full
257256
# attention layers we pass freq_base_dim=head_dim so the zero-padded
258-
# inv_freq matches HF's "proportional" partial RoPE.
257+
# On-the-fly RoPE: store only inv_freq, compute cos/sin per forward.
258+
# Saves memory vs precomputed [max_seq_len, head_dim] tables.
259259
if self.is_sliding:
260260
rotary_dim = self.head_dim
261-
freq_base_dim = None
262261
else:
263262
rotary_dim = int(self.head_dim * self.partial_rotary)
264-
freq_base_dim = self.head_dim
265-
freqs_cos, freqs_sin = precompute_freqs_cis(
266-
rotary_dim,
267-
config.max_seq_len,
268-
theta=self.rope_theta,
269-
freq_base_dim=freq_base_dim,
263+
rope_angles = rotary_dim // 2
264+
inv_freq_rotated = 1.0 / (
265+
self.rope_theta ** (torch.arange(0, rotary_dim, 2).float() / self.head_dim)
270266
)
271-
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
272-
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
267+
nope_angles = self.head_dim // 2 - rope_angles
268+
if nope_angles > 0:
269+
inv_freq = torch.cat([inv_freq_rotated, torch.zeros(nope_angles)])
270+
else:
271+
inv_freq = inv_freq_rotated
272+
self.register_buffer("inv_freq", inv_freq, persistent=False)
273273

274274
# KV cache. Sliding layers use a ring buffer (2x window) to save
275275
# memory; full layers use a flat buffer (max_seq_len).
@@ -316,10 +316,11 @@ def forward(
316316
k = k.transpose(1, 2)
317317
v = v.transpose(1, 2)
318318

319-
# RoPE on Q and K only (V is not rotated). cos/sin are gathered for
320-
# the current positions to avoid baking the full table into the graph.
321-
cos = self.freqs_cos[input_pos]
322-
sin = self.freqs_sin[input_pos]
319+
# RoPE on Q and K only (V is not rotated). cos/sin computed on the fly.
320+
freqs = torch.outer(input_pos.float(), self.inv_freq)
321+
emb = torch.cat((freqs, freqs), dim=-1)
322+
cos = torch.cos(emb)
323+
sin = torch.sin(emb)
323324
q, k = apply_rotary_emb(q, k, cos, sin)
324325

325326
# Update cache and read back full K/V.
@@ -533,8 +534,7 @@ def from_hf_checkpoint(
533534
# and not in the checkpoint — those are the "expected" missing keys.
534535
runtime_prefixes = (
535536
".kv_cache.",
536-
".freqs_cos",
537-
".freqs_sin",
537+
".inv_freq",
538538
"embed_normalizer",
539539
"logit_softcap",
540540
"cache_positions",
@@ -675,19 +675,7 @@ def materialize_runtime_buffers(
675675

676676
for layer in model.layers:
677677
attn = layer.self_attn
678-
if attn.is_sliding:
679-
rotary_dim, freq_base_dim = attn.head_dim, None
680-
else:
681-
rotary_dim = int(attn.head_dim * attn.partial_rotary)
682-
freq_base_dim = attn.head_dim
683-
cos, sin = precompute_freqs_cis(
684-
rotary_dim,
685-
config.max_seq_len,
686-
theta=attn.rope_theta,
687-
freq_base_dim=freq_base_dim,
688-
)
689-
attn.register_buffer("freqs_cos", cos.to(device), persistent=False)
690-
attn.register_buffer("freqs_sin", sin.to(device), persistent=False)
678+
attn.inv_freq = attn.inv_freq.to(device)
691679

692680
model.register_buffer(
693681
"embed_normalizer",

0 commit comments

Comments
 (0)