From d22e97617f4e0b8f6a1eeb892cc7106e8f07ec51 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 24 Mar 2026 12:05:20 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tests/model/test_qwen3_5.py | 102 +++++++++++++++++++---- xtuner/v1/data_proto/sequence_context.py | 90 +++++++++++++++++++- xtuner/v1/loss/mtp_loss.py | 4 +- xtuner/v1/module/mtp/utils.py | 37 ++++---- 4 files changed, 199 insertions(+), 34 deletions(-) diff --git a/tests/model/test_qwen3_5.py b/tests/model/test_qwen3_5.py index bc94ddd28..761d301b8 100644 --- a/tests/model/test_qwen3_5.py +++ b/tests/model/test_qwen3_5.py @@ -14,6 +14,8 @@ from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig from xtuner.v1.config import FSDPConfig from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh +from xtuner.v1.data_proto.utils import pad_to_multiple_of + import tempfile from pathlib import Path @@ -82,13 +84,28 @@ def _forward(self, model, type, device, sp_size): position_ids = tokenized_data['position_ids'].cuda() else: tokenizer = AutoTokenizer.from_pretrained(QWEN3_VL_MOE_PATH) - input_ids = tokenizer(f"今天天气不错,是学习的好日子。请听题: 1+1 等于多少?", - return_tensors="pt").input_ids.to(device) - labels = input_ids.clone() + tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_MOE_PATH, rand_video_max_frames=14, + add_vision_id=True).build(tokenizer) + raw_data = { + "id": 3, "messages": [ + { + "role": "user", "content": [ + { + "type": "text", + "text": "Translate this into chinese: Where my eyes gaze, only memories remain; where my heart strays, only yesterday's pain; where my sight stays, only regret's refrain." + } + ] + }, + {"role": "assistant", "content": "目之所及,唯余旧忆;心之所向,唯余昨痛;眸之所留,唯余悔声。"} + ] + } + tokenized_data = tokenize_fn(raw_data) + input_ids = torch.tensor(tokenized_data['input_ids'])[None].cuda() + labels = torch.tensor(tokenized_data['labels'])[None].cuda() pixel_values = None image_grid_thw = None position_ids = None - + from transformers import Qwen3_5MoeForConditionalGeneration is_hf_model = isinstance(model, Qwen3_5MoeForConditionalGeneration) @@ -115,10 +132,8 @@ def _forward(self, model, type, device, sp_size): dist.all_reduce(output.loss.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) return output.loss else: - loss_cfg = CELossConfig() - - shift_input_ids = input_ids[:, :-1] - shifted_labels = labels[:, 1:] + shift_input_ids = pad_to_multiple_of(input_ids[:, :-1], padding_value=0, multiple_of=sp_size) + shifted_labels = pad_to_multiple_of(labels[:, 1:], padding_value=-100, multiple_of=sp_size) if position_ids is not None: position_ids = position_ids[..., :-1] @@ -127,22 +142,18 @@ def _forward(self, model, type, device, sp_size): data_mesh = init_data_mesh(device, sp_size=sp_size) sp_mesh = data_mesh["sp"] - seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),)) + seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to("cuda"),)) seq_ctx.image_grid_thw = image_grid_thw seq_ctx.pixel_values = pixel_values if position_ids is not None: seq_ctx.position_ids = position_ids - seq_ctx.to('cuda') + seq_ctx.to("cuda") if sp_size > 1: seq_ctx = seq_ctx.split(sp_mesh) - seq_ctx_list = [seq_ctx] - LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh) - loss_ctx_list = [loss_ctx] - loss_ctx_list = LossContext.build_batches(loss_ctx_list) - loss_ctx = loss_ctx_list[0] - seq_ctx = seq_ctx_list[0] + data_batch = [{"seq_ctx": seq_ctx, "shifted_labels": shifted_labels}] + loss_ctx_batch = model.build_loss_ctx_batch(data_batch, sp_mesh=sp_mesh) + loss_ctx = loss_ctx_batch[0] with torch.no_grad(): output = model( @@ -221,6 +232,63 @@ def test_qwen3_5_vl_run(self, device, sp_size, tol): self.assertTrue(torch.allclose(loss_xtuner_image_fsdp, loss_xtuner_image, atol=tol, rtol=tol)) self.assertTrue(torch.allclose(loss_xtuner_video_fsdp, loss_xtuner_video, atol=tol, rtol=tol)) + @parametrize.parametrize( + "device,sp_size,tol", + [ + ("cuda", 1, 1e-2), + ("cuda", 4, 1e-2), + ], + ) + def test_qwen3_5_vl_run_mtp(self, device, sp_size, tol): + self.create_pg(device) + loss_reference = { + "text": 1.5416, + "image": 3.6920, + "video": 8.2165, + } + + QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"] + + torch.cuda.empty_cache() + + with torch.device("meta"): + model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False) + model_cfg.text_config.mtp_config = MTPConfig(num_layers=1, loss_scaling_factor=1) + qwen3vl_model = model_cfg.build().to(torch.bfloat16) + + qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH) + qwen3vl_model.eval() + + losses = {} + + loss_xtuner_text = self._forward(qwen3vl_model, type="text", device=device, sp_size=sp_size) + self.assertFalse(torch.isnan(loss_xtuner_text), "MTP text loss should not be NaN") + + loss_xtuner_image = self._forward(qwen3vl_model, type="image", device=device, sp_size=sp_size) + self.assertFalse(torch.isnan(loss_xtuner_image), "MTP image loss should not be NaN") + + loss_xtuner_video = self._forward(qwen3vl_model, type="video", device=device, sp_size=sp_size) + self.assertFalse(torch.isnan(loss_xtuner_video), "MTP video loss should not be NaN") + + losses["text"] = loss_xtuner_text + losses["image"] = loss_xtuner_image + losses["video"] = loss_xtuner_video + + for key, loss in losses.items(): + self.assertTrue( + torch.allclose( + loss, torch.tensor( + loss_reference[key], + device=loss_xtuner_text.device, + dtype=loss_xtuner_text.dtype + ), + atol=tol, + rtol=tol + ), + f"Expected text loss around {key}, but got {loss.item()}" + ) + + @parametrize.parametrize( "device,sp_size", [ diff --git a/xtuner/v1/data_proto/sequence_context.py b/xtuner/v1/data_proto/sequence_context.py index 69da8d045..5a6722b8d 100644 --- a/xtuner/v1/data_proto/sequence_context.py +++ b/xtuner/v1/data_proto/sequence_context.py @@ -5,7 +5,7 @@ from torch.distributed.device_mesh import DeviceMesh from typing_extensions import Self -from .utils import pad_to_multiple_of, split_for_sequence_parallel +from .utils import gather_for_sequence_parallel, pad_to_multiple_of, split_for_sequence_parallel # Avoid using dataclass decorator here to get rid of extra ops called in pytorch 2.8 and above @@ -50,6 +50,12 @@ class SequenceContext: # moe routed_experts rollout_routed_experts: torch.Tensor | None + # Private backing attributes for SP shard reconstruction + _raw_input_ids: torch.LongTensor | None + _raw_inputs_embeds: torch.FloatTensor | None + _shard_start: int + _shard_size: int + def __init__( self, input_ids: torch.LongTensor | None, # shape (1, seq_len) @@ -71,6 +77,11 @@ def __init__( inputs_embeds: torch.FloatTensor | None = None, num_img_tokens: list[list[int]] | None = None, rollout_routed_experts: torch.Tensor | None = None, + # SP shard metadata: private, accessed via properties below + raw_input_ids: torch.LongTensor | None = None, + raw_inputs_embeds: torch.FloatTensor | None = None, + shard_start: int = 0, + shard_size: int = 0, ): # Only to distinguish parameters accepted by the constructor from attributes. For example, for `max_length_q`, # the argument can be an int, but as an attribute it can only be a tensor @@ -99,6 +110,10 @@ def __init__( self.inputs_embeds = inputs_embeds self.num_img_tokens = num_img_tokens self.rollout_routed_experts = rollout_routed_experts + self._raw_input_ids = raw_input_ids + self._raw_inputs_embeds = raw_inputs_embeds + self._shard_start = shard_start + self._shard_size = shard_size self.seq_idx = None seq_lens_k = self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] @@ -169,6 +184,7 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self: start = sp_input_ids.shape[1] * sequence_parallel_mesh.get_local_rank() end = start + sp_input_ids.shape[1] sp_num_padding = max(0, min(sp_input_ids.shape[1], end - num_non_padding)) + shard_size = sp_input_ids.shape[1] if self.position_ids is not None: pad_position_ids = pad_to_multiple_of(self.position_ids, 0, multiple_of, -1) @@ -205,6 +221,9 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self: inputs_embeds=self.inputs_embeds, num_img_tokens=self.num_img_tokens, rollout_routed_experts=self.rollout_routed_experts, + raw_input_ids=cast(torch.LongTensor, pad_input_ids), + shard_start=start, + shard_size=shard_size, ) return sp_seq_ctx else: @@ -308,6 +327,71 @@ def seq_lens_q(self) -> torch.LongTensor: def seq_lens_k(self) -> torch.LongTensor: return self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] # type: ignore + @property + def raw_input_ids(self) -> torch.LongTensor | None: + """Full (un-split) input_ids across all SP ranks. + + In non-SP mode, returns ``input_ids`` directly. In SP mode, returns the + pre-stored full tensor if available; otherwise triggers an allgather and + caches the result for subsequent calls. + + Returns: + torch.LongTensor | None: The full input_ids tensor, or ``None`` if + ``input_ids`` is ``None``. + """ + if self._raw_input_ids is not None: + return self._raw_input_ids + if self.sequence_parallel_mesh is None or self.sequence_parallel_mesh.size() == 1: + return self.input_ids + assert self.input_ids is not None + gathered = gather_for_sequence_parallel( + self.input_ids, dim=1, sp_group=self.sequence_parallel_mesh.get_group() + ) + self._raw_input_ids = cast(torch.LongTensor, gathered) + return self._raw_input_ids + + @property + def raw_inputs_embeds(self) -> torch.FloatTensor | None: + """Full (un-split) inputs_embeds across all SP ranks. + + In non-SP mode, returns ``inputs_embeds`` directly. In SP mode, triggers + a single allgather on first access and caches the result for subsequent + calls, so the communication cost is paid at most once. + + Returns: + torch.FloatTensor | None: The full inputs_embeds tensor, or ``None`` if + ``inputs_embeds`` is ``None``. + """ + if self._raw_inputs_embeds is not None: + return self._raw_inputs_embeds + if self.inputs_embeds is None: + return None + if self.sequence_parallel_mesh is None or self.sequence_parallel_mesh.size() == 1: + return self.inputs_embeds + gathered = gather_for_sequence_parallel( + self.inputs_embeds, dim=1, sp_group=self.sequence_parallel_mesh.get_group() + ) + self._raw_inputs_embeds = cast(torch.FloatTensor, gathered) + return self._raw_inputs_embeds + + @property + def raw_position_ids(self) -> torch.LongTensor | None: + """Full (un-split) position_ids across all SP ranks. + + Returns: + torch.LongTensor | None: The full position_ids tensor. + """ + raise NotImplementedError("raw_position_ids is not yet implemented") + + @property + def raw_rollout_routed_experts(self) -> torch.Tensor | None: + """Full (un-split) rollout_routed_experts across all SP ranks. + + Returns: + torch.Tensor | None: The full rollout_routed_experts tensor. + """ + raise NotImplementedError("raw_rollout_routed_experts is not yet implemented") + # TODO: 暂时没有用到,可能要删掉 def chunk(self, num_chunks: int) -> list[Self]: n = self.seq_lens_q.numel() @@ -374,6 +458,10 @@ def copy(self, **overrides) -> Self: inputs_embeds=overrides.get("inputs_embeds", self.inputs_embeds), num_img_tokens=overrides.get("num_img_tokens", self.num_img_tokens), rollout_routed_experts=overrides.get("rollout_routed_experts", self.rollout_routed_experts), + raw_input_ids=overrides.get("raw_input_ids", self._raw_input_ids), + raw_inputs_embeds=overrides.get("raw_inputs_embeds", self._raw_inputs_embeds), + shard_start=overrides.get("shard_start", self._shard_start), + shard_size=overrides.get("shard_size", self._shard_size), ) def to(self, device: torch.device | str): diff --git a/xtuner/v1/loss/mtp_loss.py b/xtuner/v1/loss/mtp_loss.py index cce8db7b8..9b78a9b12 100644 --- a/xtuner/v1/loss/mtp_loss.py +++ b/xtuner/v1/loss/mtp_loss.py @@ -2,7 +2,6 @@ from torch.distributed.device_mesh import DeviceMesh from xtuner.v1.loss.ce_loss import CELossConfig, CELossKwargs, LMHeadLossContext -from xtuner.v1.module.mtp.utils import roll_packed_tensor from xtuner.v1.utils.device import get_device @@ -63,6 +62,9 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex MTPLossContext | None: Built loss context, or ``None`` if ``shifted_labels`` is not present in ``data``. """ + # TODO: Should move the common utils function to public package to avoid from circular import. + from xtuner.v1.module.mtp.utils import roll_packed_tensor + if "shifted_labels" not in data: return None diff --git a/xtuner/v1/module/mtp/utils.py b/xtuner/v1/module/mtp/utils.py index 9d3b4d91f..2c35623f0 100644 --- a/xtuner/v1/module/mtp/utils.py +++ b/xtuner/v1/module/mtp/utils.py @@ -100,24 +100,31 @@ def roll_sequence_context( Original input_ids: [1, 2, 3, 4, 5, 6] Rolled input_ids: [2, 3, 0, 5, 6, 0] """ - assert seq_ctx.sequence_parallel_mesh is None, "Sequence parallel is not yet supported" + sp_mesh = seq_ctx.sequence_parallel_mesh + is_sp = sp_mesh is not None and sp_mesh.size() > 1 overrides: dict = {} - if seq_ctx.input_ids is not None: - overrides["input_ids"] = roll_packed_tensor( - tensor=seq_ctx.input_ids, - cu_seq_lens=seq_ctx.cu_seq_lens_q, - shifts=shifts, - dim=-1, - ) - - if seq_ctx.inputs_embeds is not None: - overrides["inputs_embeds"] = roll_packed_tensor( - tensor=seq_ctx.inputs_embeds, - cu_seq_lens=seq_ctx.cu_seq_lens_q, - shifts=shifts, - dim=-2, # Embedding dimension is typically the second to last + raw_input_ids = seq_ctx.raw_input_ids + if raw_input_ids is not None: + rolled = roll_packed_tensor(tensor=raw_input_ids, cu_seq_lens=seq_ctx.cu_seq_lens_q, shifts=shifts, dim=-1) + overrides["raw_input_ids"] = rolled + if is_sp: + s = seq_ctx._shard_start + overrides["input_ids"] = rolled[:, s : s + seq_ctx._shard_size] + else: + overrides["input_ids"] = rolled + + raw_inputs_embeds = seq_ctx.raw_inputs_embeds + if raw_inputs_embeds is not None: + rolled_e = roll_packed_tensor( + tensor=raw_inputs_embeds, cu_seq_lens=seq_ctx.cu_seq_lens_q, shifts=shifts, dim=-2 ) + overrides["raw_inputs_embeds"] = rolled_e + if is_sp: + s = seq_ctx._shard_start + overrides["inputs_embeds"] = rolled_e[:, s : s + seq_ctx._shard_size] + else: + overrides["inputs_embeds"] = rolled_e return seq_ctx.copy(**overrides)