Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 85 additions & 17 deletions tests/model/test_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig
from xtuner.v1.config import FSDPConfig
from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh
from xtuner.v1.data_proto.utils import pad_to_multiple_of


import tempfile
from pathlib import Path
Expand Down Expand Up @@ -82,13 +84,28 @@ def _forward(self, model, type, device, sp_size):
position_ids = tokenized_data['position_ids'].cuda()
else:
tokenizer = AutoTokenizer.from_pretrained(QWEN3_VL_MOE_PATH)
input_ids = tokenizer(f"今天天气不错,是学习的好日子。请听题: 1+1 等于多少?",
return_tensors="pt").input_ids.to(device)
labels = input_ids.clone()
tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_MOE_PATH, rand_video_max_frames=14,
add_vision_id=True).build(tokenizer)
raw_data = {
"id": 3, "messages": [
{
"role": "user", "content": [
{
"type": "text",
"text": "Translate this into chinese: Where my eyes gaze, only memories remain; where my heart strays, only yesterday's pain; where my sight stays, only regret's refrain."
}
]
},
{"role": "assistant", "content": "目之所及,唯余旧忆;心之所向,唯余昨痛;眸之所留,唯余悔声。"}
]
}
tokenized_data = tokenize_fn(raw_data)
input_ids = torch.tensor(tokenized_data['input_ids'])[None].cuda()
labels = torch.tensor(tokenized_data['labels'])[None].cuda()
pixel_values = None
image_grid_thw = None
position_ids = None

from transformers import Qwen3_5MoeForConditionalGeneration
is_hf_model = isinstance(model, Qwen3_5MoeForConditionalGeneration)

Expand All @@ -115,10 +132,8 @@ def _forward(self, model, type, device, sp_size):
dist.all_reduce(output.loss.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return output.loss
else:
loss_cfg = CELossConfig()

shift_input_ids = input_ids[:, :-1]
shifted_labels = labels[:, 1:]
shift_input_ids = pad_to_multiple_of(input_ids[:, :-1], padding_value=0, multiple_of=sp_size)
shifted_labels = pad_to_multiple_of(labels[:, 1:], padding_value=-100, multiple_of=sp_size)
if position_ids is not None:
position_ids = position_ids[..., :-1]

Expand All @@ -127,22 +142,18 @@ def _forward(self, model, type, device, sp_size):
data_mesh = init_data_mesh(device, sp_size=sp_size)
sp_mesh = data_mesh["sp"]

seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),))
seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to("cuda"),))
seq_ctx.image_grid_thw = image_grid_thw
seq_ctx.pixel_values = pixel_values
if position_ids is not None:
seq_ctx.position_ids = position_ids
seq_ctx.to('cuda')
seq_ctx.to("cuda")
if sp_size > 1:
seq_ctx = seq_ctx.split(sp_mesh)

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
data_batch = [{"seq_ctx": seq_ctx, "shifted_labels": shifted_labels}]
loss_ctx_batch = model.build_loss_ctx_batch(data_batch, sp_mesh=sp_mesh)
loss_ctx = loss_ctx_batch[0]

with torch.no_grad():
output = model(
Expand Down Expand Up @@ -221,6 +232,63 @@ def test_qwen3_5_vl_run(self, device, sp_size, tol):
self.assertTrue(torch.allclose(loss_xtuner_image_fsdp, loss_xtuner_image, atol=tol, rtol=tol))
self.assertTrue(torch.allclose(loss_xtuner_video_fsdp, loss_xtuner_video, atol=tol, rtol=tol))

@parametrize.parametrize(
"device,sp_size,tol",
[
("cuda", 1, 1e-2),
("cuda", 4, 1e-2),
],
)
def test_qwen3_5_vl_run_mtp(self, device, sp_size, tol):
self.create_pg(device)
loss_reference = {
"text": 1.5416,
"image": 3.6920,
"video": 8.2165,
}

QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"]

torch.cuda.empty_cache()

with torch.device("meta"):
model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False)
model_cfg.text_config.mtp_config = MTPConfig(num_layers=1, loss_scaling_factor=1)
qwen3vl_model = model_cfg.build().to(torch.bfloat16)

qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH)
qwen3vl_model.eval()

losses = {}

loss_xtuner_text = self._forward(qwen3vl_model, type="text", device=device, sp_size=sp_size)
self.assertFalse(torch.isnan(loss_xtuner_text), "MTP text loss should not be NaN")

loss_xtuner_image = self._forward(qwen3vl_model, type="image", device=device, sp_size=sp_size)
self.assertFalse(torch.isnan(loss_xtuner_image), "MTP image loss should not be NaN")

loss_xtuner_video = self._forward(qwen3vl_model, type="video", device=device, sp_size=sp_size)
self.assertFalse(torch.isnan(loss_xtuner_video), "MTP video loss should not be NaN")

losses["text"] = loss_xtuner_text
losses["image"] = loss_xtuner_image
losses["video"] = loss_xtuner_video

for key, loss in losses.items():
self.assertTrue(
torch.allclose(
loss, torch.tensor(
loss_reference[key],
device=loss_xtuner_text.device,
dtype=loss_xtuner_text.dtype
),
atol=tol,
rtol=tol
),
f"Expected text loss around {key}, but got {loss.item()}"
)


@parametrize.parametrize(
"device,sp_size",
[
Expand Down
90 changes: 89 additions & 1 deletion xtuner/v1/data_proto/sequence_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.distributed.device_mesh import DeviceMesh
from typing_extensions import Self

from .utils import pad_to_multiple_of, split_for_sequence_parallel
from .utils import gather_for_sequence_parallel, pad_to_multiple_of, split_for_sequence_parallel


# Avoid using dataclass decorator here to get rid of extra ops called in pytorch 2.8 and above
Expand Down Expand Up @@ -50,6 +50,12 @@ class SequenceContext:
# moe routed_experts
rollout_routed_experts: torch.Tensor | None

# Private backing attributes for SP shard reconstruction
_raw_input_ids: torch.LongTensor | None
_raw_inputs_embeds: torch.FloatTensor | None
_shard_start: int
_shard_size: int

def __init__(
self,
input_ids: torch.LongTensor | None, # shape (1, seq_len)
Expand All @@ -71,6 +77,11 @@ def __init__(
inputs_embeds: torch.FloatTensor | None = None,
num_img_tokens: list[list[int]] | None = None,
rollout_routed_experts: torch.Tensor | None = None,
# SP shard metadata: private, accessed via properties below
raw_input_ids: torch.LongTensor | None = None,
raw_inputs_embeds: torch.FloatTensor | None = None,
shard_start: int = 0,
shard_size: int = 0,
):
# Only to distinguish parameters accepted by the constructor from attributes. For example, for `max_length_q`,
# the argument can be an int, but as an attribute it can only be a tensor
Expand Down Expand Up @@ -99,6 +110,10 @@ def __init__(
self.inputs_embeds = inputs_embeds
self.num_img_tokens = num_img_tokens
self.rollout_routed_experts = rollout_routed_experts
self._raw_input_ids = raw_input_ids
self._raw_inputs_embeds = raw_inputs_embeds
self._shard_start = shard_start
self._shard_size = shard_size
self.seq_idx = None

seq_lens_k = self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1]
Expand Down Expand Up @@ -169,6 +184,7 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self:
start = sp_input_ids.shape[1] * sequence_parallel_mesh.get_local_rank()
end = start + sp_input_ids.shape[1]
sp_num_padding = max(0, min(sp_input_ids.shape[1], end - num_non_padding))
shard_size = sp_input_ids.shape[1]

if self.position_ids is not None:
pad_position_ids = pad_to_multiple_of(self.position_ids, 0, multiple_of, -1)
Expand Down Expand Up @@ -205,6 +221,9 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self:
inputs_embeds=self.inputs_embeds,
num_img_tokens=self.num_img_tokens,
rollout_routed_experts=self.rollout_routed_experts,
raw_input_ids=cast(torch.LongTensor, pad_input_ids),
shard_start=start,
shard_size=shard_size,
)
return sp_seq_ctx
else:
Expand Down Expand Up @@ -308,6 +327,71 @@ def seq_lens_q(self) -> torch.LongTensor:
def seq_lens_k(self) -> torch.LongTensor:
return self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] # type: ignore

@property
def raw_input_ids(self) -> torch.LongTensor | None:
"""Full (un-split) input_ids across all SP ranks.

In non-SP mode, returns ``input_ids`` directly. In SP mode, returns the
pre-stored full tensor if available; otherwise triggers an allgather and
caches the result for subsequent calls.

Returns:
torch.LongTensor | None: The full input_ids tensor, or ``None`` if
``input_ids`` is ``None``.
"""
if self._raw_input_ids is not None:
return self._raw_input_ids
if self.sequence_parallel_mesh is None or self.sequence_parallel_mesh.size() == 1:
return self.input_ids
assert self.input_ids is not None
gathered = gather_for_sequence_parallel(
self.input_ids, dim=1, sp_group=self.sequence_parallel_mesh.get_group()
)
self._raw_input_ids = cast(torch.LongTensor, gathered)
return self._raw_input_ids

@property
def raw_inputs_embeds(self) -> torch.FloatTensor | None:
"""Full (un-split) inputs_embeds across all SP ranks.

In non-SP mode, returns ``inputs_embeds`` directly. In SP mode, triggers
a single allgather on first access and caches the result for subsequent
calls, so the communication cost is paid at most once.

Returns:
torch.FloatTensor | None: The full inputs_embeds tensor, or ``None`` if
``inputs_embeds`` is ``None``.
"""
if self._raw_inputs_embeds is not None:
return self._raw_inputs_embeds
if self.inputs_embeds is None:
return None
if self.sequence_parallel_mesh is None or self.sequence_parallel_mesh.size() == 1:
return self.inputs_embeds
gathered = gather_for_sequence_parallel(
self.inputs_embeds, dim=1, sp_group=self.sequence_parallel_mesh.get_group()
)
self._raw_inputs_embeds = cast(torch.FloatTensor, gathered)
return self._raw_inputs_embeds

@property
def raw_position_ids(self) -> torch.LongTensor | None:
"""Full (un-split) position_ids across all SP ranks.

Returns:
torch.LongTensor | None: The full position_ids tensor.
"""
raise NotImplementedError("raw_position_ids is not yet implemented")

@property
def raw_rollout_routed_experts(self) -> torch.Tensor | None:
"""Full (un-split) rollout_routed_experts across all SP ranks.

Returns:
torch.Tensor | None: The full rollout_routed_experts tensor.
"""
raise NotImplementedError("raw_rollout_routed_experts is not yet implemented")

# TODO: 暂时没有用到,可能要删掉
def chunk(self, num_chunks: int) -> list[Self]:
n = self.seq_lens_q.numel()
Expand Down Expand Up @@ -374,6 +458,10 @@ def copy(self, **overrides) -> Self:
inputs_embeds=overrides.get("inputs_embeds", self.inputs_embeds),
num_img_tokens=overrides.get("num_img_tokens", self.num_img_tokens),
rollout_routed_experts=overrides.get("rollout_routed_experts", self.rollout_routed_experts),
raw_input_ids=overrides.get("raw_input_ids", self._raw_input_ids),
raw_inputs_embeds=overrides.get("raw_inputs_embeds", self._raw_inputs_embeds),
shard_start=overrides.get("shard_start", self._shard_start),
shard_size=overrides.get("shard_size", self._shard_size),
)

