From 6743b36d6b1fffc6832dda2a0afac2728ebd2cc5 Mon Sep 17 00:00:00 2001 From: YzXiao101 Date: Sun, 19 Apr 2026 09:24:15 -0400 Subject: [PATCH 1/5] feat: init gemma3 --- python/minisgl/attention/base.py | 30 +++++- python/minisgl/attention/fa.py | 13 ++- python/minisgl/attention/fi.py | 13 ++- python/minisgl/attention/trtllm.py | 13 ++- python/minisgl/layers/__init__.py | 6 +- python/minisgl/layers/activation.py | 8 +- python/minisgl/layers/attention.py | 17 ++- python/minisgl/layers/linear.py | 8 +- python/minisgl/layers/norm.py | 19 ++++ python/minisgl/layers/rotary.py | 14 ++- python/minisgl/models/config.py | 68 ++++++++++-- python/minisgl/models/gemma3.py | 161 ++++++++++++++++++++++++++++ python/minisgl/models/register.py | 3 +- python/minisgl/models/utils.py | 7 +- python/minisgl/models/weight.py | 9 +- 15 files changed, 362 insertions(+), 27 deletions(-) create mode 100644 python/minisgl/models/gemma3.py diff --git a/python/minisgl/attention/base.py b/python/minisgl/attention/base.py index 53f7d5dd..dbac1339 100644 --- a/python/minisgl/attention/base.py +++ b/python/minisgl/attention/base.py @@ -18,7 +18,15 @@ def get_last_indices(self, bs: int) -> torch.Tensor: ... class BaseAttnBackend(ABC): @abstractmethod def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + batch: Batch, + *, + window_size: tuple[int, int] = (-1, -1), + softmax_scale: float | None = None, ) -> torch.Tensor: ... @abstractmethod @@ -44,10 +52,26 @@ def __init__( self.decode_backend = decode_backend def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + batch: Batch, + *, + window_size: tuple[int, int] = (-1, -1), + softmax_scale: float | None = None, ) -> torch.Tensor: backend = self.prefill_backend if batch.is_prefill else self.decode_backend - return backend.forward(q, k, v, layer_id, batch) + return backend.forward( + q, + k, + v, + layer_id, + batch, + window_size=window_size, + softmax_scale=softmax_scale, + ) def prepare_metadata(self, batch: Batch) -> None: backend = self.prefill_backend if batch.is_prefill else self.decode_backend diff --git a/python/minisgl/attention/fa.py b/python/minisgl/attention/fa.py index ec17f37a..b6be5b8c 100644 --- a/python/minisgl/attention/fa.py +++ b/python/minisgl/attention/fa.py @@ -46,7 +46,15 @@ def __init__(self, config: ModelConfig): self.version = 4 if is_sm100_supported() else 3 def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + batch: Batch, + *, + window_size: tuple[int, int] = (-1, -1), + softmax_scale: float | None = None, ) -> torch.Tensor: metadata = batch.attn_metadata assert isinstance(metadata, FAMetadata) @@ -60,8 +68,9 @@ def forward( cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_k=metadata.cu_seqlens_k, max_seqlen_q=metadata.max_seqlen_q, - softmax_scale=self.scale, + softmax_scale=self.scale if softmax_scale is None else softmax_scale, version=self.version, + window_size=window_size, ) def prepare_metadata(self, batch: Batch) -> None: diff --git a/python/minisgl/attention/fi.py b/python/minisgl/attention/fi.py index f390137a..0b846095 100644 --- a/python/minisgl/attention/fi.py +++ b/python/minisgl/attention/fi.py @@ -174,11 +174,22 @@ def _get_ones_cpu(self, bs: int) -> torch.Tensor: return self.cached_ones_cpu[:bs] def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + batch: Batch, + *, + window_size: tuple[int, int] = (-1, -1), + softmax_scale: float | None = None, ) -> torch.Tensor: def _flatten_cache(cache: torch.Tensor) -> torch.Tensor: # treat page = 1 return cache.view(-1, 1, cache.shape[2], cache.shape[3]) + if window_size != (-1, -1) or softmax_scale is not None: + raise NotImplementedError + metadata = batch.attn_metadata assert isinstance(metadata, FIMetadata) self._initialize_metadata_once(metadata) diff --git a/python/minisgl/attention/trtllm.py b/python/minisgl/attention/trtllm.py index e780d548..4703d4d6 100644 --- a/python/minisgl/attention/trtllm.py +++ b/python/minisgl/attention/trtllm.py @@ -47,11 +47,22 @@ def __init__(self, config: ModelConfig): ) def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, batch: Batch + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + batch: Batch, + *, + window_size: tuple[int, int] = (-1, -1), + softmax_scale: float | None = None, ) -> torch.Tensor: from flashinfer.decode import trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache + if window_size != (-1, -1) or softmax_scale is not None: + raise NotImplementedError + metadata = batch.attn_metadata assert isinstance(metadata, TRTLLMMetadata) self.kvcache.store_kv(k, v, batch.out_loc, layer_id) diff --git a/python/minisgl/layers/__init__.py b/python/minisgl/layers/__init__.py index 2415af73..dce5700f 100644 --- a/python/minisgl/layers/__init__.py +++ b/python/minisgl/layers/__init__.py @@ -1,4 +1,4 @@ -from .activation import gelu_and_mul, silu_and_mul +from .activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from .attention import AttentionLayer from .base import BaseOP, OPList, StateLessOP from .embedding import ParallelLMHead, VocabParallelEmbedding @@ -10,12 +10,13 @@ LinearRowParallel, ) from .moe import MoELayer -from .norm import RMSNorm, RMSNormFused +from .norm import Gemma3RMSNorm, RMSNorm, RMSNormFused from .rotary import get_rope, set_rope_device __all__ = [ "silu_and_mul", "gelu_and_mul", + "gelu_tanh_and_mul", "AttentionLayer", "BaseOP", "StateLessOP", @@ -26,6 +27,7 @@ "LinearRowParallel", "LinearOProj", "LinearQKVMerged", + "Gemma3RMSNorm", "RMSNorm", "RMSNormFused", "get_rope", diff --git a/python/minisgl/layers/activation.py b/python/minisgl/layers/activation.py index 3e44b10c..f5dcf056 100644 --- a/python/minisgl/layers/activation.py +++ b/python/minisgl/layers/activation.py @@ -18,4 +18,10 @@ def gelu_and_mul(x: torch.Tensor, out: torch.Tensor | None = None): return gelu_and_mul(x, out=out) -__all__ = ["silu_and_mul", "gelu_and_mul"] +def gelu_tanh_and_mul(x: torch.Tensor, out: torch.Tensor | None = None): + from flashinfer import gelu_tanh_and_mul + + return gelu_tanh_and_mul(x, out=out) + + +__all__ = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"] diff --git a/python/minisgl/layers/attention.py b/python/minisgl/layers/attention.py index 7cd3ff61..104b0243 100644 --- a/python/minisgl/layers/attention.py +++ b/python/minisgl/layers/attention.py @@ -25,6 +25,8 @@ def __init__( rotary_config: RotaryConfig, q_norm: RMSNorm | None = None, k_norm: RMSNorm | None = None, + sliding_window_size: int | None = None, + softmax_scale: float | None = None, ): assert num_qo_heads % num_kv_heads == 0 self.layer_id = layer_id @@ -43,6 +45,11 @@ def __init__( ) self.q_norm = q_norm self.k_norm = k_norm + # sliding_window_size: HF-convention (inclusive). Converted to FA (left, right) here. + self._window_size = ( + (sliding_window_size - 1, 0) if sliding_window_size is not None else (-1, -1) + ) + self._softmax_scale = softmax_scale def forward(self, qkv: torch.Tensor) -> torch.Tensor: ctx = get_global_ctx() @@ -53,5 +60,13 @@ def forward(self, qkv: torch.Tensor) -> torch.Tensor: self.k_norm.forward_inplace(k.view(-1, self.num_kv_heads, self.head_dim)) q, k = self.rotary.forward(ctx.batch.positions, q, k) q = q.view(-1, self.num_qo_heads, self.head_dim) - o = ctx.attn_backend.forward(q, k, v, self.layer_id, ctx.batch) + o = ctx.attn_backend.forward( + q, + k, + v, + self.layer_id, + ctx.batch, + window_size=self._window_size, + softmax_scale=self._softmax_scale, + ) return o.view(-1, self.qo_attn_dim) diff --git a/python/minisgl/layers/linear.py b/python/minisgl/layers/linear.py index 7642f648..8ef595f1 100644 --- a/python/minisgl/layers/linear.py +++ b/python/minisgl/layers/linear.py @@ -100,9 +100,11 @@ def __init__(self, input_size: int, output_size: int, has_bias: bool): super().__init__(full_isize, full_osize, local_isize, local_osize, has_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - y = F.linear(x, self.weight, self.bias) + y = F.linear(x, self.weight, None) if self._tp_size > 1: y = self._comm.all_reduce(y) + if self.bias is not None: + y = y + self.bias return y @@ -121,7 +123,9 @@ def __init__( super().__init__(input_size, output_size, local_input_size, local_output_size, has_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - y = F.linear(x, self.weight, self.bias) + y = F.linear(x, self.weight, None) if self._tp_size > 1: y = self._comm.all_reduce(y) + if self.bias is not None: + y = y + self.bias return y diff --git a/python/minisgl/layers/norm.py b/python/minisgl/layers/norm.py index 53e9a0dd..a8a1bc7c 100644 --- a/python/minisgl/layers/norm.py +++ b/python/minisgl/layers/norm.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Tuple import torch @@ -36,3 +38,20 @@ def forward( return self.rmsnorm(x, self.weight, self.eps), x self.fused_add_rmsnorm(x, residual, self.weight, self.eps) return x, residual + + +class Gemma3RMSNorm(BaseOP): + + def __init__(self, size: int, eps: float) -> None: + from flashinfer import gemma_rmsnorm + + self.eps = eps + self.weight = torch.zeros(size) + self.gemma_rmsnorm = gemma_rmsnorm + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.gemma_rmsnorm(x, self.weight, self.eps) + + def forward_inplace(self, x: torch.Tensor) -> None: + shape = x.shape # [t, h, d] + x.copy_(self.gemma_rmsnorm(x.view(-1, shape[-1]), self.weight, self.eps).view(shape)) diff --git a/python/minisgl/layers/rotary.py b/python/minisgl/layers/rotary.py index 3d29dd97..dda2d24f 100644 --- a/python/minisgl/layers/rotary.py +++ b/python/minisgl/layers/rotary.py @@ -20,7 +20,6 @@ def __init__( ) -> None: super().__init__() self.head_size = head_size - assert rotary_dim == head_size inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) if post_process is not None: inv_freq = post_process(inv_freq) @@ -30,8 +29,8 @@ def __init__( sin = freqs.sin() # buffer, so don't load/save self._cos_sin_cache = torch.cat((cos, sin), dim=-1) - assert self.head_size in [64, 128, 256, 512] + assert self.head_size in [64, 128, 256, 512] from flashinfer import apply_rope_with_cos_sin_cache_inplace self.apply_rope_with_cos_sin_cache_inplace = apply_rope_with_cos_sin_cache_inplace @@ -97,7 +96,11 @@ def post_process(inv_freq: torch.Tensor) -> torch.Tensor: orig_max_pos: int = rope_scaling["original_max_position_embeddings"] def _find_correction_dim(num_rotations: float) -> float: - return rotary_dim * math.log(orig_max_pos / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) + return ( + rotary_dim + * math.log(orig_max_pos / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) low = max(math.floor(_find_correction_dim(beta_fast)), 0) high = min(math.ceil(_find_correction_dim(beta_slow)), rotary_dim // 2 - 1) @@ -105,7 +108,8 @@ def _find_correction_dim(num_rotations: float) -> float: def post_process(inv_freq: torch.Tensor) -> torch.Tensor: ramp = torch.clamp( (torch.arange(rotary_dim // 2, dtype=torch.float32) - low) / max(high - low, 1), - 0, 1, + 0, + 1, ) return (inv_freq / factor) * ramp + inv_freq * (1 - ramp) @@ -143,4 +147,4 @@ def get_rope( return _get_rope(head_dim, rotary_dim, max_position, base, rope_map) -__all__ = ["get_rope", "RotaryEmbedding", "set_rope_device"] \ No newline at end of file +__all__ = ["get_rope", "RotaryEmbedding", "set_rope_device"] diff --git a/python/minisgl/models/config.py b/python/minisgl/models/config.py index 3a9caf9b..4a5aec56 100644 --- a/python/minisgl/models/config.py +++ b/python/minisgl/models/config.py @@ -1,6 +1,8 @@ from __future__ import annotations -from dataclasses import dataclass + +from dataclasses import dataclass, field from typing import Any, Dict + from transformers import PretrainedConfig @@ -33,6 +35,16 @@ class ModelConfig: model_type: str architectures: list[str] + # ============================== Gemma3 ============================== + layer_types: list[str] = field(default_factory=list) + partial_rotary_factor: float = 1.0 + global_rope_theta: float | None = None + local_rope_theta: float | None = None + query_pre_attn_scalar: float | None = None + attention_bias: bool = False + sliding_window: int | None = None # raw HF value (inclusive) + # ============================== Gemma3 ============================== + @property def is_moe(self) -> bool: return "moe" in self.model_type @@ -47,7 +59,9 @@ def from_hf(cls, config: PretrainedConfig) -> ModelConfig: setattr(config, attr, getattr(top, attr)) num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + head_dim = ( + getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + ) tie_word_embeddings = getattr(config, "tie_word_embeddings", False) model_type = getattr(config, "model_type", "llama") num_experts = getattr(config, "num_local_experts", getattr(config, "num_experts", 0)) @@ -56,9 +70,44 @@ def from_hf(cls, config: PretrainedConfig) -> ModelConfig: norm_topk_prob = getattr(config, "norm_topk_prob", False) architectures = getattr(config, "architectures", ["LlamaForCausalLM"]) - # Llama/Qwen: rope_theta is a direct attr; Mistral: it's inside rope_scaling dict + # Llama/Qwen: rope_theta is a direct attr; + # Mistral: inside rope_scaling dict; + # Gemma3: neither — falls back to 10000.0 (Gemma3 model uses global/local_rope_theta instead) rope_scaling = getattr(config, "rope_scaling", None) - rope_theta = getattr(config, "rope_theta", None) or rope_scaling["rope_theta"] + rope_theta = getattr(config, "rope_theta", None) + if rope_theta is None: + rope_theta = (rope_scaling or {}).get("rope_theta", 10000.0) + + # Gemma3 uses hidden_activation; + # All other models use hidden_act + hidden_act = getattr(config, "hidden_act", None) or getattr( + config, "hidden_activation", "silu" + ) + + # ============================== Gemma3 ============================== + layer_types = list(getattr(config, "layer_types", None) or []) + partial_rotary_factor = float(getattr(config, "partial_rotary_factor", 1.0)) + qpas = getattr(config, "query_pre_attn_scalar", None) + query_pre_attn_scalar = float(qpas) if qpas is not None else None + attention_bias = bool(getattr(config, "attention_bias", False)) + # HF uses either sliding_window_size or sliding_window; + sliding_window = getattr(config, "sliding_window_size", None) or getattr( + config, "sliding_window", None + ) + # Dual RoPE for Gemma3: nested v5 or flat v4 format + rope_params = getattr(config, "rope_parameters", None) or {} + if isinstance(rope_params, dict) and "full_attention" in rope_params: + global_rope_theta = float(rope_params["full_attention"].get("rope_theta", 1000000.0)) + local_rope_theta = float(rope_params["sliding_attention"].get("rope_theta", 10000.0)) + elif rope_params or getattr(config, "rope_local_base_freq", None): + global_rope_theta = float( + rope_params.get("rope_theta", rope_theta) if rope_params else rope_theta + ) + local_rope_theta = float(getattr(config, "rope_local_base_freq", 10000.0)) + else: + global_rope_theta = None + local_rope_theta = None + # ============================== Gemma3 ============================== return cls( num_layers=config.num_hidden_layers, @@ -68,7 +117,7 @@ def from_hf(cls, config: PretrainedConfig) -> ModelConfig: hidden_size=config.hidden_size, vocab_size=config.vocab_size, intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, + hidden_act=hidden_act, rms_norm_eps=config.rms_norm_eps, tie_word_embeddings=tie_word_embeddings, rotary_config=RotaryConfig( @@ -84,4 +133,11 @@ def from_hf(cls, config: PretrainedConfig) -> ModelConfig: norm_topk_prob=norm_topk_prob, model_type=model_type, architectures=architectures, - ) \ No newline at end of file + layer_types=layer_types, + partial_rotary_factor=partial_rotary_factor, + global_rope_theta=global_rope_theta, + local_rope_theta=local_rope_theta, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + sliding_window=sliding_window, + ) diff --git a/python/minisgl/models/gemma3.py b/python/minisgl/models/gemma3.py new file mode 100644 index 00000000..a5ae6d7d --- /dev/null +++ b/python/minisgl/models/gemma3.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import torch +from minisgl.core import get_global_ctx +from minisgl.layers import ( + AttentionLayer, + BaseOP, + Gemma3RMSNorm, + LinearOProj, + LinearQKVMerged, + OPList, + ParallelLMHead, + VocabParallelEmbedding, +) +from minisgl.utils import nvtx_annotate + +from .base import BaseLLMModel +from .config import RotaryConfig +from .utils import GatedMLP as Gemma3MLP + +if TYPE_CHECKING: + from .config import ModelConfig + + +class Gemma3Attn(BaseOP): + def __init__(self, config: ModelConfig, layer_id: int): + head_dim = config.head_dim + is_sliding = ( + config.layer_types[layer_id] == "sliding_attention" if config.layer_types else False + ) + rotary_dim = int(config.partial_rotary_factor * head_dim) + rope_theta = ( + config.local_rope_theta if is_sliding else config.global_rope_theta + ) or config.rotary_config.base + softmax_scale = ( + config.query_pre_attn_scalar**-0.5 if config.query_pre_attn_scalar is not None else None + ) + sliding_window_size = config.sliding_window if is_sliding else None + assert not is_sliding or sliding_window_size is not None + + self.qkv_proj = LinearQKVMerged( + hidden_size=config.hidden_size, + head_dim=head_dim, + num_qo_heads=config.num_qo_heads, + num_kv_heads=config.num_kv_heads, + has_bias=config.attention_bias, + ) + self.q_norm = Gemma3RMSNorm(head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma3RMSNorm(head_dim, eps=config.rms_norm_eps) + self.attn = AttentionLayer( + layer_id=layer_id, + head_dim=head_dim, + num_qo_heads=config.num_qo_heads, + num_kv_heads=config.num_kv_heads, + rotary_config=RotaryConfig( + head_dim=head_dim, + rotary_dim=rotary_dim, + max_position=config.rotary_config.max_position, + base=rope_theta, + scaling=config.rotary_config.scaling, + ), + q_norm=self.q_norm, + k_norm=self.k_norm, + sliding_window_size=sliding_window_size, + softmax_scale=softmax_scale, + ) + self.o_proj = LinearOProj( + head_dim * config.num_qo_heads, + config.hidden_size, + has_bias=config.attention_bias, + ) + + @nvtx_annotate("MHA") + def forward(self, x: torch.Tensor) -> torch.Tensor: + qkv = self.qkv_proj.forward(x) + del x + o = self.attn.forward(qkv) + return self.o_proj.forward(o) + + +class Gemma3DecoderLayer(BaseOP): + def __init__(self, config: ModelConfig, layer_id: int): + self.self_attn = Gemma3Attn(config, layer_id) + self.mlp = Gemma3MLP(config) + self.input_layernorm = Gemma3RMSNorm( + size=config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = Gemma3RMSNorm( + size=config.hidden_size, + eps=config.rms_norm_eps, + ) + self.pre_feedforward_layernorm = Gemma3RMSNorm( + size=config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_feedforward_layernorm = Gemma3RMSNorm( + size=config.hidden_size, + eps=config.rms_norm_eps, + ) + + self._layer_id = layer_id + + @nvtx_annotate("Layer_{}", layer_id_field="_layer_id") + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.input_layernorm.forward(x) + x = self.self_attn.forward(x) + x = self.post_attention_layernorm.forward(x) + x = residual + x + + residual = x + x = self.pre_feedforward_layernorm.forward(x) + x = self.mlp.forward(x) + x = self.post_feedforward_layernorm.forward(x) + return residual + x + + +class Gemma3Model(BaseOP): + def __init__(self, config: ModelConfig): + self.embed_tokens = VocabParallelEmbedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + ) + self.layers = OPList( + [Gemma3DecoderLayer(config, layer_id) for layer_id in range(config.num_layers)] + ) + self.norm = Gemma3RMSNorm( + size=config.hidden_size, + eps=config.rms_norm_eps, + ) + self.embed_scale = math.sqrt(config.hidden_size) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + x = self.embed_tokens.forward(input_ids) * self.embed_scale + for layer in self.layers.op_list: + x = layer.forward(x) + return self.norm.forward(x) + + +class Gemma3ForCausalLM(BaseLLMModel): + def __init__(self, config: ModelConfig): + self.model = Gemma3Model(config) + self.lm_head = ParallelLMHead( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + tie_word_embeddings=config.tie_word_embeddings, + tied_embedding=self.model.embed_tokens if config.tie_word_embeddings else None, + ) + super().__init__() + + def forward(self) -> torch.Tensor: + output = self.model.forward(get_global_ctx().batch.input_ids) + logits = self.lm_head.forward(output) + return logits + + +__all__ = ["Gemma3ForCausalLM"] diff --git a/python/minisgl/models/register.py b/python/minisgl/models/register.py index 6417a708..59b29fbb 100644 --- a/python/minisgl/models/register.py +++ b/python/minisgl/models/register.py @@ -9,6 +9,7 @@ "Qwen3MoeForCausalLM": (".qwen3_moe", "Qwen3MoeForCausalLM"), "MistralForCausalLM": (".mistral", "MistralForCausalLM"), "Mistral3ForConditionalGeneration": (".mistral", "MistralForCausalLM"), + "Gemma3ForCausalLM": (".gemma3", "Gemma3ForCausalLM"), } @@ -21,4 +22,4 @@ def get_model_class(model_architecture: str, model_config: ModelConfig): return model_cls(model_config) -__all__ = ["get_model_class"] \ No newline at end of file +__all__ = ["get_model_class"] diff --git a/python/minisgl/models/utils.py b/python/minisgl/models/utils.py index 1b4a1d96..8f3474d7 100644 --- a/python/minisgl/models/utils.py +++ b/python/minisgl/models/utils.py @@ -13,6 +13,7 @@ MoELayer, RMSNorm, gelu_and_mul, + gelu_tanh_and_mul, silu_and_mul, ) from minisgl.models import ModelConfig @@ -30,7 +31,11 @@ def __init__(self, config: ModelConfig): has_bias=False, ) - FN_MAP = {"silu": silu_and_mul, "gelu": gelu_and_mul} + FN_MAP = { + "silu": silu_and_mul, + "gelu": gelu_and_mul, + "gelu_pytorch_tanh": gelu_tanh_and_mul, + } act_fn = FN_MAP.get(config.hidden_act, None) if act_fn is None: raise ValueError(f"Unsupported activation function: {config.hidden_act}") diff --git a/python/minisgl/models/weight.py b/python/minisgl/models/weight.py index 825ec045..069c64fc 100644 --- a/python/minisgl/models/weight.py +++ b/python/minisgl/models/weight.py @@ -41,6 +41,8 @@ def _shard_tensor(key: str, value: torch.Tensor, r: int, n: int, num_kv_heads: i return value[head_idx * head_dim : (head_idx + 1) * head_dim].clone() return value.chunk(n, dim=0)[r].clone() elif any(key.count(sub) for sub in _SPLIT_DIM_1): + if value.dim() < 2: # 1D bias: not sharded for row-parallel projections + return value return value.chunk(n, dim=1)[r].clone() elif key.count("lm_head") or key.count("embed_tokens"): num_embeddings = value.shape[0] @@ -88,7 +90,12 @@ def load_weight(model_path: str, device: torch.device) -> Iterator[Tuple[str, to expert_buf: Dict[str, Dict[int, torch.Tensor]] = {} for file in tqdm(files, desc="Loading weights", disable=not tp_info.is_primary()): with safetensors.safe_open(file, framework="pt", device=str(device)) as f: - for name in f.keys(): + # TODO: remove debug print + keys = list(f.keys()) + print( + f"\n[DEBUG] Weight file: {file}\n Keys ({len(keys)}):\n " + "\n ".join(keys) + ) + for name in keys: # Strip multimodal wrapper prefix, skip vision/projector weights if name.startswith(("vision_tower.", "multi_modal_projector.")): continue From 2136a9072977e393e6ed748e7d6eab84df8391ab Mon Sep 17 00:00:00 2001 From: YzXiao101 Date: Wed, 6 May 2026 05:38:51 -0400 Subject: [PATCH 2/5] fix: reshape non-contiguous view --- python/minisgl/layers/norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/minisgl/layers/norm.py b/python/minisgl/layers/norm.py index a8a1bc7c..f704bde8 100644 --- a/python/minisgl/layers/norm.py +++ b/python/minisgl/layers/norm.py @@ -54,4 +54,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def forward_inplace(self, x: torch.Tensor) -> None: shape = x.shape # [t, h, d] - x.copy_(self.gemma_rmsnorm(x.view(-1, shape[-1]), self.weight, self.eps).view(shape)) + x.copy_(self.gemma_rmsnorm(x.reshape(-1, shape[-1]), self.weight, self.eps).reshape(shape)) From e8d8451c61930232506b4e189451071306f285c1 Mon Sep 17 00:00:00 2001 From: YzXiao101 Date: Mon, 25 May 2026 13:18:15 -0400 Subject: [PATCH 3/5] feat: FI + gemma3 --- python/minisgl/attention/fi.py | 113 ++++++++++++++++++++++++++------ python/minisgl/engine/engine.py | 14 ++++ python/minisgl/models/config.py | 112 +++++++++++++++++++------------ python/minisgl/models/weight.py | 2 - 4 files changed, 178 insertions(+), 63 deletions(-) diff --git a/python/minisgl/attention/fi.py b/python/minisgl/attention/fi.py index 0b846095..195f3ea1 100644 --- a/python/minisgl/attention/fi.py +++ b/python/minisgl/attention/fi.py @@ -58,7 +58,17 @@ class FIMetadata(BaseAttnMetadata): pos_encoding_mode: str seq_lens_cpu: torch.Tensor # on cpu dtype: torch.dtype - wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDecodeWithPagedKVCacheWrapper + wrapper: ( + BatchPrefillWithPagedKVCacheWrapper + | BatchDecodeWithPagedKVCacheWrapper + | CUDAGraphBatchDecodeWithPagedKVCacheWrapper + ) + sliding_wrapper: ( + BatchPrefillWithPagedKVCacheWrapper + | BatchDecodeWithPagedKVCacheWrapper + | CUDAGraphBatchDecodeWithPagedKVCacheWrapper + | None + ) = None initialized: bool = False # fmt: on @@ -101,10 +111,28 @@ def __init__(self, config: ModelConfig) -> None: kv_layout="NHD", backend="fa2", # flashinfer fa3 is slow, use fa2 instead ) + self.sliding_prefill_wrapper = None + self.sliding_decode_wrappers = None + if self.config.has_sliding_attention: + self.sliding_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self.float_workspace_buffer, + kv_layout="NHD", + backend="fa2", + ) + self.sliding_decode_wrappers = BatchDecodeWithPagedKVCacheWrapper( + self.float_workspace_buffer, + use_tensor_cores=self.use_tensor_cores, + kv_layout="NHD", + backend="fa2", + ) # NOTE: some hack to reuse the int_workspace_buffer self.int_workspace_buffer = self.prefill_wrapper._int_workspace_buffer self.decode_wrappers._int_workspace_buffer = self.int_workspace_buffer + if self.sliding_prefill_wrapper is not None: + self.sliding_prefill_wrapper._int_workspace_buffer = self.int_workspace_buffer + if self.sliding_decode_wrappers is not None: + self.sliding_decode_wrappers._int_workspace_buffer = self.int_workspace_buffer # initialize some data members tp_size = get_tp_info().size @@ -120,18 +148,27 @@ def __init__(self, config: ModelConfig) -> None: self.last_event = torch.cuda.Event() self.last_event.record() - def _initialize_metadata_once(self, metadata: FIMetadata) -> None: - if metadata.initialized: - return - - from flashinfer import BatchDecodeWithPagedKVCacheWrapper + def _plan_wrapper( + self, + metadata: FIMetadata, + wrapper: ( + BatchPrefillWithPagedKVCacheWrapper + | BatchDecodeWithPagedKVCacheWrapper + | CUDAGraphBatchDecodeWithPagedKVCacheWrapper + ), + softmax_scale: float | None, + window_left: int, + ) -> None: + from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + CUDAGraphBatchDecodeWithPagedKVCacheWrapper, + ) - metadata.initialized = True - # FlashInfer planning reuses a pinned host staging buffer and launches an - # async H2D copy. Wait here before the next plan mutates that host buffer. - self.last_event.synchronize() - if isinstance(metadata.wrapper, BatchDecodeWithPagedKVCacheWrapper): - metadata.wrapper.plan( + if isinstance( + wrapper, + (BatchDecodeWithPagedKVCacheWrapper, CUDAGraphBatchDecodeWithPagedKVCacheWrapper), + ): + wrapper.plan( indptr=metadata.cu_seqlens_k_cpu, indices=metadata.indices, last_page_len=metadata.last_page_len_cpu, @@ -144,10 +181,12 @@ def _initialize_metadata_once(self, metadata: FIMetadata) -> None: data_type=metadata.dtype, q_data_type=metadata.dtype, kv_data_type=metadata.dtype, + sm_scale=softmax_scale, + window_left=window_left, non_blocking=True, ) else: - metadata.wrapper.plan( + wrapper.plan( qo_indptr=metadata.cu_seqlens_q_cpu, paged_kv_indptr=metadata.cu_seqlens_k_cpu, paged_kv_indices=metadata.indices, @@ -160,9 +199,31 @@ def _initialize_metadata_once(self, metadata: FIMetadata) -> None: seq_lens=metadata.seq_lens_cpu, q_data_type=metadata.dtype, kv_data_type=metadata.dtype, + sm_scale=softmax_scale, + window_left=window_left, non_blocking=True, causal=True, ) + + def _initialize_metadata_once( + self, metadata: FIMetadata, softmax_scale: float | None = None + ) -> None: + if metadata.initialized: + return + + # FlashInfer planning launches async H2D copies from host metadata. Wait + # before the next batch can reuse those host-side staging tensors. + self.last_event.synchronize() + self._plan_wrapper(metadata, metadata.wrapper, softmax_scale, window_left=-1) + if metadata.sliding_wrapper is not None: + assert self.config.sliding_window is not None + self._plan_wrapper( + metadata, + metadata.sliding_wrapper, + softmax_scale, + window_left=self.config.sliding_window - 1, + ) + metadata.initialized = True self.last_event.record() def _get_ones_cpu(self, bs: int) -> torch.Tensor: @@ -187,16 +248,21 @@ def forward( def _flatten_cache(cache: torch.Tensor) -> torch.Tensor: # treat page = 1 return cache.view(-1, 1, cache.shape[2], cache.shape[3]) - if window_size != (-1, -1) or softmax_scale is not None: - raise NotImplementedError - metadata = batch.attn_metadata assert isinstance(metadata, FIMetadata) - self._initialize_metadata_once(metadata) + if window_size == (-1, -1): + wrapper = metadata.wrapper + else: + assert self.config.sliding_window is not None + assert window_size == (self.config.sliding_window - 1, 0) + assert metadata.sliding_wrapper is not None + wrapper = metadata.sliding_wrapper + + self._initialize_metadata_once(metadata, softmax_scale) self.kvcache.store_kv(k, v, batch.out_loc, layer_id) kv_cache = (self.kvcache.k_cache(layer_id), self.kvcache.v_cache(layer_id)) kv_cache = (_flatten_cache(kv_cache[0]), _flatten_cache(kv_cache[1])) - return metadata.wrapper.run(q=q, paged_kv_cache=kv_cache) + return wrapper.run(q=q, paged_kv_cache=kv_cache) def prepare_metadata(self, batch: Batch) -> None: reqs = batch.padded_reqs @@ -219,6 +285,14 @@ def prepare_metadata(self, batch: Batch) -> None: cu_seqlens_q_cpu = torch.tensor([0] + seqlens_q, **CPU_KWARGS).cumsum_(dim=0) page_table = get_global_ctx().page_table + wrapper = self.decode_wrappers if batch.is_decode else self.prefill_wrapper + sliding_wrapper = None + if self.config.has_sliding_attention: + sliding_wrapper = ( + self.sliding_decode_wrappers if batch.is_decode else self.sliding_prefill_wrapper + ) + assert sliding_wrapper is not None + batch.attn_metadata = FIMetadata( cu_seqlens_q_cpu=cu_seqlens_q_cpu, cu_seqlens_k_cpu=cu_seqlens_k_cpu, @@ -232,7 +306,8 @@ def prepare_metadata(self, batch: Batch) -> None: pos_encoding_mode="NONE", seq_lens_cpu=seq_len_cpu, dtype=self.kvcache.dtype, - wrapper=self.decode_wrappers if batch.is_decode else self.prefill_wrapper, + wrapper=wrapper, + sliding_wrapper=sliding_wrapper, ) def init_capture_graph(self, max_seq_len: int, bs_list: List[int]) -> None: diff --git a/python/minisgl/engine/engine.py b/python/minisgl/engine/engine.py index ea29a96b..ae43a2fd 100644 --- a/python/minisgl/engine/engine.py +++ b/python/minisgl/engine/engine.py @@ -228,6 +228,20 @@ def override(attr: str, value: Any): # this is dangerous, use with caution override("page_size", 64) logger.warning_rank0("Page size is overridden to 64 for TRTLLM backend") + def _decode_backend_is_fi(attention_backend: str) -> bool: + return attention_backend.split(",", 1)[-1] == "fi" + + # FIXME(yzxiao): unlock FI + cuda graph for gemma3 + if config.model_config.has_sliding_attention and _decode_backend_is_fi( + config.attention_backend + ): + if config.cuda_graph_bs != [] or config.cuda_graph_max_bs != 0: + logger.warning_rank0( + "CUDA graph is disabled for sliding-attention models with FI decode " + ) + override("cuda_graph_bs", []) + override("cuda_graph_max_bs", 0) + if config.model_config.is_moe and config.moe_backend == "auto": override("moe_backend", "fused") logger.info_rank0(f"Auto-selected MoE backend: {config.moe_backend}") diff --git a/python/minisgl/models/config.py b/python/minisgl/models/config.py index 4a5aec56..bbc6469c 100644 --- a/python/minisgl/models/config.py +++ b/python/minisgl/models/config.py @@ -49,6 +49,10 @@ class ModelConfig: def is_moe(self) -> bool: return "moe" in self.model_type + @property + def has_sliding_attention(self) -> bool: + return "sliding_attention" in self.layer_types + @classmethod def from_hf(cls, config: PretrainedConfig) -> ModelConfig: if hasattr(config, "text_config") and config.text_config is not None: @@ -58,6 +62,15 @@ def from_hf(cls, config: PretrainedConfig) -> ModelConfig: if not getattr(config, attr, None) and getattr(top, attr, None): setattr(config, attr, getattr(top, attr)) + architectures = list(getattr(config, "architectures", ["LlamaForCausalLM"])) + model_type = getattr(config, "model_type", "llama") + if model_type == "gemma3_text" or "Gemma3ForCausalLM" in architectures: + return cls._from_gemma3_hf(config, architectures) + + return cls._from_basic_hf(config, architectures) + + @classmethod + def _from_basic_hf(cls, config: PretrainedConfig, architectures: list[str]) -> ModelConfig: num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) head_dim = ( getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads @@ -68,46 +81,10 @@ def from_hf(cls, config: PretrainedConfig) -> ModelConfig: num_experts_per_tok = getattr(config, "num_experts_per_tok", 0) moe_intermediate_size = getattr(config, "moe_intermediate_size", 0) norm_topk_prob = getattr(config, "norm_topk_prob", False) - architectures = getattr(config, "architectures", ["LlamaForCausalLM"]) - # Llama/Qwen: rope_theta is a direct attr; - # Mistral: inside rope_scaling dict; - # Gemma3: neither — falls back to 10000.0 (Gemma3 model uses global/local_rope_theta instead) + # Llama/Qwen: rope_theta is a direct attr; Mistral: it's inside rope_scaling dict rope_scaling = getattr(config, "rope_scaling", None) - rope_theta = getattr(config, "rope_theta", None) - if rope_theta is None: - rope_theta = (rope_scaling or {}).get("rope_theta", 10000.0) - - # Gemma3 uses hidden_activation; - # All other models use hidden_act - hidden_act = getattr(config, "hidden_act", None) or getattr( - config, "hidden_activation", "silu" - ) - - # ============================== Gemma3 ============================== - layer_types = list(getattr(config, "layer_types", None) or []) - partial_rotary_factor = float(getattr(config, "partial_rotary_factor", 1.0)) - qpas = getattr(config, "query_pre_attn_scalar", None) - query_pre_attn_scalar = float(qpas) if qpas is not None else None - attention_bias = bool(getattr(config, "attention_bias", False)) - # HF uses either sliding_window_size or sliding_window; - sliding_window = getattr(config, "sliding_window_size", None) or getattr( - config, "sliding_window", None - ) - # Dual RoPE for Gemma3: nested v5 or flat v4 format - rope_params = getattr(config, "rope_parameters", None) or {} - if isinstance(rope_params, dict) and "full_attention" in rope_params: - global_rope_theta = float(rope_params["full_attention"].get("rope_theta", 1000000.0)) - local_rope_theta = float(rope_params["sliding_attention"].get("rope_theta", 10000.0)) - elif rope_params or getattr(config, "rope_local_base_freq", None): - global_rope_theta = float( - rope_params.get("rope_theta", rope_theta) if rope_params else rope_theta - ) - local_rope_theta = float(getattr(config, "rope_local_base_freq", 10000.0)) - else: - global_rope_theta = None - local_rope_theta = None - # ============================== Gemma3 ============================== + rope_theta = getattr(config, "rope_theta", None) or rope_scaling["rope_theta"] return cls( num_layers=config.num_hidden_layers, @@ -117,7 +94,7 @@ def from_hf(cls, config: PretrainedConfig) -> ModelConfig: hidden_size=config.hidden_size, vocab_size=config.vocab_size, intermediate_size=config.intermediate_size, - hidden_act=hidden_act, + hidden_act=config.hidden_act, rms_norm_eps=config.rms_norm_eps, tie_word_embeddings=tie_word_embeddings, rotary_config=RotaryConfig( @@ -133,11 +110,62 @@ def from_hf(cls, config: PretrainedConfig) -> ModelConfig: norm_topk_prob=norm_topk_prob, model_type=model_type, architectures=architectures, + ) + + @classmethod + def _from_gemma3_hf(cls, config: PretrainedConfig, architectures: list[str]) -> ModelConfig: + if "Gemma3ForCausalLM" not in architectures: + raise ValueError("Only Gemma3ForCausalLM text models are supported") + + layer_types = list(getattr(config, "layer_types", None) or []) + if not layer_types: + sliding_window_pattern = getattr(config, "sliding_window_pattern", None) + layer_types = [ + "sliding_attention" if (i + 1) % sliding_window_pattern else "full_attention" + for i in range(config.num_hidden_layers) + ] + + sliding_window = getattr(config, "sliding_window_size", None) or getattr( + config, "sliding_window", None + ) + + num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = ( + getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + ) + global_rope_theta = float(getattr(config, "rope_theta", 1000000.0)) + local_rope_theta = float(getattr(config, "rope_local_base_freq", 10000.0)) + qpas = getattr(config, "query_pre_attn_scalar", None) + + return cls( + num_layers=config.num_hidden_layers, + num_qo_heads=config.num_attention_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + intermediate_size=config.intermediate_size, + hidden_act=getattr(config, "hidden_activation", "gelu_pytorch_tanh"), + rms_norm_eps=config.rms_norm_eps, + tie_word_embeddings=getattr(config, "tie_word_embeddings", False), + rotary_config=RotaryConfig( + head_dim=head_dim, + rotary_dim=head_dim, + max_position=config.max_position_embeddings, + base=global_rope_theta, + scaling=None, + ), + num_experts=0, + num_experts_per_tok=0, + moe_intermediate_size=0, + norm_topk_prob=False, + model_type=getattr(config, "model_type", "gemma3_text"), + architectures=architectures, layer_types=layer_types, - partial_rotary_factor=partial_rotary_factor, + partial_rotary_factor=float(getattr(config, "partial_rotary_factor", 1.0)), global_rope_theta=global_rope_theta, local_rope_theta=local_rope_theta, - query_pre_attn_scalar=query_pre_attn_scalar, - attention_bias=attention_bias, + query_pre_attn_scalar=float(qpas) if qpas is not None else None, + attention_bias=bool(getattr(config, "attention_bias", False)), sliding_window=sliding_window, ) diff --git a/python/minisgl/models/weight.py b/python/minisgl/models/weight.py index 069c64fc..bcc86cc2 100644 --- a/python/minisgl/models/weight.py +++ b/python/minisgl/models/weight.py @@ -41,8 +41,6 @@ def _shard_tensor(key: str, value: torch.Tensor, r: int, n: int, num_kv_heads: i return value[head_idx * head_dim : (head_idx + 1) * head_dim].clone() return value.chunk(n, dim=0)[r].clone() elif any(key.count(sub) for sub in _SPLIT_DIM_1): - if value.dim() < 2: # 1D bias: not sharded for row-parallel projections - return value return value.chunk(n, dim=1)[r].clone() elif key.count("lm_head") or key.count("embed_tokens"): num_embeddings = value.shape[0] From 31d9da8091a2b35824e35b25fcd6fb3081f6955c Mon Sep 17 00:00:00 2001 From: YzXiao101 Date: Tue, 26 May 2026 12:08:26 -0400 Subject: [PATCH 4/5] feat: fi + gemma3 --- python/minisgl/attention/fi.py | 7 ++++--- python/minisgl/models/config.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/minisgl/attention/fi.py b/python/minisgl/attention/fi.py index 195f3ea1..c81a4ed1 100644 --- a/python/minisgl/attention/fi.py +++ b/python/minisgl/attention/fi.py @@ -130,9 +130,10 @@ def __init__(self, config: ModelConfig) -> None: self.int_workspace_buffer = self.prefill_wrapper._int_workspace_buffer self.decode_wrappers._int_workspace_buffer = self.int_workspace_buffer if self.sliding_prefill_wrapper is not None: - self.sliding_prefill_wrapper._int_workspace_buffer = self.int_workspace_buffer - if self.sliding_decode_wrappers is not None: - self.sliding_decode_wrappers._int_workspace_buffer = self.int_workspace_buffer + assert self.sliding_decode_wrappers is not None + self.sliding_decode_wrappers._int_workspace_buffer = ( + self.sliding_prefill_wrapper._int_workspace_buffer + ) # initialize some data members tp_size = get_tp_info().size diff --git a/python/minisgl/models/config.py b/python/minisgl/models/config.py index bbc6469c..bece1083 100644 --- a/python/minisgl/models/config.py +++ b/python/minisgl/models/config.py @@ -51,7 +51,7 @@ def is_moe(self) -> bool: @property def has_sliding_attention(self) -> bool: - return "sliding_attention" in self.layer_types + return "sliding_attention" in self.layer_types and self.sliding_window is not None @classmethod def from_hf(cls, config: PretrainedConfig) -> ModelConfig: From ab2aed4ed81235948d4ad385fa30a4e4b0a62443 Mon Sep 17 00:00:00 2001 From: YzXiao101 Date: Thu, 28 May 2026 10:50:24 -0400 Subject: [PATCH 5/5] feat: fi + gemma3 --- python/minisgl/attention/fi.py | 59 +++++++++++++++++++++++++-------- python/minisgl/engine/engine.py | 14 -------- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/python/minisgl/attention/fi.py b/python/minisgl/attention/fi.py index c81a4ed1..60c9b100 100644 --- a/python/minisgl/attention/fi.py +++ b/python/minisgl/attention/fi.py @@ -43,6 +43,12 @@ def indices(self) -> torch.Tensor: return self.page_table +@dataclass +class _FIGraphWrappers: + wrapper: CUDAGraphBatchDecodeWithPagedKVCacheWrapper + sliding_wrapper: CUDAGraphBatchDecodeWithPagedKVCacheWrapper | None = None + + @dataclass class FIMetadata(BaseAttnMetadata): # fmt: off @@ -144,7 +150,12 @@ def __init__(self, config: ModelConfig) -> None: # for cuda graph self.capture_bs: List[int] = [] self.max_graph_bs = 0 - self.graph_wrappers: Dict[int, CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = {} + self.graph_wrappers: Dict[int, _FIGraphWrappers] = {} + self.graph_softmax_scale = ( + self.config.query_pre_attn_scalar**-0.5 + if self.config.query_pre_attn_scalar is not None + else None + ) self.capture: FICaptureData | None = None self.last_event = torch.cuda.Event() self.last_event.record() @@ -334,25 +345,45 @@ def prepare_for_capture(self, batch: Batch) -> None: bs = batch.size assert bs in self.capture_bs and bs not in self.graph_wrappers and self.capture capture = self.capture - self.graph_wrappers[bs] = CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - self.float_workspace_buffer, - kv_layout="NHD", - use_tensor_cores=self.use_tensor_cores, - indptr_buffer=capture.cu_seqlens_k[: bs + 1], - indices_buffer=capture.indices, - last_page_len_buffer=capture.one_tensor[:bs], + graph_wrappers = _FIGraphWrappers( + wrapper=CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self.float_workspace_buffer, + kv_layout="NHD", + use_tensor_cores=self.use_tensor_cores, + indptr_buffer=capture.cu_seqlens_k[: bs + 1], + indices_buffer=capture.indices, + last_page_len_buffer=capture.one_tensor[:bs], + ) ) - self.graph_wrappers[bs]._backend = "fa2" - self.graph_wrappers[bs]._int_workspace_buffer = self.int_workspace_buffer + graph_wrappers.wrapper._backend = "fa2" + graph_wrappers.wrapper._int_workspace_buffer = self.int_workspace_buffer + if self.config.has_sliding_attention: + assert self.sliding_prefill_wrapper is not None + graph_wrappers.sliding_wrapper = CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self.float_workspace_buffer, + kv_layout="NHD", + use_tensor_cores=self.use_tensor_cores, + indptr_buffer=capture.cu_seqlens_k[: bs + 1], + indices_buffer=capture.indices, + last_page_len_buffer=capture.one_tensor[:bs], + ) + graph_wrappers.sliding_wrapper._backend = "fa2" + graph_wrappers.sliding_wrapper._int_workspace_buffer = ( + self.sliding_prefill_wrapper._int_workspace_buffer + ) + self.graph_wrappers[bs] = graph_wrappers self.prepare_metadata(batch) metadata = batch.attn_metadata assert isinstance(metadata, FIMetadata) - metadata.wrapper = self.graph_wrappers[bs] - self._initialize_metadata_once(metadata) + metadata.wrapper = graph_wrappers.wrapper + metadata.sliding_wrapper = graph_wrappers.sliding_wrapper + self._initialize_metadata_once(metadata, self.graph_softmax_scale) def prepare_for_replay(self, batch: Batch) -> None: metadata, bs = batch.attn_metadata, batch.padded_size assert isinstance(metadata, FIMetadata) and not metadata.initialized assert self.capture is not None and bs in self.capture_bs - metadata.wrapper = self.graph_wrappers[bs] - self._initialize_metadata_once(metadata) + graph_wrappers = self.graph_wrappers[bs] + metadata.wrapper = graph_wrappers.wrapper + metadata.sliding_wrapper = graph_wrappers.sliding_wrapper + self._initialize_metadata_once(metadata, self.graph_softmax_scale) diff --git a/python/minisgl/engine/engine.py b/python/minisgl/engine/engine.py index ae43a2fd..ea29a96b 100644 --- a/python/minisgl/engine/engine.py +++ b/python/minisgl/engine/engine.py @@ -228,20 +228,6 @@ def override(attr: str, value: Any): # this is dangerous, use with caution override("page_size", 64) logger.warning_rank0("Page size is overridden to 64 for TRTLLM backend") - def _decode_backend_is_fi(attention_backend: str) -> bool: - return attention_backend.split(",", 1)[-1] == "fi" - - # FIXME(yzxiao): unlock FI + cuda graph for gemma3 - if config.model_config.has_sliding_attention and _decode_backend_is_fi( - config.attention_backend - ): - if config.cuda_graph_bs != [] or config.cuda_graph_max_bs != 0: - logger.warning_rank0( - "CUDA graph is disabled for sliding-attention models with FI decode " - ) - override("cuda_graph_bs", []) - override("cuda_graph_max_bs", 0) - if config.model_config.is_moe and config.moe_backend == "auto": override("moe_backend", "fused") logger.info_rank0(f"Auto-selected MoE backend: {config.moe_backend}")