Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lmdeploy/pytorch/backends/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ class FopeParameters:
inv_freq: torch.Tensor = None


@dataclass
class MropeParameters:
"""MRoPE parameters."""
mrope_section: list[int]
mrope_interleaved: bool = False


class RotaryEmbeddingImpl(ABC):
"""Rotary embedding implementation api."""

Expand Down
28 changes: 2 additions & 26 deletions lmdeploy/pytorch/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# adapted from:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py

from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable, Sequence
from typing import Any

import numpy as np
Expand All @@ -23,35 +23,12 @@
from .utils.model import DeployModelMixin, vlm_model


def _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: list[int],
position_ids: torch.Tensor, rotary_emb_func: Callable):
_mrope_position_ids = torch.zeros(3, position_ids.shape[-1], dtype=position_ids.dtype, device=position_ids.device)
_mrope_position_ids[:, :mrope_position_ids.shape[-1]] = mrope_position_ids
cos, sin = rotary_emb_func(hidden_states, _mrope_position_ids)
_cos = torch.zeros(cos.shape[1], cos.shape[-1], dtype=cos.dtype, device=cos.device)
_sin = torch.zeros_like(_cos)
mrope_section = mrope_section * 2

def _apply_split(src, dst):
start = 0
for i, m in enumerate(src.split(mrope_section, dim=-1)):
dst[:, start:start + mrope_section[i]] = m[i % 3]
start += mrope_section[i]

_apply_split(cos, _cos)
_apply_split(sin, _sin)

return _cos, _sin


class Glm4vTextModel(nn.Module):

def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.mrope_section = config.rope_scaling['mrope_section']

self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
Expand Down Expand Up @@ -92,8 +69,7 @@ def forward(
cos, sin = self.rotary_emb(hidden_states, position_ids)
cos, sin = cos[0], sin[0]
else:
cos, sin = _apply_mrope_selection(hidden_states, mrope_position_ids, self.mrope_section, position_ids,
self.rotary_emb)
cos, sin = self.rotary_emb(hidden_states, mrope_position_ids)
rotary_pos_emb = (cos, sin)

# decoding
Expand Down
28 changes: 2 additions & 26 deletions lmdeploy/pytorch/models/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.

from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable, Sequence
from typing import Any

import numpy as np
Expand Down Expand Up @@ -32,27 +32,6 @@
from .utils.model import DeployModelMixinV1, build_embedding, vlm_model


def _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: list[int],
position_ids: torch.Tensor, rotary_emb_func: Callable):
_mrope_position_ids = torch.zeros(3, position_ids.shape[-1], dtype=position_ids.dtype, device=position_ids.device)
_mrope_position_ids[:, :mrope_position_ids.shape[-1]] = mrope_position_ids
cos, sin = rotary_emb_func(hidden_states, _mrope_position_ids)
_cos = torch.zeros(cos.shape[1], cos.shape[-1], dtype=cos.dtype, device=cos.device)
_sin = torch.zeros_like(_cos)
mrope_section = mrope_section * 2

def _apply_split(src, dst):
start = 0
for i, m in enumerate(src.split(mrope_section, dim=-1)):
dst[:, start:start + mrope_section[i]] = m[i % 3]
start += mrope_section[i]

_apply_split(cos, _cos)
_apply_split(sin, _sin)

return _cos, _sin


class Qwen2Attention(nn.Module):
"""Rewrite module of Qwen2Attention."""

Expand Down Expand Up @@ -246,8 +225,6 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.mrope_section = config.rope_scaling['mrope_section']

