diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index d6dff173072..d43533b5a70 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -7,7 +7,11 @@ import torch.nn.functional as F from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated +from executorch.examples.models.llama.norm import ( + RMSNorm, + RMSNormGated, + ScalelessRMSNorm, +) from executorch.examples.models.llama.rope import Rope @@ -375,6 +379,9 @@ def __init__( self.qk_norm_before_rope = args.qk_norm_before_rope self.use_q_gate = args.use_q_gate self.enable_dynamic_shape = args.enable_dynamic_shape + self.scale_query_by = args.scale_query_by + self.use_attn_o_gate = args.use_attn_o_gate + self.use_attn_o_norm = args.use_attn_o_norm q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1) # YOCO: Determine if this is a KV shared layer (receives shared KV from donor). @@ -417,17 +424,26 @@ def __init__( def _init_norms(self, args: ModelArgs) -> None: """Initialize QK normalization layers.""" if self.use_qk_norm: - self.q_norm_fn = RMSNorm( - self.head_dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) - if self.has_kv_weights: - self.k_norm_fn = RMSNorm( + if args.qk_norm_affine: + self.q_norm_fn = RMSNorm( self.head_dim, eps=args.norm_eps, add_unit_offset=args.rms_norm_add_unit_offset, ) + if self.has_kv_weights: + self.k_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + else: + self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.has_kv_weights: + self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.use_attn_o_norm: + self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.use_attn_o_gate: + self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False) def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None: """Initialize Q/K/V/O projection layers.""" @@ -478,6 +494,8 @@ def _prepare_qkv_shared( if self.use_qk_norm and self.qk_norm_before_rope: q = self.q_norm_fn(q) + if self.scale_query_by != 1.0: + q = q * self.scale_query_by # Apply RoPE to Q only (K already has RoPE from donor layer) q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin) @@ -485,6 +503,8 @@ def _prepare_qkv_shared( if self.use_qk_norm and not self.qk_norm_before_rope: q = self.q_norm_fn(q) + if self.scale_query_by != 1.0: + q = q * self.scale_query_by return q, k, v @@ -508,6 +528,8 @@ def _prepare_qkv( if self.use_qk_norm and self.qk_norm_before_rope: q = self.q_norm_fn(q) + if self.scale_query_by != 1.0: + q = q * self.scale_query_by k = self.k_norm_fn(k) q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) @@ -518,6 +540,8 @@ def _prepare_qkv( if self.use_qk_norm and not self.qk_norm_before_rope: q = self.q_norm_fn(q) + if self.scale_query_by != 1.0: + q = q * self.scale_query_by k = self.k_norm_fn(k) return q, k, v @@ -582,8 +606,7 @@ def forward( ) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) - if gate is not None: - output = output * torch.sigmoid(gate) + output = self._apply_output_transforms(output, x, gate, bsz, seqlen) if shared_kv is None and self.num_kv_shared_layers > 0: update = {"kv_to_share": (k, v)} @@ -602,13 +625,27 @@ def forward( output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) output = output.transpose(1, 2).reshape(bsz, seqlen, -1) - if gate is not None: - output = output * torch.sigmoid(gate) + output = self._apply_output_transforms(output, x, gate, bsz, seqlen) output = self.wo(output) return output, None + def _apply_output_transforms( + self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int + ) -> torch.Tensor: + if self.use_attn_o_norm or self.use_attn_o_gate: + output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim) + if self.use_attn_o_norm: + output_4d = self.o_norm(output_4d) + if self.use_attn_o_gate: + og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim) + output_4d = torch.sigmoid(og) * output_4d + output = output_4d.reshape(bsz, seqlen, -1) + if gate is not None: + output = output * torch.sigmoid(gate) + return output + def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index e74ae810a02..d87eef3f906 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -7,6 +7,7 @@ # Please refer to README.md in the same folder for more information. +import math from typing import Any, Dict, Optional, Tuple, Union import torch @@ -20,7 +21,11 @@ ) from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import ( + RMSNorm, + RMSNormWithInputScale, + ScalelessRMSNorm, +) from executorch.examples.models.llama.rope import Rope from torch import nn @@ -51,6 +56,26 @@ def _is_kv_shared_layer( return layer_idx >= first_shared and first_shared > 0 +class NormPreservingResidualConnection(nn.Module): + def __init__( + self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3 + ): + super().__init__() + self.eps = eps + self.temperature = temperature + p = max(0.0 + eps, min(1.0 - eps, init_scale)) + init_param = math.log(p / (1.0 - p)) * temperature + self.gate = nn.Parameter(torch.full((dim,), init_param)) + + def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor: + dtype = stream.dtype + w = self.gate.view(*([1] * (stream.ndim - 1)), -1).float() + beta = torch.sigmoid(w / self.temperature) + alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta) + alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps)) + return (alpha * stream.float() + beta * branch.float()).to(dtype) + + class ConditionalFeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -99,7 +124,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): def __init__( - self, args: ModelArgs, attention: Attention, mlp_type: str = "default" + self, + args: ModelArgs, + attention: Attention, + mlp_type: str = "default", + layer_id: int = 0, ): """ Transformer block with support for pre-norm and post-norm. @@ -110,6 +139,7 @@ def __init__( the attention type is registered in the ATTENTION_REGISTRY. mlp_type (str): MLP type for this layer. "default" for standard FFN, "skip" for no FFN block. + layer_id (int): layer index, used for residual gate initialization. """ super().__init__() self.use_kv_cache = args.use_kv_cache @@ -118,6 +148,7 @@ def __init__( self.head_dim = args.head_dim self.attention = attention self.mlp_type = mlp_type.lower() + self.use_residual_gate = args.use_residual_gate assert ( args.hidden_dim is not None @@ -150,6 +181,20 @@ def __init__( add_unit_offset=args.rms_norm_add_unit_offset, ) + if args.use_residual_gate: + attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5 + ffn_init = 1.0 / (2 * layer_id + 2) + self.add_attn = NormPreservingResidualConnection( + dim=args.dim, init_scale=attn_init + ) + self.add_ffn = NormPreservingResidualConnection( + dim=args.dim, init_scale=ffn_init + ) + self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps) + + if args.use_ffn_learnable_scales and self.mlp_type != "skip": + self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps) + @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": """ @@ -169,21 +214,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock": mlp_type = args.mlp_type[layer_id] cls = ATTENTION_REGISTRY[args.attention_type] attention = cls(args, layer_id, rope, **args.attention_kwargs) - return TransformerBlock(args, attention, mlp_type=mlp_type) + return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id) def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN h, attn_options_update = self.attention( self.attention_norm(x), freqs_cos, freqs_sin, **attn_options ) if not isinstance(self.attention, AttentionSkip): - h = x + h + if self.use_residual_gate: + if hasattr(self, "post_attn_norm"): + h = self.post_attn_norm(h) + h = self.add_attn(stream=x, branch=h) + else: + h = x + h if self.mlp_type == "skip": out = h elif hasattr(self, "block_sparse_moe"): - out = h + self.block_sparse_moe(self.ffn_norm(h)) + ffn_out = self.block_sparse_moe(self.ffn_norm(h)) + if hasattr(self, "post_ffn_norm"): + ffn_out = self.post_ffn_norm(ffn_out) + if self.use_residual_gate: + out = self.add_ffn(stream=h, branch=ffn_out) + else: + out = h + ffn_out else: - out = h + self.feed_forward(self.ffn_norm(h)) + ffn_out = self.feed_forward(self.ffn_norm(h)) + if hasattr(self, "post_ffn_norm"): + ffn_out = self.post_ffn_norm(ffn_out) + if self.use_residual_gate: + out = self.add_ffn(stream=h, branch=ffn_out) + else: + out = h + ffn_out return out, attn_options_update @@ -371,7 +433,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: and model_args.layer_types[layer_id] == "skip_attention" ): attention = AttentionSkip() - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock( + model_args, attention, layer_id=layer_id + ) layers.append(transformer_block) elif ( model_args.layer_types @@ -386,13 +450,17 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: attention = linear_cls( model_args, layer_id, rope, **model_args.attention_kwargs ) - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock( + model_args, attention, layer_id=layer_id + ) layers.append(transformer_block) else: attention = cls( model_args, layer_id, rope, **model_args.attention_kwargs ) # pyre-ignore[45] - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock( + model_args, attention, layer_id=layer_id + ) layers.append(transformer_block) return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 104e9fe2ddd..86818bb721a 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -102,6 +102,9 @@ class ModelArgs: apply_output: bool = True # Use output layer (unembedding) inside the transformer use_qk_norm: bool = False # apply normalization to q and k in the attention qk_norm_before_rope: bool = False # when to apply qk norm + qk_norm_affine: bool = ( + True # whether QK norm has learnable weight (False = scaleless) + ) residual_multiplier: Optional[float] = ( None # Scaling factor applied to the residual hidden states ) @@ -162,6 +165,15 @@ class ModelArgs: final_logit_softcapping: Optional[float] = None attn_logit_softcapping: Optional[float] = None + # rlformers forward-pass features for on-device model parity + normalize_tok_embeddings: bool = False + scale_query_by: float = 1.0 + use_attn_o_gate: bool = False + use_attn_o_norm: bool = False + use_residual_gate: bool = False + use_ffn_learnable_scales: bool = False + output_soft_cap_temp: Optional[float] = None + def __post_init__(self): # noqa: C901 if self.n_kv_heads is None: self.n_kv_heads = self.n_heads diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 0189c88b13b..8ad51d4594a 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -32,33 +32,38 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. + return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) - Args: - x (torch.Tensor): The input tensor. + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.add_unit_offset: + return output * (1.0 + self.weight.float()).type_as(x) + return output * self.weight.type_as(x) - Returns: - torch.Tensor: The normalized tensor. - """ - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) +class ScalelessRMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps def forward(self, x): - """ - Forward pass through the RMSNorm layer. + x_float = x.float() + return ( + x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps) + ).type_as(x) - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The output tensor after applying RMSNorm. +class RMSNormWithInputScale(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.dim = dim + self.weight = torch.nn.Parameter(torch.ones(dim)) - """ - output = self._norm(x.float()).type_as(x) - if self.add_unit_offset: - return output * (1.0 + self.weight.float()).type_as(x) - return output * self.weight.type_as(x) + def forward(self, x): + scaled = self.weight * x + return F.rms_norm(scaled, (self.dim,), None, self.eps) class RMSNormGated(nn.Module): diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index a8f342ffa63..72ce31438d6 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -15,6 +15,7 @@ ) from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.norm import ScalelessRMSNorm from executorch.examples.models.llama.rope import Rope @@ -797,6 +798,7 @@ def __init__( self.attention_qkv_bias = config.attention_qkv_bias self.use_qk_norm = config.use_qk_norm self.qk_norm_before_rope = config.qk_norm_before_rope + self.scale_query_by = getattr(config, "scale_query_by", 1.0) self.split_mha = split_mha self.is_kv_shared_layer = is_kv_shared_layer self.num_kv_shared_layers = config.num_kv_shared_layers @@ -896,11 +898,18 @@ def _init_wo(self, config: ModelArgs) -> None: def _init_qk_norms(self, config: ModelArgs, is_kv_shared_layer: bool) -> None: if self.use_qk_norm: - self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) - if is_kv_shared_layer: - self.k_norm = nn.Identity() + if getattr(config, "qk_norm_affine", True): + self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + if is_kv_shared_layer: + self.k_norm = nn.Identity() + else: + self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) else: - self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + self.q_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps) + if is_kv_shared_layer: + self.k_norm = nn.Identity() + else: + self.k_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps) else: self.q_norm = torch.nn.Identity() self.k_norm = torch.nn.Identity() @@ -928,6 +937,18 @@ def from_attention_mha( "contains LoRALinear modules. Use split_mha=False instead." ) + if getattr(other, "use_attn_o_gate", False) or getattr( + other, "use_attn_o_norm", False + ): + raise ValueError( + "StaticAttention does not support use_attn_o_gate or use_attn_o_norm. " + "These features require AttentionMHA." + ) + + qk_norm_affine = ( + hasattr(other.q_norm_fn, "weight") if other.use_qk_norm else True + ) + config = ModelArgs( dim=other.dim, n_layers=1, # Not used in attention layer @@ -939,8 +960,10 @@ def from_attention_mha( attention_qkv_bias=other.attention_qkv_bias, use_qk_norm=other.use_qk_norm, qk_norm_before_rope=other.qk_norm_before_rope, + qk_norm_affine=qk_norm_affine, norm_eps=other.q_norm_fn.eps if other.use_qk_norm else 1e-5, num_kv_shared_layers=getattr(other, "num_kv_shared_layers", 0), + scale_query_by=getattr(other, "scale_query_by", 1.0), ) instance = cls( @@ -1078,7 +1101,7 @@ def _apply_qk_norm(self, qs, ks=None, before_rope=False): before_rope (bool, optional): Whether to apply normalization before RoPE. Defaults to False. """ if self.use_qk_norm and before_rope == self.qk_norm_before_rope: - qs = [self.q_norm(q) for q in qs] + qs = [self.q_norm(q) * self.scale_query_by for q in qs] if ks is not None: ks = [self.k_norm(k) for k in ks] return qs, ks @@ -1306,12 +1329,12 @@ def _forward_mha( q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm(q) + q = self.q_norm(q) * self.scale_query_by q = self.rope(q, freqs_cos, freqs_sin) if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm(q) + q = self.q_norm(q) * self.scale_query_by k, v = shared_kv else: @@ -1320,14 +1343,14 @@ def _forward_mha( v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm(q) + q = self.q_norm(q) * self.scale_query_by k = self.k_norm(k) q = self.rope(q, freqs_cos, freqs_sin) k = self.rope(k, freqs_cos, freqs_sin) if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm(q) + q = self.q_norm(q) * self.scale_query_by k = self.k_norm(k) k, out_cache_state = self.k_caches[0].update( @@ -1430,14 +1453,17 @@ def load_weights_from_attention_mha( if other.use_qk_norm: self.use_qk_norm = True self.qk_norm_before_rope = other.qk_norm_before_rope - self.q_norm = rms_norm_class(other.q_norm_fn.dim, other.q_norm_fn.eps).to( - other.q_norm_fn.weight.dtype - ) - self.q_norm.load_state_dict(other.q_norm_fn.state_dict()) + self.scale_query_by = getattr(other, "scale_query_by", 1.0) + if hasattr(other.q_norm_fn, "weight"): + self.q_norm = rms_norm_class( + other.q_norm_fn.dim, other.q_norm_fn.eps + ).to(other.q_norm_fn.weight.dtype) + self.q_norm.load_state_dict(other.q_norm_fn.state_dict()) if ( not self.is_kv_shared_layer and hasattr(other, "k_norm_fn") and other.k_norm_fn is not None + and hasattr(other.k_norm_fn, "weight") ): self.k_norm = rms_norm_class( other.k_norm_fn.dim, other.k_norm_fn.eps