Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ all = [
"liger-kernel",
"parametrize",
"mathruler",
"pylatexenc"
"pylatexenc",
"flash-linear-attention"
]

[tool.mypy]
Expand Down
223 changes: 223 additions & 0 deletions tests/model/test_qwen3_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import os
import unittest
import parametrize
import torch
from packaging.version import Version
from transformers import __version__ as transformers_version
from xtuner._testing import DeterministicDDPTestCase
from transformers import AutoTokenizer
import torch.distributed as dist
from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config
from xtuner.v1.loss.ce_loss import CELossConfig
from xtuner.v1.model.moe.moe import SequenceContext
from xtuner.v1.utils.test_utils import init_data_mesh
from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig
from xtuner.v1.config import FSDPConfig
from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh


VIDEO_ROOT = os.environ["VIDEO_ROOT"]

@unittest.skipIf(
Version(transformers_version) < Version("5.2.0"),
f"transformers >= 5.2.0 is required, but got {transformers_version}"
)
class TestQwen3_5_VL(DeterministicDDPTestCase):

def _forward(self, model, type, device, sp_size):
QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"]
if type == 'image':
tokenizer = AutoTokenizer.from_pretrained(QWEN3_VL_MOE_PATH)
tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_MOE_PATH, add_vision_id=True).build(
tokenizer)

raw_data = {"id": 3, "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {
"url": "tests/resource/mscoco_twocat_000000039769.jpg", "image_wh": [640, 480]}}, {"type": "image_url",
"image_url": {
"url": "tests/resource/mscoco_dog_000000319154.jpg",
"image_wh": [375,
500]}},
{"type": "text",
"text": "<IMG_CONTEXT>\n<IMG_CONTEXT>\n请描述下第二幅图片中的狗是什么颜色?"}]},
{"role": "assistant", "content": "图片中的狗是棕色的。"}]}
tokenized_data = tokenize_fn(raw_data)
input_ids = torch.tensor(tokenized_data['input_ids'])[None].cuda()
labels = torch.tensor(tokenized_data['labels'])[None].cuda()
pixel_values = tokenized_data['pixel_values'].cuda()
image_grid_thw = tokenized_data['image_grid_thw'].cuda()
position_ids = tokenized_data['position_ids'].cuda()
elif type == 'video':
tokenizer = AutoTokenizer.from_pretrained(QWEN3_VL_MOE_PATH)
tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_MOE_PATH, rand_video_max_frames=14,
add_vision_id=True).build(tokenizer)

raw_data = {"id": 9, "messages": [{"role": "user", "content": [{"type": "video_url",
"video_url": {"url": "tennis_frames_4fps/",
"image_wh": [1280, 720],
"origin_video_length": 182,
"origin_fps": 30.0,
"processed_video_length": 23,
"processed_fps": 4}},
{"type": "video_url",
"video_url": {"url": "tennis_frames_2fps/",
"image_wh": [1280, 720],
"origin_video_length": 182,
"origin_fps": 30.0,
"processed_video_length": 13,
"processed_fps": 2}},
{"type": "text",
"text": "<VIDEO_CONTEXT><VIDEO_CONTEXT>两个视频中都在做什么?"}]},
{"role": "assistant", "content": "打网球"}]}

tokenized_data = tokenize_fn(raw_data, media_root=VIDEO_ROOT)
input_ids = torch.tensor(tokenized_data['input_ids'])[None].cuda()
labels = torch.tensor(tokenized_data['labels'])[None].cuda()
pixel_values = tokenized_data['pixel_values'].cuda()
image_grid_thw = tokenized_data['image_grid_thw'].cuda()
position_ids = tokenized_data['position_ids'].cuda()
else:
rank = dist.get_rank()
tokenizer = AutoTokenizer.from_pretrained(QWEN3_VL_MOE_PATH)
if sp_size == 1:
input_ids = tokenizer(f"今天天气不错,是学习的好日子。请听题: 1+1 等于多少?",
return_tensors="pt").input_ids.to(device)
else:
input_ids = tokenizer(f"今天天气不错,是学习的好日子。请听题: 1+{rank} 等于多少?",
return_tensors="pt").input_ids.to(device)
labels = input_ids.clone()
pixel_values = None
image_grid_thw = None
position_ids = None

from transformers import Qwen3_5MoeForConditionalGeneration
is_hf_model = isinstance(model, Qwen3_5MoeForConditionalGeneration)

if is_hf_model:
with torch.no_grad():
if type == 'video':
output = model(
input_ids=input_ids,
labels=labels,
pixel_values_videos=pixel_values,
video_grid_thw=image_grid_thw,
position_ids=position_ids,
use_cache = False
)
else:
output = model(
input_ids=input_ids,
labels=labels,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
position_ids=position_ids,
use_cache = False
)
return output.loss
else:
loss_cfg = CELossConfig()

shift_input_ids = input_ids[:, :-1]
shifted_labels = labels[:, 1:]
if position_ids is not None:
position_ids = position_ids[..., :-1]

