@@ -105,7 +105,7 @@ Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`,
105105| Method | Input | Output (sampled) |
106106| -----------| ------------------------------------------------------------| ------------------|
107107| ` decode ` | tokens ` (1, 1) ` + input_pos ` (1,) ` + temperature ` (1,) ` | ` (1, 1) ` float |
108- | ` prefill ` | tokens ` (1, T) ` + input_pos ` (T,) ` + temperature ` (1,) ` , T∈[ 2 , min(max_seq_len-1, 2×sliding_window)] | ` (1, 1) ` float |
108+ | ` prefill ` | tokens ` (1, T) ` + input_pos ` (T,) ` + temperature ` (1,) ` , T∈[ 5 , min(max_seq_len-1, 2×sliding_window)] | ` (1, 1) ` float |
109109
110110Both methods share the same KV-cache buffers via
111111` MemoryPlanningPass(share_mutable_buffers=True) ` and
@@ -145,11 +145,11 @@ quantize_and_save.py export.py / inference.py
145145 | |
146146 quantize_weight() load (torchao safetensors)
147147 | |
148- Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked
148+ Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked (used directly)
149149 | |
150- save (torchao safetensors) pack_model()
150+ save (torchao safetensors) int4_dispatch routes to int4_plain_mm
151151 | |
152- model.safetensors Int4TilePackedTo4dTensor (runtime)
152+ model.safetensors dp4a decode / dequant+cuBLAS prefill
153153```
154154
155155` embed_tokens ` and ` lm_head ` start tied; they are untied before
@@ -159,18 +159,17 @@ lossless for index lookup).
159159
160160## Runtime buffer materialization
161161
162- After weight loading (via ` pack_model() ` or ` from_hf_checkpoint() ` ), the
163- model's KV caches, RoPE tables, and scalar constants are still on the meta
164- device. ` materialize_runtime_buffers(model, dtype, device) ` in ` model.py `
165- replaces them with real tensors:
162+ After weight loading (via ` from_hf_checkpoint() ` ), the model's KV caches,
163+ RoPE inv_freq buffers, and scalar constants are still on the meta device.
164+ ` materialize_runtime_buffers(model, dtype, device) ` in ` model.py ` replaces
165+ them with real tensors:
166166
167167- KV caches → zeros in ` dtype ` (bf16 for inference, bf16 for export)
168- - RoPE tables → computed per-layer (sliding vs full, different θ and head_dim )
168+ - ` inv_freq ` → moved to target device (cos/sin computed on the fly per forward )
169169- ` embed_normalizer ` , ` logit_softcap ` , ` cache_positions ` → scalar constants
170170
171171Called by ` export.py ` (device="cpu" for tracing) and ` inference.py `
172- (device="cuda" for eager execution). Having one function avoids duplicating
173- the RoPE computation and constant setup across scripts.
172+ (device="cuda" for eager execution).
174173
175174## Customizations vs. vLLM / transformers reference
176175
@@ -183,9 +182,10 @@ These exist solely to make the model exportable / efficient under ExecuTorch:
183182 via modulo and the attention mask reconstructs which slots are valid.
184183 Full-attention layers use a flat ` Gemma4KVCache ` sized to ` max_seq_len ` .
185184 Both use ` index_copy_(dim=2, ...) ` for trace-friendly updates.
186- - ** Per-layer RoPE tables** registered as ` persistent=False ` buffers (sliding
187- uses full RoPE, full uses proportional partial RoPE — head_dim and θ
188- differ, so the table is not shared).
185+ - ** On-the-fly RoPE** : stores only ` inv_freq ` per layer, computes cos/sin
186+ via ` torch.outer(positions, inv_freq) ` each forward. Saves memory vs
187+ precomputed ` [max_seq_len, head_dim] ` tables (sliding uses full RoPE,
188+ full uses proportional partial RoPE — head_dim and θ differ).
189189- ** On-device Gumbel-max sampling** so the exported program emits a token
190190 rather than a full logits tensor — keeps the runner GPU↔CPU traffic to a
191191 single float per step.
@@ -198,6 +198,6 @@ These exist solely to make the model exportable / efficient under ExecuTorch:
198198The numerically-sensitive math primitives are imported from
199199` examples.models.gemma4.text_decoder ` and shared with the Gemma 4 E2B/E4B
200200example: ` RMSNorm ` , ` RMSNormNoWeight ` , ` Gemma4MLP ` , ` Gemma4KVCache ` ,
201- ` precompute_freqs_cis ` , ` apply_rotary_emb ` . The 31B-specific pieces
202- (attention with K=V branch, decoder layer, top-level model with softcap +
203- sampling, checkpoint loader) live in ` model.py ` .
201+ ` apply_rotary_emb ` . The 31B-specific pieces (attention with K=V branch,
202+ decoder layer, top-level model with softcap + sampling, checkpoint loader)
203+ live in ` model.py ` .
0 commit comments