diff --git a/examples/models/gemma4/e2e_runner.cpp b/examples/models/gemma4/e2e_runner.cpp index dc9b32d6cc2..7ca5cf7c767 100644 --- a/examples/models/gemma4/e2e_runner.cpp +++ b/examples/models/gemma4/e2e_runner.cpp @@ -43,6 +43,11 @@ DEFINE_int32(max_new_tokens, 100, "Maximum tokens to generate."); DEFINE_int32(max_vision_tokens, 140, "Maximum soft tokens for vision encoder."); DEFINE_double(temperature, 0.0, "Sampling temperature (0.0 = greedy)."); DEFINE_int32(cpu_threads, -1, "Number of CPU threads. -1 = auto-detect."); +DEFINE_bool( + enable_workspace_sharing, + true, + "Enable XNNPACK PerModel workspace sharing + weight cache. " + "Pass --noenable_workspace_sharing to disable for debugging."); int32_t main(int32_t argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -62,7 +67,10 @@ int32_t main(int32_t argc, char** argv) { Gemma4Stats stats; stats.on_load_begin(); - Gemma4Runner runner(FLAGS_model_path, FLAGS_tokenizer_path); + Gemma4Runner runner( + FLAGS_model_path, + FLAGS_tokenizer_path, + FLAGS_enable_workspace_sharing); auto err = runner.load(); ET_CHECK_MSG(err == executorch::runtime::Error::Ok, "Failed to load model"); diff --git a/examples/models/gemma4/export_gemma4.py b/examples/models/gemma4/export_gemma4.py index 237a8454dcd..d59d6c82615 100644 --- a/examples/models/gemma4/export_gemma4.py +++ b/examples/models/gemma4/export_gemma4.py @@ -451,6 +451,7 @@ def _export_text_decoder( tied_embedding: bool = False, variant: str = "e2b", quantize_kv_cache: bool = False, + use_custom_sdpa: bool = True, ): """Export text decoder. Returns ExportedProgram.""" from executorch.examples.models.gemma4.quant_utils import ( @@ -467,6 +468,12 @@ def _export_text_decoder( config.use_kv_cache = True config.max_seq_len = max_seq_len config.enable_dynamic_shape = True + config.use_custom_sdpa = use_custom_sdpa + + if use_custom_sdpa: + from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + + logger.info("Custom SDPA enabled (tiled flash attention)") model_wrapper = Gemma4Model( config=config, checkpoint_path=checkpoint_path, dtype=torch.float32 @@ -507,7 +514,9 @@ def _export_text_decoder( ) logger.info("Replacing KV cache with INT8 quantized KV cache...") - model = replace_kv_cache_with_quantized_kv_cache(model) + model = replace_kv_cache_with_quantized_kv_cache( + model, use_custom_sdpa=use_custom_sdpa + ) model.eval() if linear_quant: @@ -549,6 +558,7 @@ def _export_components( quantize_kv_cache: bool, include_audio: bool, include_vision: bool, + use_custom_sdpa: bool, ) -> dict: """Export each requested component to an ExportedProgram.""" components = [] @@ -594,6 +604,7 @@ def _export_components( tied_embedding=tied_embedding, variant=variant, quantize_kv_cache=quantize_kv_cache, + use_custom_sdpa=use_custom_sdpa, ) return programs @@ -687,6 +698,7 @@ def export_single_pte( quantize_kv_cache: bool = False, include_audio: bool = True, include_vision: bool = True, + use_custom_sdpa: bool = True, ) -> Path: """Export components into a single PTE. @@ -714,6 +726,7 @@ def export_single_pte( quantize_kv_cache=quantize_kv_cache, include_audio=include_audio, include_vision=include_vision, + use_custom_sdpa=use_custom_sdpa, ) logger.info("Combining into single PTE...") @@ -840,6 +853,13 @@ def main(): default=False, help="Exclude vision_encoder method to reduce PTE size.", ) + parser.add_argument( + "--use_custom_sdpa", + action=argparse.BooleanOptionalAction, + default=True, + help="Route attention through llama::custom_sdpa (tiled flash attention). " + "Pass --no-use_custom_sdpa to fall back to matmul attention.", + ) args = parser.parse_args() export_single_pte( @@ -856,6 +876,7 @@ def main(): quantize_kv_cache=args.quantize_kv_cache, include_audio=not args.no_audio, include_vision=not args.no_vision, + use_custom_sdpa=args.use_custom_sdpa, ) diff --git a/examples/models/gemma4/runner/gemma4_runner.cpp b/examples/models/gemma4/runner/gemma4_runner.cpp index 4e90c8ee4dc..c457e51f064 100644 --- a/examples/models/gemma4/runner/gemma4_runner.cpp +++ b/examples/models/gemma4/runner/gemma4_runner.cpp @@ -10,9 +10,12 @@ #include +#include #include #include #include +#include +#include #include #include @@ -28,10 +31,36 @@ using ::executorch::runtime::EValue; Gemma4Runner::Gemma4Runner( const std::string& model_path, - const std::string& tokenizer_path) - : model_path_(model_path), tokenizer_path_(tokenizer_path) {} + const std::string& tokenizer_path, + bool enable_workspace_sharing) + : model_path_(model_path), + tokenizer_path_(tokenizer_path), + enable_workspace_sharing_(enable_workspace_sharing) {} Error Gemma4Runner::load() { + // Set XNNPACK workspace sharing explicitly. The compile-time default + // varies across build configurations, and set_option state is process- + // global, so always set it here to get the intended mode regardless of + // how the binary was built or what other code in the process did first. + { + auto mode = enable_workspace_sharing_ + ? ::executorch::backends::xnnpack::WorkspaceSharingMode::PerModel + : ::executorch::backends::xnnpack::WorkspaceSharingMode::Disabled; + ::executorch::runtime::BackendOptions<2> xnnpack_opts; + xnnpack_opts.set_option( + ::executorch::backends::xnnpack::weight_cache_option_key, + enable_workspace_sharing_); + xnnpack_opts.set_option( + ::executorch::backends::xnnpack::workspace_sharing_mode_option_key, + static_cast(mode)); + auto opts_status = ::executorch::runtime::set_option( + ::executorch::backends::xnnpack::xnnpack_backend_key, + xnnpack_opts.view()); + if (opts_status != Error::Ok) { + ET_LOG(Error, "Failed to set XNNPACK options"); + } + } + ET_LOG(Info, "Loading model: %s", model_path_.c_str()); module_ = std::make_unique(model_path_, Module::LoadMode::Mmap); diff --git a/examples/models/gemma4/runner/gemma4_runner.h b/examples/models/gemma4/runner/gemma4_runner.h index 33bdd72571d..7d858564616 100644 --- a/examples/models/gemma4/runner/gemma4_runner.h +++ b/examples/models/gemma4/runner/gemma4_runner.h @@ -39,7 +39,8 @@ class Gemma4Runner { public: Gemma4Runner( const std::string& model_path, - const std::string& tokenizer_path); + const std::string& tokenizer_path, + bool enable_workspace_sharing = true); Error load(); bool is_loaded() const; @@ -176,6 +177,7 @@ class Gemma4Runner { std::unique_ptr tokenizer_; std::string model_path_; std::string tokenizer_path_; + bool enable_workspace_sharing_; Error load_audio_methods(); Error load_vision_methods(); diff --git a/examples/models/gemma4/runner/gemma4_stats.h b/examples/models/gemma4/runner/gemma4_stats.h index 0dbad47b855..8194a791a2a 100644 --- a/examples/models/gemma4/runner/gemma4_stats.h +++ b/examples/models/gemma4/runner/gemma4_stats.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -198,7 +199,7 @@ struct Gemma4Stats { "\"prefill_ms\":%.1f,\"generation_ms\":%.1f," "\"num_prompt_tokens\":%d,\"num_generated_tokens\":%d," "\"prefill_tok_per_s\":%.1f,\"gen_tok_per_s\":%.1f," - "\"ttft_ms\":%.1f,\"total_ms\":%.1f," + "\"ttft_ms\":%.1f,\"total_ms\":%.1f,\"rtf\":%.2f," "\"rss_after_load_mb\":%.0f,\"rss_peak_gen_mb\":%.0f}", load_ms, speech_transform_ms, @@ -212,6 +213,7 @@ struct Gemma4Stats { tokens_per_second(), time_to_first_token_ms(), total_inference_ms(), + rtf(), rss_after_load_kb / 1024.0, rss_peak_gen_kb / 1024.0); return std::string(buf); @@ -226,7 +228,7 @@ struct Gemma4Stats { char line[256]; int64_t rss_kb = 0; while (fgets(line, sizeof(line), f)) { - if (sscanf(line, "VmRSS: %ld kB", &rss_kb) == 1) { + if (sscanf(line, "VmRSS: %" SCNd64 " kB", &rss_kb) == 1) { break; } } diff --git a/examples/models/gemma4/targets.bzl b/examples/models/gemma4/targets.bzl index e0af6d9c8ac..fd8179980a7 100644 --- a/examples/models/gemma4/targets.bzl +++ b/examples/models/gemma4/targets.bzl @@ -43,10 +43,9 @@ def define_common_targets(): "runner/gemma4_runner.cpp", ], visibility = ["PUBLIC"], - preprocessor_flags = [ - "-DENABLE_XNNPACK_SHARED_WORKSPACE", - ], deps = _KERNEL_BACKEND_DEPS + [ + "//executorch/backends/xnnpack:xnnpack_interface", + "//executorch/runtime/backend:interface", "//executorch/extension/llm/sampler:sampler", "//executorch/extension/module:module", "//executorch/extension/tensor:tensor", @@ -72,5 +71,5 @@ def define_common_targets(): ], visibility = ["PUBLIC"], compiler_flags = ["-Wno-global-constructors"], - preprocessor_flags = ["-DET_USE_THREADPOOL", "-DENABLE_XNNPACK_SHARED_WORKSPACE"], + preprocessor_flags = ["-DET_USE_THREADPOOL"], ) diff --git a/examples/models/gemma4/text_decoder/gemma4_attention.py b/examples/models/gemma4/text_decoder/gemma4_attention.py index 176f8442e18..77bd843d71c 100644 --- a/examples/models/gemma4/text_decoder/gemma4_attention.py +++ b/examples/models/gemma4/text_decoder/gemma4_attention.py @@ -28,48 +28,6 @@ from .gemma4_norm import RMSNorm, RMSNormNoWeight -def precompute_freqs_cis( - dim: int, - end: int, - theta: float = 10000.0, - freq_base_dim: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Precompute rotary position embeddings (RoPE). - - Uses HuggingFace-compatible format where cos/sin have full dim. - - Args: - dim: Dimension of the embeddings (rotary_dim, may be < head_dim for partial RoPE) - end: Maximum sequence length - theta: Base frequency for RoPE - freq_base_dim: Denominator for frequency computation. Defaults to dim. - For partial RoPE (Gemma 4 full attention), this should be the full - head_dim, not the rotary_dim, matching HF's proportional RoPE. - - Returns: - Tuple of (cos, sin) tensors of shape [end, dim] - """ - if freq_base_dim is None: - freq_base_dim = dim - rope_angles = dim // 2 - inv_freq_rotated = 1.0 / ( - theta ** (torch.arange(0, dim, 2).float() / freq_base_dim) - ) - # For partial RoPE: pad with zeros for non-rotated dims - nope_angles = freq_base_dim // 2 - rope_angles - if nope_angles > 0: - inv_freq = torch.cat([inv_freq_rotated, torch.zeros(nope_angles)]) - dim = freq_base_dim # Use full head_dim for cos/sin shape - else: - inv_freq = inv_freq_rotated - t = torch.arange(end, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - freqs_cos = torch.cos(emb) - freqs_sin = torch.sin(emb) - return freqs_cos, freqs_sin - - def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotates half the hidden dims of the input (HuggingFace style).""" x1 = x[..., : x.shape[-1] // 2] @@ -242,22 +200,24 @@ def __init__( else: self.rotary_dim = self.head_dim - # Precompute RoPE frequencies - # For partial RoPE, pass freq_base_dim=head_dim so zero-padded - # inv_freq produces full head_dim cos/sin matching HF's dimension pairing - freqs_cos, freqs_sin = precompute_freqs_cis( - self.rotary_dim, - config.max_seq_len, - theta=self.rope_theta, - freq_base_dim=self.head_dim, + # RoPE: store only inv_freq; cos/sin computed on the fly per forward. + # Partial RoPE pads with zeros for non-rotated dims so rotate_half pairs correctly. + rope_angles = self.rotary_dim // 2 + inv_freq_rotated = 1.0 / ( + self.rope_theta + ** (torch.arange(0, self.rotary_dim, 2).float() / self.head_dim) ) - self.register_buffer("freqs_cos", freqs_cos, persistent=False) - self.register_buffer("freqs_sin", freqs_sin, persistent=False) + nope_angles = self.head_dim // 2 - rope_angles + if nope_angles > 0: + inv_freq = torch.cat([inv_freq_rotated, torch.zeros(nope_angles)]) + else: + inv_freq = inv_freq_rotated + self.register_buffer("inv_freq", inv_freq, persistent=False) - # KV cache - self.use_index_copy = getattr(config, "use_index_copy_for_kv_cache", False) + # KV cache — skip allocation for shared layers (they use donor's KV) + self.use_index_copy = config.use_index_copy_for_kv_cache self.kv_cache: Optional[Gemma4KVCache] = None - if config.use_kv_cache: + if config.use_kv_cache and not self.is_kv_shared_layer: self.kv_cache = Gemma4KVCache( max_batch_size=config.max_batch_size, max_seq_len=config.max_seq_len, @@ -266,6 +226,10 @@ def __init__( use_index_copy=self.use_index_copy, ) + self.use_custom_sdpa = config.use_custom_sdpa + if self.use_custom_sdpa: + from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + # Mask buffers — shared across layers by Gemma4TextModel._share_masks() # to avoid duplicating identical [max_seq_len x max_seq_len] tensors. # Initialized here as fallback for standalone usage. @@ -300,6 +264,59 @@ def _apply_rope_single( """Apply RoPE to Q only (for shared KV layers).""" return apply_rotary_emb_single(q, freqs_cos, freqs_sin) + def _get_rope_freqs( + self, + input_pos: Optional[torch.Tensor], + seq_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute RoPE cos/sin from inv_freq for the current positions.""" + pos = input_pos if input_pos is not None else torch.arange(seq_len) + freqs = torch.outer(pos.float(), self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return torch.cos(emb), torch.sin(emb) + + def _slice_mask( + self, + base_mask: torch.Tensor, + input_pos: torch.Tensor, + seq_len: int, + kv_len: int, + ) -> torch.Tensor: + """Slice a [max_seq_len, max_seq_len] mask to current query positions x cache.""" + if self.use_index_copy: + return torch.index_select(base_mask, 0, input_pos).narrow(1, 0, kv_len) + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + torch._check(start_pos >= 0) + return base_mask.narrow(0, start_pos, seq_len).narrow(1, 0, kv_len) + + def _build_attn_mask( + self, + input_pos: Optional[torch.Tensor], + seq_len: int, + kv_len: int, + ) -> torch.Tensor: + """Combined causal + sliding-window mask for the current step.""" + using_cached_kv = ( + (self.kv_cache is not None or self.is_kv_shared_layer) + and input_pos is not None + and kv_len > seq_len + ) + if using_cached_kv: + mask = self._slice_mask(self.causal_mask, input_pos, seq_len, kv_len) + else: + mask = self.causal_mask[:seq_len, :seq_len] + + if self.sliding_window is not None and self.sliding_window_mask is not None: + if using_cached_kv: + sw_mask = self._slice_mask( + self.sliding_window_mask, input_pos, seq_len, kv_len + ) + else: + sw_mask = self.sliding_window_mask[:seq_len, :seq_len] + mask = mask + sw_mask + return mask + def forward( self, hidden_states: torch.Tensor, @@ -330,14 +347,7 @@ def forward( # For KV shared layers, use shared K/V from donor layer if self.is_kv_shared_layer and shared_kv is not None: k, v = shared_kv - - if input_pos is not None: - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] - else: - freqs_cos = self.freqs_cos[:seq_len] - freqs_sin = self.freqs_sin[:seq_len] - + freqs_cos, freqs_sin = self._get_rope_freqs(input_pos, seq_len) q = self._apply_rope_single(q, freqs_cos, freqs_sin) else: # Compute K, V projections @@ -356,17 +366,11 @@ def forward( v = self.v_norm(v) # Get RoPE frequencies - if input_pos is not None: - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] - else: - freqs_cos = self.freqs_cos[:seq_len] - freqs_sin = self.freqs_sin[:seq_len] + freqs_cos, freqs_sin = self._get_rope_freqs(input_pos, seq_len) # Apply RoPE (partial for full attention, full for sliding) q, k = self._apply_rope(q, k, freqs_cos, freqs_sin) - # Update KV cache if enabled (only for non-shared layers) if ( self.kv_cache is not None and input_pos is not None @@ -374,64 +378,62 @@ def forward( ): k, v = self.kv_cache.update(input_pos, k, v) - # Store K/V for sharing if this is a donor layer + # Lazy dequant for INT8 KV cache: do it once here, before both + # the donor-share path (cross-decoder layers can't see scales) and + # the custom_sdpa branch. Using basic torch ops keeps it inside + # the XNNPACK partition (no quantized_decomposed graph break). + if ( + isinstance(self.kv_cache, Gemma4QuantizedKVCache) + and not self.kv_cache.return_float_values + ): + k = k.to(torch.float32) * self.kv_cache.k_cache_scales + v = v.to(torch.float32) * self.kv_cache.v_cache_scales + kv_to_share: Optional[Tuple[torch.Tensor, torch.Tensor]] = None if self.is_kv_donor_layer: - kv_to_share = (k.clone(), v.clone()) + kv_to_share = (k, v) - # Expand KV for MQA/GQA - k = self._repeat_kv(k) - v = self._repeat_kv(v) + if self.use_custom_sdpa and input_pos is not None: + # Custom SDPA handles GQA/MQA natively (skips 8x KV expansion) + # and tiles attention so the [seq x seq] matrix never materializes. + kv_len = k.size(2) + start_pos = 0 if self.use_index_copy else input_pos[0].item() + attn_mask = self._build_attn_mask(input_pos, seq_len, kv_len) + + # custom_sdpa expects [bs, seq_len, n_heads, head_dim] + q_sdpa = q.transpose(1, 2) + k_sdpa = k.transpose(1, 2) + v_sdpa = v.transpose(1, 2) + + # custom_sdpa positional args: (q, k, v, start_pos, attn_mask, dropout, is_causal, scale). + # The op schema has a typo (`drpout_p`); avoid kwargs. + attn_output = torch.ops.llama.custom_sdpa( + q_sdpa, + k_sdpa, + v_sdpa, + start_pos, + attn_mask, + 0.0, + False, + self.scaling, + ) + attn_output = attn_output.view(batch_size, seq_len, -1) + else: + k = self._repeat_kv(k) + v = self._repeat_kv(v) + attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scaling - # Compute attention - attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scaling + if mask is None: + mask = self._build_attn_mask(input_pos, seq_len, k.size(2)) + + attn_weights = attn_weights + mask.unsqueeze(0).unsqueeze(0) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( + q + ) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_len, -1) - # Apply mask - if mask is None: - kv_len = k.size(2) - using_cached_kv = ( - self.kv_cache is not None or self.is_kv_shared_layer - ) and input_pos is not None - - if using_cached_kv and kv_len > seq_len: - cache_len = kv_len - if self.use_index_copy: - mask = torch.index_select(self.causal_mask, 0, input_pos).narrow( - 1, 0, cache_len - ) - else: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - torch._check(start_pos >= 0) - mask = self.causal_mask.narrow(0, start_pos, seq_len).narrow( - 1, 0, cache_len - ) - else: - mask = self.causal_mask[:seq_len, :seq_len] - - if self.sliding_window is not None and self.sliding_window_mask is not None: - if using_cached_kv and kv_len > seq_len: - if self.use_index_copy: - sw_mask = torch.index_select( - self.sliding_window_mask, 0, input_pos - ).narrow(1, 0, cache_len) - else: - sw_mask = self.sliding_window_mask.narrow( - 0, start_pos, seq_len - ).narrow(1, 0, cache_len) - else: - sw_mask = self.sliding_window_mask[:seq_len, :seq_len] - mask = mask + sw_mask - - attn_weights = attn_weights + mask.unsqueeze(0).unsqueeze(0) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(q) - - # Apply attention to values - attn_output = torch.matmul(attn_weights, v) - - # Reshape and project output - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, seq_len, -1) attn_output = self.o_proj(attn_output) return attn_output, kv_to_share @@ -440,13 +442,14 @@ def forward( class Gemma4QuantizedKVCache(nn.Module): """INT8 Quantized Key-Value cache for Gemma4. - Stores K and V tensors as int8 with per-token scales. + Stores K and V tensors as int8 with symmetric per-token quantization. + Uses simple torch ops (abs/amax, div, mul) that XNNPACK can fuse, + avoiding quantized_decomposed ops that break graph partitioning. - Args: - max_batch_size: Maximum batch size - max_seq_len: Maximum sequence length - num_kv_heads: Number of key-value heads - head_dim: Dimension per head + When return_float_values=False, returns raw INT8 K/V + exposes scales + as attributes. The attention module does lazy inline dequant with + simple ops right before custom_sdpa, keeping everything in one + XNNPACK partition. """ def __init__( @@ -457,12 +460,14 @@ def __init__( head_dim: int, dtype: torch.dtype = torch.float32, use_index_copy: bool = False, + return_float_values: bool = True, ): super().__init__() self.max_seq_len = max_seq_len self.head_dim = head_dim self.dtype = dtype self.use_index_copy = use_index_copy + self.return_float_values = return_float_values cache_shape = (max_batch_size, num_kv_heads, max_seq_len, head_dim) scale_shape = (max_batch_size, num_kv_heads, max_seq_len, 1) @@ -476,26 +481,13 @@ def __init__( "v_cache_scales", torch.ones(scale_shape, dtype=torch.float32) ) - self.k_cache.requires_grad_(False) - self.v_cache.requires_grad_(False) - self.k_cache_scales.requires_grad_(False) - self.v_cache_scales.requires_grad_(False) - def _quantize(self, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize tensor to int8 with symmetric per-token quantization.""" + """Symmetric per-token quantization using basic torch ops.""" amax = value.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) scales = amax / 127.0 quantized = (value / scales).round().clamp(-128, 127).to(torch.int8) return quantized, scales - def _dequantize( - self, - quantized: torch.Tensor, - scales: torch.Tensor, - ) -> torch.Tensor: - """Dequantize int8 tensor back to float.""" - return (quantized.to(torch.float32) * scales).to(self.dtype) - def update( self, input_pos: torch.Tensor, @@ -505,7 +497,9 @@ def update( """Update cache with new K, V values. Returns: - Tuple of (full_k, full_v) as dequantized float tensors + If return_float_values=True: (full_k, full_v) as dequantized float. + If return_float_values=False: (full_k_int8, full_v_int8) — caller + dequantizes lazily using k_cache_scales/v_cache_scales attributes. """ quantized_k, k_scales = self._quantize(k_val) quantized_v, v_scales = self._quantize(v_val) @@ -525,8 +519,12 @@ def update( self.k_cache_scales.narrow(2, start_pos, seq_len).copy_(k_scales) self.v_cache_scales.narrow(2, start_pos, seq_len).copy_(v_scales) - k_out = self._dequantize(self.k_cache, self.k_cache_scales) - v_out = self._dequantize(self.v_cache, self.v_cache_scales) + if not self.return_float_values: + return self.k_cache, self.v_cache + + # Legacy path: full dequant + overwrite current pos with float original + k_out = (self.k_cache.to(torch.float32) * self.k_cache_scales).to(self.dtype) + v_out = (self.v_cache.to(torch.float32) * self.v_cache_scales).to(self.dtype) if self.use_index_copy: k_out.index_copy_(2, input_pos, k_val) @@ -538,7 +536,9 @@ def update( return k_out, v_out @classmethod - def from_float(cls, kv_cache: Gemma4KVCache) -> "Gemma4QuantizedKVCache": + def from_float( + cls, kv_cache: Gemma4KVCache, return_float_values: bool = True + ) -> "Gemma4QuantizedKVCache": """Create quantized KV cache from float KV cache.""" max_batch_size, num_kv_heads, max_seq_len, head_dim = kv_cache.k_cache.shape dtype = kv_cache.k_cache.dtype @@ -549,23 +549,37 @@ def from_float(cls, kv_cache: Gemma4KVCache) -> "Gemma4QuantizedKVCache": head_dim, dtype, use_index_copy=kv_cache.use_index_copy, + return_float_values=return_float_values, ) -def replace_kv_cache_with_quantized_kv_cache(model: nn.Module) -> nn.Module: - """Replace Gemma4KVCache with Gemma4QuantizedKVCache in the model.""" - return _replace_kv_cache_with_quantized_kv_cache(model) +def replace_kv_cache_with_quantized_kv_cache( + model: nn.Module, + use_custom_sdpa: bool = False, +) -> nn.Module: + """Replace Gemma4KVCache with Gemma4QuantizedKVCache in the model. + + When use_custom_sdpa=True, the quantized cache returns raw INT8 tensors + for use with custom_quantized_sdpa (avoids full-cache dequant overhead). + """ + return_float = not use_custom_sdpa + return _replace_kv_cache_with_quantized_kv_cache(model, return_float) -def _replace_kv_cache_with_quantized_kv_cache(module: nn.Module) -> nn.Module: +def _replace_kv_cache_with_quantized_kv_cache( + module: nn.Module, + return_float_values: bool = True, +) -> nn.Module: """Recursively replace Gemma4KVCache with Gemma4QuantizedKVCache.""" for name, child in module.named_children(): if isinstance(child, Gemma4KVCache): setattr( module, name, - Gemma4QuantizedKVCache.from_float(child), + Gemma4QuantizedKVCache.from_float( + child, return_float_values=return_float_values + ), ) else: - _replace_kv_cache_with_quantized_kv_cache(child) + _replace_kv_cache_with_quantized_kv_cache(child, return_float_values) return module diff --git a/examples/models/gemma4/text_decoder/gemma4_config.py b/examples/models/gemma4/text_decoder/gemma4_config.py index b85dd0e441a..50ef78734f7 100644 --- a/examples/models/gemma4/text_decoder/gemma4_config.py +++ b/examples/models/gemma4/text_decoder/gemma4_config.py @@ -118,6 +118,9 @@ class Gemma4Config: enable_dynamic_shape: bool = False use_index_copy_for_kv_cache: bool = False + # Optimization flags + use_custom_sdpa: bool = True + @classmethod def from_json(cls, json_path: str) -> "Gemma4Config": """Load config from JSON file."""