sp_mesh = None
if sp_size > 1:
data_mesh = init_data_mesh(device, sp_size=sp_size)
sp_mesh = data_mesh["sp"]

seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),))
seq_ctx.image_grid_thw = image_grid_thw
seq_ctx.pixel_values = pixel_values
if position_ids is not None:
seq_ctx.position_ids = position_ids
seq_ctx.to('cuda')
if sp_size > 1:
seq_ctx = seq_ctx.split(sp_mesh)

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]

with torch.no_grad():
output = model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
)
loss = output["loss"]
return loss

@parametrize.parametrize(
"device,sp_size,tol",
[
("cuda", 1, 1e-2),
],
)
def test_qwen3_5_vl_run(self, device, sp_size, tol):
self.create_pg(device)

from transformers import Qwen3_5MoeForConditionalGeneration
QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"]

hf_model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
QWEN3_VL_MOE_PATH,
dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="cuda",
trust_remote_code=True
).eval()
# Cannot understand, but must accept. Once there is no this code, it will appear cuda access illegal memory error in multi-GPU
torch.distributed.barrier()

loss_hf_text = self._forward(hf_model, type='text', device=device, sp_size=sp_size)
loss_hf_image = self._forward(hf_model, type='image', device=device, sp_size=sp_size)
# loss_hf_video = self._forward(hf_model, type='video', device=device, sp_size=sp_size)

del hf_model
torch.cuda.empty_cache()

with torch.device("meta"):
model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False)
qwen3vl_model = model_cfg.build().to(torch.bfloat16)

qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH)
qwen3vl_model.eval()

loss_xtuner_text = self._forward(qwen3vl_model, type='text',device=device, sp_size=sp_size)
loss_xtuner_image = self._forward(qwen3vl_model, type='image',device=device, sp_size=sp_size)
loss_xtuner_video = self._forward(qwen3vl_model, type='video',device=device, sp_size=sp_size)

self.assertTrue(torch.allclose(loss_xtuner_text, loss_hf_text.to(loss_xtuner_text.dtype), atol=tol, rtol=tol))
self.assertTrue(torch.allclose(loss_xtuner_image, loss_hf_image.to(loss_xtuner_image.dtype), atol=tol, rtol=tol))
# self.assertTrue(torch.allclose(loss_xtuner_video, loss_hf_video.to(loss_xtuner_video.dtype), atol=tol, rtol=tol))

del qwen3vl_model
torch.cuda.empty_cache()

# test fsdp
with torch.device("meta"):
model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False)
qwen3vl_model = model_cfg.build().to(torch.bfloat16)

fsdp_config = FSDPConfig(cpu_offload=False)
fsdp_mesh = init_world_mesh()
qwen3vl_model.vision_tower.fsdp_mesh = fsdp_mesh
qwen3vl_model.vision_tower.fsdp_config = fsdp_config
qwen3vl_model.fully_shard(fsdp_config=fsdp_config)
qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH)
qwen3vl_model.eval()

loss_xtuner_text_fsdp = self._forward(qwen3vl_model, type='text',device=device, sp_size=sp_size)
loss_xtuner_image_fsdp = self._forward(qwen3vl_model, type='image',device=device, sp_size=sp_size)
loss_xtuner_video_fsdp = self._forward(qwen3vl_model, type='video',device=device, sp_size=sp_size)
self.assertTrue(torch.allclose(loss_xtuner_text_fsdp, loss_xtuner_text, atol=tol, rtol=tol))
self.assertTrue(torch.allclose(loss_xtuner_image_fsdp, loss_xtuner_image, atol=tol, rtol=tol))
self.assertTrue(torch.allclose(loss_xtuner_video_fsdp, loss_xtuner_video, atol=tol, rtol=tol))

@property
def world_size(self) -> int:
return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "4"))
16 changes: 14 additions & 2 deletions xtuner/_testing/patch_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@


def patch_hf_rms_norm(module: nn.Module) -> None:
for submodule in module.modules():
replacements = []
for name, submodule in module.named_modules():
if "RMSNorm" in submodule.__class__.__name__ and isinstance(submodule, nn.Module):
submodule.__class__.forward = RMSNorm.forward
dim = submodule.weight.shape
device = submodule.weight.device
eps = submodule.variance_epsilon
new_submodule = RMSNorm(hidden_size=dim, eps=eps).to(device)
new_submodule.load_state_dict(submodule.state_dict())
parts = name.split(".")
parent = module
for part in parts[:-1]:
parent = getattr(parent, part)
replacements.append((parent, parts[-1], new_submodule))

for parent, attr_name, new_submodule in replacements:
setattr(parent, attr_name, new_submodule)
3 changes: 1 addition & 2 deletions xtuner/v1/data_proto/messages/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from pydantic import BaseModel, ConfigDict

from transformers import PreTrainedTokenizer
from xtuner.utils import IGNORE_INDEX
from xtuner.v1.data_proto.messages.base import BaseMessages
from xtuner.v1.data_proto.templates import ChatTemplate, HybridChatTemplate
from xtuner.v1.utils import get_logger
from xtuner.v1.utils import IGNORE_INDEX, get_logger


