Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 49 additions & 12 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import torch.nn.functional as F
from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
from executorch.examples.models.llama.norm import (
RMSNorm,
RMSNormGated,
ScalelessRMSNorm,
)
from executorch.examples.models.llama.rope import Rope


Expand Down Expand Up @@ -375,6 +379,9 @@ def __init__(
self.qk_norm_before_rope = args.qk_norm_before_rope
self.use_q_gate = args.use_q_gate
self.enable_dynamic_shape = args.enable_dynamic_shape
self.scale_query_by = args.scale_query_by
self.use_attn_o_gate = args.use_attn_o_gate
self.use_attn_o_norm = args.use_attn_o_norm
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)

# YOCO: Determine if this is a KV shared layer (receives shared KV from donor).
Expand Down Expand Up @@ -417,17 +424,26 @@ def __init__(
def _init_norms(self, args: ModelArgs) -> None:
"""Initialize QK normalization layers."""
if self.use_qk_norm:
self.q_norm_fn = RMSNorm(
self.head_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
if self.has_kv_weights:
self.k_norm_fn = RMSNorm(
if args.qk_norm_affine:
self.q_norm_fn = RMSNorm(
self.head_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
if self.has_kv_weights:
self.k_norm_fn = RMSNorm(
self.head_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
else:
self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
if self.has_kv_weights:
self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
if self.use_attn_o_norm:
self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
if self.use_attn_o_gate:
self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)

def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None:
"""Initialize Q/K/V/O projection layers."""
Expand Down Expand Up @@ -478,13 +494,17 @@ def _prepare_qkv_shared(

if self.use_qk_norm and self.qk_norm_before_rope:
q = self.q_norm_fn(q)
if self.scale_query_by != 1.0:
q = q * self.scale_query_by

# Apply RoPE to Q only (K already has RoPE from donor layer)
q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin)
q = q.transpose(1, 2)

if self.use_qk_norm and not self.qk_norm_before_rope:
q = self.q_norm_fn(q)
if self.scale_query_by != 1.0:
q = q * self.scale_query_by

return q, k, v

Expand All @@ -508,6 +528,8 @@ def _prepare_qkv(

if self.use_qk_norm and self.qk_norm_before_rope:
q = self.q_norm_fn(q)
if self.scale_query_by != 1.0:
q = q * self.scale_query_by
k = self.k_norm_fn(k)

q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
Expand All @@ -518,6 +540,8 @@ def _prepare_qkv(

if self.use_qk_norm and not self.qk_norm_before_rope:
q = self.q_norm_fn(q)
if self.scale_query_by != 1.0:
q = q * self.scale_query_by
k = self.k_norm_fn(k)

return q, k, v
Expand Down Expand Up @@ -582,8 +606,7 @@ def forward(
)

output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
if gate is not None:
output = output * torch.sigmoid(gate)
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)

if shared_kv is None and self.num_kv_shared_layers > 0:
update = {"kv_to_share": (k, v)}
Expand All @@ -602,13 +625,27 @@ def forward(
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
if gate is not None:
output = output * torch.sigmoid(gate)
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)

output = self.wo(output)

return output, None

def _apply_output_transforms(
self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int
) -> torch.Tensor:
if self.use_attn_o_norm or self.use_attn_o_gate:
output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim)
if self.use_attn_o_norm:
output_4d = self.o_norm(output_4d)
if self.use_attn_o_gate:
og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
output_4d = torch.sigmoid(og) * output_4d
output = output_4d.reshape(bsz, seqlen, -1)
if gate is not None:
output = output * torch.sigmoid(gate)
return output


def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
Expand Down
86 changes: 77 additions & 9 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# Please refer to README.md in the same folder for more information.

import math
from typing import Any, Dict, Optional, Tuple, Union

import torch
Expand All @@ -20,7 +21,11 @@
)
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.norm import (
RMSNorm,
RMSNormWithInputScale,
ScalelessRMSNorm,
)
from executorch.examples.models.llama.rope import Rope
from torch import nn

Expand Down Expand Up @@ -51,6 +56,26 @@ def _is_kv_shared_layer(
return layer_idx >= first_shared and first_shared > 0


class NormPreservingResidualConnection(nn.Module):
def __init__(
self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3
):
super().__init__()
self.eps = eps
self.temperature = temperature
p = max(0.0 + eps, min(1.0 - eps, init_scale))
init_param = math.log(p / (1.0 - p)) * temperature
self.gate = nn.Parameter(torch.full((dim,), init_param))

def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor:
dtype = stream.dtype
w = self.gate.view(*([1] * (stream.ndim - 1)), -1).float()
beta = torch.sigmoid(w / self.temperature)
alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta)
alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps))
return (alpha * stream.float() + beta * branch.float()).to(dtype)


class ConditionalFeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
Expand Down Expand Up @@ -99,7 +124,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class TransformerBlock(nn.Module):
def __init__(
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
self,
args: ModelArgs,
attention: Attention,
mlp_type: str = "default",
layer_id: int = 0,
):
"""
Transformer block with support for pre-norm and post-norm.
Expand All @@ -110,6 +139,7 @@ def __init__(
the attention type is registered in the ATTENTION_REGISTRY.
mlp_type (str): MLP type for this layer. "default" for standard
FFN, "skip" for no FFN block.
layer_id (int): layer index, used for residual gate initialization.
"""
super().__init__()
self.use_kv_cache = args.use_kv_cache
Expand All @@ -118,6 +148,7 @@ def __init__(
self.head_dim = args.head_dim
self.attention = attention
self.mlp_type = mlp_type.lower()
self.use_residual_gate = args.use_residual_gate

assert (
args.hidden_dim is not None
Expand Down Expand Up @@ -150,6 +181,20 @@ def __init__(
add_unit_offset=args.rms_norm_add_unit_offset,
)

if args.use_residual_gate:
attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5
ffn_init = 1.0 / (2 * layer_id + 2)
self.add_attn = NormPreservingResidualConnection(
dim=args.dim, init_scale=attn_init
)
self.add_ffn = NormPreservingResidualConnection(
dim=args.dim, init_scale=ffn_init
)
self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps)

if args.use_ffn_learnable_scales and self.mlp_type != "skip":
self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps)

@classmethod
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
"""
Expand All @@ -169,21 +214,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
mlp_type = args.mlp_type[layer_id]
cls = ATTENTION_REGISTRY[args.attention_type]
attention = cls(args, layer_id, rope, **args.attention_kwargs)
return TransformerBlock(args, attention, mlp_type=mlp_type)
return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id)

def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
h, attn_options_update = self.attention(
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
)
if not isinstance(self.attention, AttentionSkip):
h = x + h
if self.use_residual_gate:
if hasattr(self, "post_attn_norm"):
h = self.post_attn_norm(h)
h = self.add_attn(stream=x, branch=h)
else:
h = x + h

if self.mlp_type == "skip":
out = h
elif hasattr(self, "block_sparse_moe"):
out = h + self.block_sparse_moe(self.ffn_norm(h))
ffn_out = self.block_sparse_moe(self.ffn_norm(h))
if hasattr(self, "post_ffn_norm"):
ffn_out = self.post_ffn_norm(ffn_out)
if self.use_residual_gate:
out = self.add_ffn(stream=h, branch=ffn_out)
else:
out = h + ffn_out
else:
out = h + self.feed_forward(self.ffn_norm(h))
ffn_out = self.feed_forward(self.ffn_norm(h))
if hasattr(self, "post_ffn_norm"):
ffn_out = self.post_ffn_norm(ffn_out)
if self.use_residual_gate:
out = self.add_ffn(stream=h, branch=ffn_out)
else:
out = h + ffn_out
return out, attn_options_update


Expand Down Expand Up @@ -371,7 +433,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
and model_args.layer_types[layer_id] == "skip_attention"
):
attention = AttentionSkip()
transformer_block = TransformerBlock(model_args, attention)
transformer_block = TransformerBlock(
model_args, attention, layer_id=layer_id
)
layers.append(transformer_block)
elif (
model_args.layer_types
Expand All @@ -386,13 +450,17 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
attention = linear_cls(
model_args, layer_id, rope, **model_args.attention_kwargs
)
transformer_block = TransformerBlock(model_args, attention)
transformer_block = TransformerBlock(
model_args, attention, layer_id=layer_id
)
layers.append(transformer_block)
else:
attention = cls(
model_args, layer_id, rope, **model_args.attention_kwargs
) # pyre-ignore[45]
transformer_block = TransformerBlock(model_args, attention)
transformer_block = TransformerBlock(
model_args, attention, layer_id=layer_id
)
layers.append(transformer_block)

return Transformer(model_args, layers, rope)
12 changes: 12 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ class ModelArgs:
apply_output: bool = True # Use output layer (unembedding) inside the transformer
use_qk_norm: bool = False # apply normalization to q and k in the attention
qk_norm_before_rope: bool = False # when to apply qk norm
qk_norm_affine: bool = (
True # whether QK norm has learnable weight (False = scaleless)
)
residual_multiplier: Optional[float] = (
None # Scaling factor applied to the residual hidden states
)
Expand Down Expand Up @@ -162,6 +165,15 @@ class ModelArgs:
final_logit_softcapping: Optional[float] = None
attn_logit_softcapping: Optional[float] = None

# rlformers forward-pass features for on-device model parity
normalize_tok_embeddings: bool = False
scale_query_by: float = 1.0
use_attn_o_gate: bool = False
use_attn_o_norm: bool = False
use_residual_gate: bool = False
use_ffn_learnable_scales: bool = False
output_soft_cap_temp: Optional[float] = None

def __post_init__(self): # noqa: C901
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
Expand Down
43 changes: 24 additions & 19 deletions examples/models/llama/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,38 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)

Args:
x (torch.Tensor): The input tensor.
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
return output * (1.0 + self.weight.float()).type_as(x)
return output * self.weight.type_as(x)

Returns:
torch.Tensor: The normalized tensor.

"""
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
class ScalelessRMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps

def forward(self, x):
"""
Forward pass through the RMSNorm layer.
x_float = x.float()
return (
x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)
).type_as(x)

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.
class RMSNormWithInputScale(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.dim = dim
self.weight = torch.nn.Parameter(torch.ones(dim))

"""
output = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
return output * (1.0 + self.weight.float()).type_as(x)
return output * self.weight.type_as(x)
def forward(self, x):
scaled = self.weight * x
return F.rms_norm(scaled, (self.dim,), None, self.eps)


class RMSNormGated(nn.Module):
Expand Down
Loading
Loading