diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py index 16b9d7c799..70f578fb76 100644 --- a/lmdeploy/pytorch/backends/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -55,6 +55,13 @@ class FopeParameters: inv_freq: torch.Tensor = None +@dataclass +class MropeParameters: + """MRoPE parameters.""" + mrope_section: list[int] + mrope_interleaved: bool = False + + class RotaryEmbeddingImpl(ABC): """Rotary embedding implementation api.""" diff --git a/lmdeploy/pytorch/models/glm4_1v.py b/lmdeploy/pytorch/models/glm4_1v.py index 7ab240208b..ac802fe6a2 100644 --- a/lmdeploy/pytorch/models/glm4_1v.py +++ b/lmdeploy/pytorch/models/glm4_1v.py @@ -2,7 +2,7 @@ # adapted from: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Iterable, Sequence from typing import Any import numpy as np @@ -23,35 +23,12 @@ from .utils.model import DeployModelMixin, vlm_model -def _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: list[int], - position_ids: torch.Tensor, rotary_emb_func: Callable): - _mrope_position_ids = torch.zeros(3, position_ids.shape[-1], dtype=position_ids.dtype, device=position_ids.device) - _mrope_position_ids[:, :mrope_position_ids.shape[-1]] = mrope_position_ids - cos, sin = rotary_emb_func(hidden_states, _mrope_position_ids) - _cos = torch.zeros(cos.shape[1], cos.shape[-1], dtype=cos.dtype, device=cos.device) - _sin = torch.zeros_like(_cos) - mrope_section = mrope_section * 2 - - def _apply_split(src, dst): - start = 0 - for i, m in enumerate(src.split(mrope_section, dim=-1)): - dst[:, start:start + mrope_section[i]] = m[i % 3] - start += mrope_section[i] - - _apply_split(cos, _cos) - _apply_split(sin, _sin) - - return _cos, _sin - - class Glm4vTextModel(nn.Module): def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.mrope_section = config.rope_scaling['mrope_section'] - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, @@ -92,8 +69,7 @@ def forward( cos, sin = self.rotary_emb(hidden_states, position_ids) cos, sin = cos[0], sin[0] else: - cos, sin = _apply_mrope_selection(hidden_states, mrope_position_ids, self.mrope_section, position_ids, - self.rotary_emb) + cos, sin = self.rotary_emb(hidden_states, mrope_position_ids) rotary_pos_emb = (cos, sin) # decoding diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index f01f82acdb..ff2bdb2e7d 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Iterable, Sequence from typing import Any import numpy as np @@ -32,27 +32,6 @@ from .utils.model import DeployModelMixinV1, build_embedding, vlm_model -def _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: list[int], - position_ids: torch.Tensor, rotary_emb_func: Callable): - _mrope_position_ids = torch.zeros(3, position_ids.shape[-1], dtype=position_ids.dtype, device=position_ids.device) - _mrope_position_ids[:, :mrope_position_ids.shape[-1]] = mrope_position_ids - cos, sin = rotary_emb_func(hidden_states, _mrope_position_ids) - _cos = torch.zeros(cos.shape[1], cos.shape[-1], dtype=cos.dtype, device=cos.device) - _sin = torch.zeros_like(_cos) - mrope_section = mrope_section * 2 - - def _apply_split(src, dst): - start = 0 - for i, m in enumerate(src.split(mrope_section, dim=-1)): - dst[:, start:start + mrope_section[i]] = m[i % 3] - start += mrope_section[i] - - _apply_split(cos, _cos) - _apply_split(sin, _sin) - - return _cos, _sin - - class Qwen2Attention(nn.Module): """Rewrite module of Qwen2Attention.""" @@ -246,8 +225,6 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.mrope_section = config.rope_scaling['mrope_section'] - self.embed_tokens = build_embedding( config.vocab_size, config.hidden_size, @@ -290,8 +267,7 @@ def forward( cos, sin = self.rotary_emb(hidden_states, position_ids) cos, sin = cos[0], sin[0] else: - cos, sin = _apply_mrope_selection(hidden_states, mrope_position_ids, self.mrope_section, position_ids, - self.rotary_emb) + cos, sin = self.rotary_emb(hidden_states, mrope_position_ids) rotary_pos_emb = (cos, sin) # decoding diff --git a/lmdeploy/pytorch/models/qwen3_5.py b/lmdeploy/pytorch/models/qwen3_5.py index 3abf30a77b..3e4958832b 100644 --- a/lmdeploy/pytorch/models/qwen3_5.py +++ b/lmdeploy/pytorch/models/qwen3_5.py @@ -9,13 +9,19 @@ import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update import lmdeploy.pytorch.nn.gated_delta as gated_delta_util from lmdeploy.pytorch.distributed import get_tp_world_rank from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RMSNorm, SiluAndMul +from lmdeploy.pytorch.nn import ( + ApplyRotaryEmb, + Attention, + LayerNorm, + RMSNorm, + SiluAndMul, + build_rotary_embedding_from_config, +) from lmdeploy.pytorch.nn.gated_delta import CausalConv1d, GatedDelta, GatedDeltaMeta, build_rmsnorm_gated from lmdeploy.pytorch.nn.linear import ( build_colwise_linear, @@ -24,7 +30,6 @@ build_qkv_proj, build_rowwise_linear, ) -from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight from lmdeploy.vl.constants import Modality @@ -805,99 +810,6 @@ def forward( return outputs -class Qwen3_5TextRotaryEmbedding(nn.Module): - inv_freq: torch.Tensor # fix linting for `register_buffer` - - def __init__(self, config: PretrainedConfig, device=None): - super().__init__() - rope_scaling = get_rope_parameters(config) - assert rope_scaling is not None, 'RoPE scaling parameters must be provided in the config for Qwen3.5 models.' - self.rope_type = rope_scaling.get('rope_type', 'default') - - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - if self.rope_type != 'default': - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - else: - self.rope_init_fn = self.compute_default_rope_parameters - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer('inv_freq', inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - self.mrope_section = rope_scaling.get('mrope_section', [11, 11, 10]) - - @staticmethod - def compute_default_rope_parameters( - config: PretrainedConfig | None = None, - device: torch.device | None = None, - seq_len: int | None = None, - ) -> tuple['torch.Tensor', float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PreTrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - rope_parameters = get_rope_parameters(config) - base = rope_parameters['rope_theta'] - partial_rotary_factor = rope_parameters.get('partial_rotary_factor', 1.0) - head_dim = getattr(config, 'head_dim', None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.int64).to(dtype=torch.float) / dim)) - inv_freq = inv_freq.to(device=device) - return inv_freq, attention_factor - - def apply_interleaved_mrope(self, freqs, mrope_section): - """Apply interleaved MRoPE to 3D rotary embeddings. - - Reorganizes frequency layout from chunked [TTT...HHH...WWW] to - interleaved [THTHWHTHW...TT], preserving frequency continuity. - args: - x: (3, bs, seq_len, head_dim // 2) - mrope_section: (3,) - returns: - x_t: (bs, seq_len, head_dim // 2) - """ - freqs_t = freqs[0] # just overwrite the first dimension T - for dim, offset in enumerate((1, 2), start=1): # H, W - length = mrope_section[dim] * 3 - idx = slice(offset, length, 3) - freqs_t[..., idx] = freqs[dim, ..., idx] - return freqs_t - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - # In contrast to other models, Qwen3VL has different position ids for the grids - # So we expand the inv_freq to shape (3, ...) - if position_ids.ndim == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) - position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) - freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class Qwen3_5TextModel(nn.Module): """qwen3.5 text model.""" @@ -932,7 +844,7 @@ def __init__(self, self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) # build rotary embedding - self.rotary_emb = Qwen3_5TextRotaryEmbedding(config, device=device) + self.rotary_emb = build_rotary_embedding_from_config(config, device=device) def forward( self, diff --git a/lmdeploy/pytorch/models/qwen3_5_moe.py b/lmdeploy/pytorch/models/qwen3_5_moe.py index 92bc7f75f7..616a4d01d0 100644 --- a/lmdeploy/pytorch/models/qwen3_5_moe.py +++ b/lmdeploy/pytorch/models/qwen3_5_moe.py @@ -9,7 +9,7 @@ from lmdeploy.pytorch.distributed import get_dist_manager from lmdeploy.pytorch.model_inputs import StepContextManager -from lmdeploy.pytorch.nn import RMSNorm +from lmdeploy.pytorch.nn import RMSNorm, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.moe import build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -23,7 +23,6 @@ Qwen3_5MLP, Qwen3_5Model, Qwen3_5TextModel, - Qwen3_5TextRotaryEmbedding, ) from .qwen3_5 import Qwen3_5VisionModel as Qwen3_5MoeVisionModel from .qwen3_vl import Qwen3VLInputProcessor as Qwen3_5MoeInputProcessor @@ -212,7 +211,7 @@ def __init__(self, self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) # build rotary embedding - self.rotary_emb = Qwen3_5TextRotaryEmbedding(config, device=device) + self.rotary_emb = build_rotary_embedding_from_config(config, device=device) class Qwen3_5MoeModel(Qwen3_5Model): diff --git a/lmdeploy/pytorch/models/qwen3_5_mtp.py b/lmdeploy/pytorch/models/qwen3_5_mtp.py index 02a4cae582..c192099292 100644 --- a/lmdeploy/pytorch/models/qwen3_5_mtp.py +++ b/lmdeploy/pytorch/models/qwen3_5_mtp.py @@ -8,12 +8,12 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import RMSNorm +from lmdeploy.pytorch.nn import RMSNorm, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.linear import build_colwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import add_prefix, get_build_model_context -from .qwen3_5 import Qwen3_5Attention, Qwen3_5DecoderLayer, Qwen3_5MLP, Qwen3_5TextRotaryEmbedding +from .qwen3_5 import Qwen3_5Attention, Qwen3_5DecoderLayer, Qwen3_5MLP from .qwen3_5_moe import Qwen3_5MoeSparseMoeBlock from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin @@ -112,7 +112,7 @@ def __init__( ) # build rotary embedding - self.rotary_emb = Qwen3_5TextRotaryEmbedding(config, device=device) + self.rotary_emb = build_rotary_embedding_from_config(config, device=device) def get_input_embeddings(self): """Get input embeddings.""" diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 9a0185e13b..db3f69af55 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -8,14 +8,12 @@ import torch from torch import nn from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.multimodal.data_type import MultiModalData -from lmdeploy.pytorch.nn import LayerNorm +from lmdeploy.pytorch.nn import LayerNorm, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_rowwise_linear -from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from lmdeploy.vl.constants import Modality @@ -27,77 +25,6 @@ from .utils.model import DeployModelMixinV1, vlm_model -class Qwen3VLTextRotaryEmbedding(nn.Module): - inv_freq: torch.Tensor # fix linting for `register_buffer` - - def __init__(self, config: PretrainedConfig, device=None): - super().__init__() - if hasattr(config, 'rope_scaling') and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get('rope_type', 'default') - else: - self.rope_type = 'default' - - self._pack_for_trans5(config) - - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer('inv_freq', inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - self.mrope_section = config.rope_scaling.get('mrope_section', [24, 20, 20]) - - def _pack_for_trans5(self, config): - if self.rope_type == 'default' and 'default' not in ROPE_INIT_FUNCTIONS: - # transformers 5 has removed default in ROPE_INIT_FUNCTIONS - self.rope_type = 'linear' - rope_parameters = get_rope_parameters(config) - if 'factor' not in rope_parameters: - rope_parameters['factor'] = 1.0 - - def apply_interleaved_mrope(self, freqs, mrope_section): - """Apply interleaved MRoPE to 3D rotary embeddings. - - Reorganizes frequency layout from chunked [TTT...HHH...WWW] to - interleaved [THTHWHTHW...TT], preserving frequency continuity. - args: - x: (3, bs, seq_len, head_dim // 2) - mrope_section: (3,) - returns: - x_t: (bs, seq_len, head_dim // 2) - """ - freqs_t = freqs[0] # just overwrite the first dimension T - for dim, offset in enumerate((1, 2), start=1): # H, W - length = mrope_section[dim] * 3 - idx = slice(offset, length, 3) - freqs_t[..., idx] = freqs[dim, ..., idx] - return freqs_t - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - # In contrast to other models, Qwen3VL has different position ids for the grids - # So we expand the inv_freq to shape (3, ...) - if position_ids.ndim == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) - position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != 'mps' else 'cpu' - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) - freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class Qwen3VLTextModel(Qwen3model): """Text part of Qwen3VL. @@ -112,8 +39,7 @@ def __init__(self, super().__init__(config=config, dtype=dtype, device=device, prefix=prefix) # build rotary embedding - # TODO: zhouxinyu, add triton kernel for interleaved mrope - self.rotary_emb = Qwen3VLTextRotaryEmbedding(config, device=device) + self.rotary_emb = build_rotary_embedding_from_config(config, device=device) def forward( self, diff --git a/lmdeploy/pytorch/models/qwen3_vl_moe.py b/lmdeploy/pytorch/models/qwen3_vl_moe.py index 9dd8263c4a..6c88125156 100644 --- a/lmdeploy/pytorch/models/qwen3_vl_moe.py +++ b/lmdeploy/pytorch/models/qwen3_vl_moe.py @@ -8,12 +8,12 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContextManager +from lmdeploy.pytorch.nn import build_rotary_embedding_from_config from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import add_prefix, get_build_model_context from .qwen3_moe import Qwen3MoeModel from .qwen3_vl import Qwen3VLForConditionalGeneration -from .qwen3_vl import Qwen3VLTextRotaryEmbedding as Qwen3VLMoeTextRotaryEmbedding class Qwen3VLMoeTextModel(Qwen3MoeModel): @@ -30,8 +30,7 @@ def __init__(self, super().__init__(config=config, dtype=dtype, device=device, prefix=prefix) # build rotary embedding - # TODO: zhouxinyu, add triton kernel for interleaved mrope - self.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(config, device=device) + self.rotary_emb = build_rotary_embedding_from_config(config, device=device) def forward( self, diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index 3fa8cfda81..98fb8d8ce4 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -11,6 +11,7 @@ FopeParameters, Llama3Parameters, LongRoPEScalingParameters, + MropeParameters, RopeType, YarnParameters, ) @@ -127,6 +128,19 @@ def _get_fope_parameters(config: PretrainedConfig): return dict(fope_params=params) +def _get_mrope_parameters(config: PretrainedConfig): + """Get mrope parameters.""" + rope_scaling = get_rope_parameters(config=config) + if rope_scaling is None or 'mrope_section' not in rope_scaling: + return dict() + + params = MropeParameters( + mrope_section=rope_scaling['mrope_section'], + mrope_interleaved=rope_scaling.get('mrope_interleaved', False), + ) + return dict(mrope_params=params) + + def build_rotary_params(config: PretrainedConfig): """Get scaling_factor rotary params, and emb_type.""" params = dict(emb_type=RopeType.Default) @@ -135,6 +149,8 @@ def build_rotary_params(config: PretrainedConfig): if rope_scaling is not None: # BC: "rope_type" was originally "type" rope_type_str = rope_scaling.get('rope_type', rope_scaling.get('type', 'default')) + if rope_type_str == 'mrope': + rope_type_str = 'default' if rope_type_str == 'fope': rope_type_str = 'default' build_funcs = dict(default=_get_default_rope_parameters, @@ -146,9 +162,12 @@ def build_rotary_params(config: PretrainedConfig): llama3=_get_llama3_parameters) params.update(build_funcs[rope_type_str](config)) params.update(_get_fope_parameters(config)) + params.update(_get_mrope_parameters(config)) # update partial_rotary_factor - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, 'partial_rotary_factor') else None + partial_rotary_factor = getattr(config, 'partial_rotary_factor', None) + if partial_rotary_factor is None and rope_scaling is not None: + partial_rotary_factor = rope_scaling.get('partial_rotary_factor', None) if partial_rotary_factor is not None: params['partial_rotary_factor'] = partial_rotary_factor @@ -163,6 +182,7 @@ def build_rotary_embedding(dim: int, longrope_params: LongRoPEScalingParameters = None, llama3_params: Llama3Parameters = None, fope_params: FopeParameters = None, + mrope_params: MropeParameters = None, emb_type: RopeType = RopeType.Default, partial_rotary_factor: float = None, device: torch.device = None) -> nn.Module: @@ -186,8 +206,9 @@ def build_rotary_embedding(dim: int, if fope_params is not None: inv_freq = impl.inv_freq fope_params.inv_freq = inv_freq - fope = FopeRotaryEmbedding(dim, max_position_embeddings, scaling_factor, fope_params, device) - return fope + impl = FopeRotaryEmbedding(dim, max_position_embeddings, scaling_factor, fope_params, device) + elif mrope_params is not None: + impl = MRotaryEmbedding(impl, mrope_params) return impl @@ -252,6 +273,107 @@ def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: return query, key +class MRotaryEmbedding(nn.Module): + """Rotary embedding wrapper with multimodal axis selection.""" + + def __init__(self, impl: nn.Module, params: MropeParameters): + super().__init__() + self.impl = impl + self.mrope_section = list(params.mrope_section) + self.mrope_interleaved = params.mrope_interleaved + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): + """forward.""" + if position_ids.size(0) != 3: + cos, sin = self.impl(x, position_ids) + return cos, sin + + if self._uses_static_inv_freq_rope(): + return self.build_mrope_tables_from_selected_freqs(x, position_ids) + + leading_shape = position_ids.shape[:-1] + flat_position_ids = position_ids.flatten(0, -2) + cos, sin = self.impl(x, flat_position_ids) + cos = cos.reshape(*leading_shape, *cos.shape[1:]) + sin = sin.reshape(*leading_shape, *sin.shape[1:]) + return self.apply_mrope(cos), self.apply_mrope(sin) + + def apply_mrope(self, freqs: torch.Tensor): + """Select temporal, height, and width rotary bands.""" + if self.mrope_interleaved: + return self.apply_interleaved_mrope(freqs) + return self.apply_chunked_mrope(freqs) + + def apply_chunked_mrope(self, freqs: torch.Tensor): + """Apply Qwen2-VL style chunked MRoPE.""" + # Layout is contiguous bands: T..., H..., W..., then repeated for the + # duplicated RoPE half if freqs already contains cos/sin table width. + mrope_section = self.mrope_section + if freqs.size(-1) == sum(self.mrope_section) * 2: + mrope_section = mrope_section * 2 + selected_chunks = [] + for index, chunk in enumerate(freqs.split(mrope_section, dim=-1)): + axis = index % 3 + selected_chunks.append(chunk[axis]) + return torch.cat(selected_chunks, dim=-1) + + def apply_interleaved_mrope(self, freqs: torch.Tensor): + """Apply Qwen3-VL style interleaved MRoPE.""" + # Layout is lane-interleaved: T, H, W, T, H, W...; start from T and + # overwrite the H/W lanes from their corresponding axes. + half_dim = sum(self.mrope_section) + has_duplicated_half = freqs.size(-1) == half_dim * 2 + freqs_t = freqs[0].clone() + for dim, offset in enumerate((1, 2), start=1): + length = min(self.mrope_section[dim] * 3, half_dim) + freqs_t[..., offset:length:3] = freqs[dim, ..., offset:length:3] + if has_duplicated_half: + freqs_t[..., half_dim + offset:half_dim + length:3] = \ + freqs[dim, ..., half_dim + offset:half_dim + length:3] + return freqs_t + + def _uses_static_inv_freq_rope(self): + """Check whether RoPE is equivalent to position_ids * inv_freq.""" + if not hasattr(self.impl, 'inv_freq'): + return False + backend_only_attrs = ('_ntk_inv_freq', 'short_factor', 'long_factor', 'mscale_all_dim') + return not any(hasattr(self.impl, attr) for attr in backend_only_attrs) + + def build_mrope_tables_from_selected_freqs(self, x: torch.Tensor, position_ids: torch.Tensor): + """Build MRoPE cos/sin tables from selected axis frequencies.""" + inv_freq = self.impl.inv_freq + if inv_freq.device != x.device: + self.impl.inv_freq = inv_freq.to(x.device) + inv_freq = self.impl.inv_freq + + scaling_factor = getattr(self.impl, 'scaling_factor', 1.0) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + position_ids = position_ids.float() + if scaling_factor != 1.0: + position_ids = position_ids / scaling_factor + + inv_freq = inv_freq.float() + freqs = position_ids.unsqueeze(-1) * inv_freq + freqs = self.apply_mrope(freqs) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + mscale = getattr(self.impl, 'mscale', None) + if mscale is not None: + cos = cos * mscale + sin = sin * mscale + + attention_scaling = getattr(self.impl, 'attention_scaling', None) + if attention_scaling is not None: + cos = cos * attention_scaling + sin = sin * attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class FopeRotaryEmbedding(nn.Module): """Fope rotary embedding.""" diff --git a/tests/pytorch/nn/test_rotary_embedding.py b/tests/pytorch/nn/test_rotary_embedding.py new file mode 100644 index 0000000000..f4bc35380a --- /dev/null +++ b/tests/pytorch/nn/test_rotary_embedding.py @@ -0,0 +1,182 @@ +import torch +from transformers import PretrainedConfig + +from lmdeploy.pytorch.nn import build_rotary_embedding_from_config + + +def _make_config(*, mrope_interleaved: bool = False): + return PretrainedConfig( + hidden_size=16, + num_attention_heads=1, + head_dim=16, + max_position_embeddings=64, + rope_theta=10000, + rope_scaling=dict( + rope_type='default', + mrope_section=[2, 3, 3], + mrope_interleaved=mrope_interleaved, + ), + ) + + +def test_mrope_uses_rope_parameters_partial_rotary_factor(): + config = PretrainedConfig( + hidden_size=16, + num_attention_heads=1, + head_dim=16, + max_position_embeddings=64, + rope_parameters=dict( + rope_type='default', + rope_theta=10000, + partial_rotary_factor=0.5, + mrope_section=[1, 1, 2], + mrope_interleaved=True, + ), + ) + rotary_emb = build_rotary_embedding_from_config(config) + hidden_states = torch.empty(5, 16) + position_ids = torch.stack([ + torch.arange(5), + torch.arange(10, 15), + torch.arange(20, 25), + ]) + + cos, sin = rotary_emb(hidden_states, position_ids) + + assert cos.shape == (5, 8) + assert sin.shape == (5, 8) + + +def test_chunked_mrope_matches_legacy_selection(): + rotary_emb = build_rotary_embedding_from_config(_make_config()) + hidden_states = torch.empty(5, 16) + position_ids = torch.stack([ + torch.arange(5), + torch.arange(10, 15), + torch.arange(20, 25), + ]) + + cos, sin = rotary_emb(hidden_states, position_ids) + base_cos, base_sin = rotary_emb.impl(hidden_states, position_ids) + mrope_section = [2, 3, 3] * 2 + expected_cos = torch.cat([m[i % 3] for i, m in enumerate(base_cos.split(mrope_section, dim=-1))], dim=-1) + expected_sin = torch.cat([m[i % 3] for i, m in enumerate(base_sin.split(mrope_section, dim=-1))], dim=-1) + + assert rotary_emb._uses_static_inv_freq_rope() + torch.testing.assert_close(cos, expected_cos, rtol=0, atol=1e-7) + torch.testing.assert_close(sin, expected_sin, rtol=0, atol=1e-7) + + +def test_interleaved_mrope_matches_qwen3_selection(): + rotary_emb = build_rotary_embedding_from_config(_make_config(mrope_interleaved=True)) + hidden_states = torch.empty(5, 16) + position_ids = torch.stack([ + torch.arange(5), + torch.arange(10, 15), + torch.arange(20, 25), + ]) + + cos, sin = rotary_emb(hidden_states, position_ids) + base_cos, base_sin = rotary_emb.impl(hidden_states, position_ids) + + def apply_interleaved(freqs): + half_dim = freqs.size(-1) // 2 + out = freqs[0].clone() + for dim, offset in enumerate((1, 2), start=1): + half_dim = freqs.size(-1) // 2 + length = min([2, 3, 3][dim] * 3, half_dim) + out[..., offset:length:3] = freqs[dim, ..., offset:length:3] + out[..., half_dim + offset:half_dim + length:3] = \ + freqs[dim, ..., half_dim + offset:half_dim + length:3] + return out + + torch.testing.assert_close(cos, apply_interleaved(base_cos)) + torch.testing.assert_close(sin, apply_interleaved(base_sin)) + + +def test_interleaved_mrope_matches_legacy_qwen3_formula_tightly(): + rotary_emb = build_rotary_embedding_from_config(_make_config(mrope_interleaved=True)) + hidden_states = torch.empty(1, 17, 16) + position_ids = torch.stack([ + torch.arange(17), + torch.arange(101, 118), + torch.arange(1001, 1018), + ]).unsqueeze(1) + + cos, sin = rotary_emb(hidden_states, position_ids) + inv_freq = rotary_emb.impl.inv_freq + inv_freq_expanded = inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs_t = freqs[0].clone() + for dim, offset in enumerate((1, 2), start=1): + length = [2, 3, 3][dim] * 3 + freqs_t[..., offset:length:3] = freqs[dim, ..., offset:length:3] + + emb = torch.cat((freqs_t, freqs_t), dim=-1) + expected_cos = emb.cos() + expected_sin = emb.sin() + + assert (cos - expected_cos).abs().max().item() <= 2e-7 + assert (sin - expected_sin).abs().max().item() <= 2e-7 + + +def test_interleaved_mrope_yarn_long_context_override_uses_backend_path(): + config = PretrainedConfig( + hidden_size=2048, + num_attention_heads=8, + head_dim=256, + max_position_embeddings=512000, + rope_theta=10000000, + rope_parameters=dict( + mrope_interleaved=True, + mrope_section=[11, 11, 10], + rope_type='yarn', + rope_theta=10000000, + partial_rotary_factor=0.25, + factor=4.0, + original_max_position_embeddings=262144, + ), + ) + rotary_emb = build_rotary_embedding_from_config(config) + position_ids = torch.tensor([ + [0, 1, 1024, 262143, 262144, 400000, 511998, 511999], + [3, 5, 2048, 262140, 262150, 399900, 510000, 511999], + [7, 11, 4096, 262130, 262160, 399800, 509000, 511999], + ]).unsqueeze(1) + hidden_states = torch.empty(1, position_ids.shape[-1], config.head_dim) + + cos, sin = rotary_emb(hidden_states, position_ids) + leading_shape = position_ids.shape[:-1] + base_cos, base_sin = rotary_emb.impl(hidden_states, position_ids.flatten(0, -2)) + base_cos = base_cos.reshape(*leading_shape, *base_cos.shape[1:]) + base_sin = base_sin.reshape(*leading_shape, *base_sin.shape[1:]) + + def apply_interleaved_reference(freqs): + half_dim = freqs.size(-1) // 2 + out = freqs[0].clone() + for dim, offset in enumerate((1, 2), start=1): + length = min(config.rope_parameters['mrope_section'][dim] * 3, half_dim) + out[..., offset:length:3] = freqs[dim, ..., offset:length:3] + out[..., half_dim + offset:half_dim + length:3] = \ + freqs[dim, ..., half_dim + offset:half_dim + length:3] + return out + + assert not rotary_emb._uses_static_inv_freq_rope() + assert cos.shape == (1, position_ids.shape[-1], 64) + assert sin.shape == (1, position_ids.shape[-1], 64) + torch.testing.assert_close(cos, apply_interleaved_reference(base_cos), rtol=0, atol=0) + torch.testing.assert_close(sin, apply_interleaved_reference(base_sin), rtol=0, atol=0) + + +def test_mrope_config_keeps_text_positions_as_regular_rope(): + rotary_emb = build_rotary_embedding_from_config(_make_config(mrope_interleaved=True)) + hidden_states = torch.empty(5, 16) + position_ids = torch.arange(5).unsqueeze(0) + + cos, sin = rotary_emb(hidden_states, position_ids) + expected_cos, expected_sin = rotary_emb.impl(hidden_states, position_ids) + + torch.testing.assert_close(cos, expected_cos) + torch.testing.assert_close(sin, expected_sin)