logger = get_logger()
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/data_proto/sequence_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SequenceContext:
block_table: torch.Tensor | None
device: str | torch.device # TODO: 这个地方有点乱,到处是 device
position_ids: torch.LongTensor | None
seq_idx: torch.IntTensor | None

# Qwen3VL
image_grid_thw: torch.Tensor | None
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
self.inputs_embeds = inputs_embeds
self.num_img_tokens = num_img_tokens
self.rollout_routed_experts = rollout_routed_experts
self.seq_idx = None

seq_lens_k = self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1]
seq_lens_q = self.cu_seq_lens_q[1:] - self.cu_seq_lens_q[:-1]
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/datasets/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def qwen3_vl_sft_collator(
if len(position_ids_list) > 0:
position_ids = torch.cat(position_ids_list, dim=-1)
position_ids = position_ids[:, :, :-1]
if pack_to_max_length and pack_max_length - position_ids.shape[-1] > 0:
position_ids = pad_to_max_length(position_ids, 0, max_length=pack_max_length, dim=-1)
Comment thread
hhaAndroid marked this conversation as resolved.

num_img_tokens: list[int] = []
for data in instance:
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
InternVL3P5MoE30BA3Config,
InternVLBaseConfig,
)
from .compose.qwen3_5 import Qwen3_5_VLMoE35BA3Config
from .compose.qwen3_vl import (
Qwen3VLDense4BConfig,
Qwen3VLDense8BConfig,
Expand Down Expand Up @@ -98,4 +99,5 @@ def get_model_config_from_hf(model_path: Path):
"TorchCompileOption",
"DEFAULT_FLOAT8_CFG",
"XTunerBaseModelConfig",
"Qwen3_5_VLMoE35BA3Config",
]
6 changes: 4 additions & 2 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
WeightWithDynamicTilewiseFloat8CastTensor,
)
from xtuner.v1.loss import BaseLossContext
from xtuner.v1.module.attention import MHAConfig, MLAConfig
from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig, MLAConfig
from xtuner.v1.module.rope import RopeScalingConfig
from xtuner.v1.ops.comm.foreach_allgather import foreach_all_gather
from xtuner.v1.utils import get_device, get_logger, get_torch_device_module, profile_time_and_memory
Expand Down Expand Up @@ -147,9 +147,11 @@ class TransformerConfig(XTunerBaseModelConfig):
hidden_size: Annotated[int, Parameter(group="model")]
intermediate_size: Annotated[int, Parameter(group="model")]
rms_norm_eps: Annotated[float, Parameter(group="model")]
rms_norm_type: Annotated[Literal["default", "zero_centered"], Parameter(group="model")] = "default"
rope_theta: Annotated[float, Parameter(group="model")] # required by transformers's build rope
hidden_act: Annotated[str, Parameter(group="model")] # key defined in `transformers.activations.ACT2CLS`
attention: MLAConfig | MHAConfig
linear_attention: Annotated[GatedDeltaNetConfig | None, Parameter(group="model")] = None
mlp_bias: Annotated[bool, Parameter(group="model")] = False
tie_word_embeddings: Annotated[bool, Parameter(group="model")] = False
model_type: Annotated[str | None, Parameter(group="model")] = None # TODO: yehaochen maybe should be removed
Expand Down Expand Up @@ -182,7 +184,7 @@ def head_dim(self) -> int:
return self.attention.head_dim

@computed_field
def layers_type(self) -> list[Literal["full_attention", "sliding_attention"]]:
def layers_type(self) -> list[Literal["full_attention", "sliding_attention", "linear_attention"]]:
if not self.use_sliding_window:
return ["full_attention"] * self.num_hidden_layers
else:
Expand Down
6 changes: 6 additions & 0 deletions xtuner/v1/model/compose/qwen3_5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .qwen3_5_config import Qwen3_5_VLMoE35BA3Config


__all__ = [
"Qwen3_5_VLMoE35BA3Config",
]
33 changes: 33 additions & 0 deletions xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from xtuner.v1.model.base import TransformerConfig
from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig
from xtuner.v1.utils import get_logger

from ..qwen3_vl.qwen3_vl_config import Qwen3VLBaseConfig, Qwen3VLProjectorConfig, Qwen3VLVisionConfig


logger = get_logger()


class Qwen3_5_VisionConfig(Qwen3VLVisionConfig):
deepstack_visual_indexes: list[int] = []


class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig):
deepstack_visual_indexes: list[int] = []


class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
vision_config: Qwen3_5_VisionConfig
projector_config: Qwen3_5_ProjectorConfig
text_config: TransformerConfig

image_token_id: int = 248056
video_token_id: int = 248057
vision_start_token_id: int = 248053
vision_end_token_id: int = 248054


class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig):
vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig()
projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig()
text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig()
Loading
Loading