def to(self, device: torch.device | str):
Expand Down
20 changes: 19 additions & 1 deletion xtuner/v1/loss/mtp_loss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.distributed.device_mesh import DeviceMesh

from xtuner.v1.loss.ce_loss import CELossConfig, CELossKwargs, LMHeadLossContext
from xtuner.v1.module.mtp.utils import roll_packed_tensor
from xtuner.v1.utils.device import get_device


Expand Down Expand Up @@ -63,12 +63,30 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex
MTPLossContext | None: Built loss context, or ``None`` if
``shifted_labels`` is not present in ``data``.
"""
# TODO: Should move the common utils function to public package to avoid from circular import.
from xtuner.v1.module.mtp.utils import roll_packed_tensor

if "shifted_labels" not in data:
return None

shifted_labels = data["shifted_labels"]
cu_seq_lens = data["seq_ctx"].cu_seq_lens_k

# cu_seq_lens[-1] may be larger than shifted_labels.shape[-1] when seq_ctx
# was split for sequence parallelism (padding is added to make the sequence
# length a multiple of sp_size). Pad with -100 so roll_packed_tensor does
# not go out of bounds.
padded_len = int(cu_seq_lens[-1].item())
seq_len = shifted_labels.shape[-1]
if padded_len > seq_len:
pad = torch.full(
(*shifted_labels.shape[:-1], padded_len - seq_len),
fill_value=-100,
dtype=shifted_labels.dtype,
device=shifted_labels.device,
)
shifted_labels = torch.cat([shifted_labels, pad], dim=-1)

rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100)

loss_kwargs = MTPLossKwargs(shifted_labels=rolled).to(DEVICE)
Expand Down
37 changes: 22 additions & 15 deletions xtuner/v1/module/mtp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,24 +100,31 @@ def roll_sequence_context(
Original input_ids: [1, 2, 3, 4, 5, 6]
Rolled input_ids: [2, 3, 0, 5, 6, 0]
"""
assert seq_ctx.sequence_parallel_mesh is None, "Sequence parallel is not yet supported"
sp_mesh = seq_ctx.sequence_parallel_mesh
is_sp = sp_mesh is not None and sp_mesh.size() > 1