self.embed_tokens = build_embedding(
config.vocab_size,
config.hidden_size,
Expand Down Expand Up @@ -290,8 +267,7 @@ def forward(
cos, sin = self.rotary_emb(hidden_states, position_ids)
cos, sin = cos[0], sin[0]
else:
cos, sin = _apply_mrope_selection(hidden_states, mrope_position_ids, self.mrope_section, position_ids,
self.rotary_emb)
cos, sin = self.rotary_emb(hidden_states, mrope_position_ids)
rotary_pos_emb = (cos, sin)

# decoding
Expand Down
106 changes: 9 additions & 97 deletions lmdeploy/pytorch/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update

import lmdeploy.pytorch.nn.gated_delta as gated_delta_util
from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RMSNorm, SiluAndMul
from lmdeploy.pytorch.nn import (
ApplyRotaryEmb,
Attention,
LayerNorm,
RMSNorm,
SiluAndMul,
build_rotary_embedding_from_config,
)
from lmdeploy.pytorch.nn.gated_delta import CausalConv1d, GatedDelta, GatedDeltaMeta, build_rmsnorm_gated
from lmdeploy.pytorch.nn.linear import (
build_colwise_linear,
Expand All @@ -24,7 +30,6 @@
build_qkv_proj,
build_rowwise_linear,
)
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight
from lmdeploy.vl.constants import Modality

Expand Down Expand Up @@ -805,99 +810,6 @@ def forward(
return outputs


class Qwen3_5TextRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`

def __init__(self, config: PretrainedConfig, device=None):
super().__init__()
rope_scaling = get_rope_parameters(config)
assert rope_scaling is not None, 'RoPE scaling parameters must be provided in the config for Qwen3.5 models.'
self.rope_type = rope_scaling.get('rope_type', 'default')

self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
if self.rope_type != 'default':
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
else:
self.rope_init_fn = self.compute_default_rope_parameters

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer('inv_freq', inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

self.mrope_section = rope_scaling.get('mrope_section', [11, 11, 10])

@staticmethod
def compute_default_rope_parameters(
config: PretrainedConfig | None = None,
device: torch.device | None = None,
seq_len: int | None = None,
) -> tuple['torch.Tensor', float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
rope_parameters = get_rope_parameters(config)
base = rope_parameters['rope_theta']
partial_rotary_factor = rope_parameters.get('partial_rotary_factor', 1.0)
head_dim = getattr(config, 'head_dim', None) or config.hidden_size // config.num_attention_heads
dim = int(head_dim * partial_rotary_factor)

attention_factor = 1.0 # Unused in this type of RoPE

# Compute the inverse frequencies
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.int64).to(dtype=torch.float) / dim))
inv_freq = inv_freq.to(device=device)
return inv_freq, attention_factor

def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.

Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
args:
x: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
x_t: (bs, seq_len, head_dim // 2)
"""
freqs_t = freqs[0] # just overwrite the first dimension T
for dim, offset in enumerate((1, 2), start=1): # H, W
length = mrope_section[dim] * 3
idx = slice(offset, length, 3)
freqs_t[..., idx] = freqs[dim, ..., idx]
return freqs_t

@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
# In contrast to other models, Qwen3VL has different position ids for the grids
# So we expand the inv_freq to shape (3, ...)
if position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)

freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class Qwen3_5TextModel(nn.Module):
"""qwen3.5 text model."""

Expand Down Expand Up @@ -932,7 +844,7 @@ def __init__(self,
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

# build rotary embedding
self.rotary_emb = Qwen3_5TextRotaryEmbedding(config, device=device)
self.rotary_emb = build_rotary_embedding_from_config(config, device=device)

def forward(
self,
Expand Down
5 changes: 2 additions & 3 deletions lmdeploy/pytorch/models/qwen3_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from lmdeploy.pytorch.distributed import get_dist_manager
from lmdeploy.pytorch.model_inputs import StepContextManager
from lmdeploy.pytorch.nn import RMSNorm
from lmdeploy.pytorch.nn import RMSNorm, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

Expand All @@ -23,7 +23,6 @@
Qwen3_5MLP,
Qwen3_5Model,
Qwen3_5TextModel,
Qwen3_5TextRotaryEmbedding,
)
from .qwen3_5 import Qwen3_5VisionModel as Qwen3_5MoeVisionModel
from .qwen3_vl import Qwen3VLInputProcessor as Qwen3_5MoeInputProcessor
Expand Down Expand Up @@ -212,7 +211,7 @@ def __init__(self,
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

# build rotary embedding
self.rotary_emb = Qwen3_5TextRotaryEmbedding(config, device=device)
self.rotary_emb = build_rotary_embedding_from_config(config, device=device)


class Qwen3_5MoeModel(Qwen3_5Model):
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/models/qwen3_5_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import RMSNorm
from lmdeploy.pytorch.nn import RMSNorm, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_colwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .patch import add_prefix, get_build_model_context
from .qwen3_5 import Qwen3_5Attention, Qwen3_5DecoderLayer, Qwen3_5MLP, Qwen3_5TextRotaryEmbedding
from .qwen3_5 import Qwen3_5Attention, Qwen3_5DecoderLayer, Qwen3_5MLP
from .qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin

Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(
)

# build rotary embedding
self.rotary_emb = Qwen3_5TextRotaryEmbedding(config, device=device)
self.rotary_emb = build_rotary_embedding_from_config(config, device=device)

def get_input_embeddings(self):
"""Get input embeddings."""
Expand Down
Loading
Loading