overrides: dict = {}

if seq_ctx.input_ids is not None:
overrides["input_ids"] = roll_packed_tensor(
tensor=seq_ctx.input_ids,
cu_seq_lens=seq_ctx.cu_seq_lens_q,
shifts=shifts,
dim=-1,
)

if seq_ctx.inputs_embeds is not None:
overrides["inputs_embeds"] = roll_packed_tensor(
tensor=seq_ctx.inputs_embeds,
cu_seq_lens=seq_ctx.cu_seq_lens_q,
shifts=shifts,
dim=-2, # Embedding dimension is typically the second to last
raw_input_ids = seq_ctx.raw_input_ids
if raw_input_ids is not None:
rolled = roll_packed_tensor(tensor=raw_input_ids, cu_seq_lens=seq_ctx.cu_seq_lens_q, shifts=shifts, dim=-1)
overrides["raw_input_ids"] = rolled
if is_sp:
s = seq_ctx._shard_start
overrides["input_ids"] = rolled[:, s : s + seq_ctx._shard_size]
else:
overrides["input_ids"] = rolled

raw_inputs_embeds = seq_ctx.raw_inputs_embeds
if raw_inputs_embeds is not None:
rolled_e = roll_packed_tensor(
tensor=raw_inputs_embeds, cu_seq_lens=seq_ctx.cu_seq_lens_q, shifts=shifts, dim=-2
)
overrides["raw_inputs_embeds"] = rolled_e
if is_sp:
s = seq_ctx._shard_start
overrides["inputs_embeds"] = rolled_e[:, s : s + seq_ctx._shard_size]
else:
overrides["inputs_embeds"] = rolled_e

return seq_ctx.copy(**overrides)
Loading