Skip to content

Commit 3daf20e

Browse files
committed
[Feature] Add SP support for roll_sequence_context in MTP
- Add raw_input_ids, raw_inputs_embeds, raw_position_ids, raw_rollout_routed_experts properties to SequenceContext for reconstructing full tensors from SP shards - Store raw_input_ids (full padded tensor), shard_start, shard_size in SequenceContext.split() for zero-communication input_ids rolling - raw_inputs_embeds triggers a single allgather on first access and caches the result, amortising communication across MTP layers - roll_sequence_context: remove SP assert; always operate on full tensors via raw_* properties, slice to local shard only when in SP ghstack-source-id: 79251cf Pull-Request: #1629
1 parent bdb0b00 commit 3daf20e

File tree

4 files changed

+215
-34
lines changed

4 files changed

+215
-34
lines changed

tests/model/test_qwen3_5.py

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig
1515
from xtuner.v1.config import FSDPConfig
1616
from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh
17+
from xtuner.v1.data_proto.utils import pad_to_multiple_of
18+
1719

1820
import tempfile
1921
from pathlib import Path
@@ -82,13 +84,28 @@ def _forward(self, model, type, device, sp_size):
8284
position_ids = tokenized_data['position_ids'].cuda()
8385
else:
8486
tokenizer = AutoTokenizer.from_pretrained(QWEN3_VL_MOE_PATH)
85-
input_ids = tokenizer(f"今天天气不错,是学习的好日子。请听题: 1+1 等于多少?",
86-
return_tensors="pt").input_ids.to(device)
87-
labels = input_ids.clone()
87+
tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_MOE_PATH, rand_video_max_frames=14,
88+
add_vision_id=True).build(tokenizer)
89+
raw_data = {
90+
"id": 3, "messages": [
91+
{
92+
"role": "user", "content": [
93+
{
94+
"type": "text",
95+
"text": "Translate this into chinese: Where my eyes gaze, only memories remain; where my heart strays, only yesterday's pain; where my sight stays, only regret's refrain."
96+
}
97+
]
98+
},
99+
{"role": "assistant", "content": "目之所及,唯余旧忆;心之所向,唯余昨痛;眸之所留,唯余悔声。"}
100+
]
101+
}
102+
tokenized_data = tokenize_fn(raw_data)
103+
input_ids = torch.tensor(tokenized_data['input_ids'])[None].cuda()
104+
labels = torch.tensor(tokenized_data['labels'])[None].cuda()
88105
pixel_values = None
89106
image_grid_thw = None
90107
position_ids = None
91-
108+
92109
from transformers import Qwen3_5MoeForConditionalGeneration
93110
is_hf_model = isinstance(model, Qwen3_5MoeForConditionalGeneration)
94111

@@ -115,10 +132,8 @@ def _forward(self, model, type, device, sp_size):
115132
dist.all_reduce(output.loss.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
116133
return output.loss
117134
else:
118-
loss_cfg = CELossConfig()
119-
120-
shift_input_ids = input_ids[:, :-1]
121-
shifted_labels = labels[:, 1:]
135+
shift_input_ids = pad_to_multiple_of(input_ids[:, :-1], padding_value=0, multiple_of=sp_size)
136+
shifted_labels = pad_to_multiple_of(labels[:, 1:], padding_value=-100, multiple_of=sp_size)
122137
if position_ids is not None:
123138
position_ids = position_ids[..., :-1]
124139

@@ -127,22 +142,18 @@ def _forward(self, model, type, device, sp_size):
127142
data_mesh = init_data_mesh(device, sp_size=sp_size)
128143
sp_mesh = data_mesh["sp"]
129144

130-
seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),))
145+
seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to("cuda"),))
131146
seq_ctx.image_grid_thw = image_grid_thw
132147
seq_ctx.pixel_values = pixel_values
133148
if position_ids is not None:
134149
seq_ctx.position_ids = position_ids
135-
seq_ctx.to('cuda')
150+
seq_ctx.to("cuda")
136151
if sp_size > 1:
137152
seq_ctx = seq_ctx.split(sp_mesh)
138153

139-
seq_ctx_list = [seq_ctx]
140-
LossContext = loss_cfg.loss_ctx_cls
141-
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh)
142-
loss_ctx_list = [loss_ctx]
143-
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
144-
loss_ctx = loss_ctx_list[0]
145-
seq_ctx = seq_ctx_list[0]
154+
data_batch = [{"seq_ctx": seq_ctx, "shifted_labels": shifted_labels}]
155+
loss_ctx_batch = model.build_loss_ctx_batch(data_batch, sp_mesh=sp_mesh)
156+
loss_ctx = loss_ctx_batch[0]
146157

147158
with torch.no_grad():
148159
output = model(
@@ -221,6 +232,63 @@ def test_qwen3_5_vl_run(self, device, sp_size, tol):
221232
self.assertTrue(torch.allclose(loss_xtuner_image_fsdp, loss_xtuner_image, atol=tol, rtol=tol))
222233
self.assertTrue(torch.allclose(loss_xtuner_video_fsdp, loss_xtuner_video, atol=tol, rtol=tol))
223234

235+
@parametrize.parametrize(
236+
"device,sp_size,tol",
237+
[
238+
("cuda", 1, 1e-2),
239+
("cuda", 4, 1e-2),
240+
],
241+
)
242+
def test_qwen3_5_vl_run_mtp(self, device, sp_size, tol):
243+
self.create_pg(device)
244+
loss_reference = {
245+
"text": 1.5416,
246+
"image": 3.6920,
247+
"video": 8.2165,
248+
}
249+
250+
QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"]
251+
252+
torch.cuda.empty_cache()
253+
254+
with torch.device("meta"):
255+
model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False)
256+
model_cfg.text_config.mtp_config = MTPConfig(num_layers=1, loss_scaling_factor=1)
257+
qwen3vl_model = model_cfg.build().to(torch.bfloat16)
258+
259+
qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH)
260+
qwen3vl_model.eval()
261+
262+
losses = {}
263+
264+
loss_xtuner_text = self._forward(qwen3vl_model, type="text", device=device, sp_size=sp_size)
265+
self.assertFalse(torch.isnan(loss_xtuner_text), "MTP text loss should not be NaN")
266+
267+
loss_xtuner_image = self._forward(qwen3vl_model, type="image", device=device, sp_size=sp_size)
268+
self.assertFalse(torch.isnan(loss_xtuner_image), "MTP image loss should not be NaN")
269+
270+
loss_xtuner_video = self._forward(qwen3vl_model, type="video", device=device, sp_size=sp_size)
271+
self.assertFalse(torch.isnan(loss_xtuner_video), "MTP video loss should not be NaN")
272+
273+
losses["text"] = loss_xtuner_text
274+
losses["image"] = loss_xtuner_image
275+
losses["video"] = loss_xtuner_video
276+
277+
for key, loss in losses.items():
278+
self.assertTrue(
279+
torch.allclose(
280+
loss, torch.tensor(
281+
loss_reference[key],
282+
device=loss_xtuner_text.device,
283+
dtype=loss_xtuner_text.dtype
284+
),
285+
atol=tol,
286+
rtol=tol
287+
),
288+
f"Expected text loss around {key}, but got {loss.item()}"
289+
)
290+
291+
224292
@parametrize.parametrize(
225293
"device,sp_size",
226294
[

xtuner/v1/data_proto/sequence_context.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.distributed.device_mesh import DeviceMesh
66
from typing_extensions import Self
77

8-
from .utils import pad_to_multiple_of, split_for_sequence_parallel
8+
from .utils import gather_for_sequence_parallel, pad_to_multiple_of, split_for_sequence_parallel
99

1010

1111
# Avoid using dataclass decorator here to get rid of extra ops called in pytorch 2.8 and above
@@ -50,6 +50,12 @@ class SequenceContext:
5050
# moe routed_experts
5151
rollout_routed_experts: torch.Tensor | None
5252

53+
# Private backing attributes for SP shard reconstruction
54+
_raw_input_ids: torch.LongTensor | None
55+
_raw_inputs_embeds: torch.FloatTensor | None
56+
_shard_start: int
57+
_shard_size: int
58+
5359
def __init__(
5460
self,
5561
input_ids: torch.LongTensor | None, # shape (1, seq_len)
@@ -71,6 +77,11 @@ def __init__(
7177
inputs_embeds: torch.FloatTensor | None = None,
7278
num_img_tokens: list[list[int]] | None = None,
7379
rollout_routed_experts: torch.Tensor | None = None,
80+
# SP shard metadata: private, accessed via properties below
81+
raw_input_ids: torch.LongTensor | None = None,
82+
raw_inputs_embeds: torch.FloatTensor | None = None,
83+
shard_start: int = 0,
84+
shard_size: int = 0,
7485
):
7586
# Only to distinguish parameters accepted by the constructor from attributes. For example, for `max_length_q`,
7687
# the argument can be an int, but as an attribute it can only be a tensor
@@ -99,6 +110,10 @@ def __init__(
99110
self.inputs_embeds = inputs_embeds
100111
self.num_img_tokens = num_img_tokens
101112
self.rollout_routed_experts = rollout_routed_experts
113+
self._raw_input_ids = raw_input_ids
114+
self._raw_inputs_embeds = raw_inputs_embeds
115+
self._shard_start = shard_start
116+
self._shard_size = shard_size
102117
self.seq_idx = None
103118

104119
seq_lens_k = self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1]
@@ -169,6 +184,7 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self:
169184
start = sp_input_ids.shape[1] * sequence_parallel_mesh.get_local_rank()
170185
end = start + sp_input_ids.shape[1]
171186
sp_num_padding = max(0, min(sp_input_ids.shape[1], end - num_non_padding))
187+
shard_size = sp_input_ids.shape[1]
172188

173189
if self.position_ids is not None:
174190
pad_position_ids = pad_to_multiple_of(self.position_ids, 0, multiple_of, -1)
@@ -205,6 +221,9 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self:
205221
inputs_embeds=self.inputs_embeds,
206222
num_img_tokens=self.num_img_tokens,
207223
rollout_routed_experts=self.rollout_routed_experts,
224+
raw_input_ids=cast(torch.LongTensor, pad_input_ids),
225+
shard_start=start,
226+
shard_size=shard_size,
208227
)
209228
return sp_seq_ctx
210229
else:
@@ -308,6 +327,71 @@ def seq_lens_q(self) -> torch.LongTensor:
308327
def seq_lens_k(self) -> torch.LongTensor:
309328
return self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] # type: ignore
310329

330+
@property
331+
def raw_input_ids(self) -> torch.LongTensor | None:
332+
"""Full (un-split) input_ids across all SP ranks.
333+
334+
In non-SP mode, returns ``input_ids`` directly. In SP mode, returns the
335+
pre-stored full tensor if available; otherwise triggers an allgather and
336+
caches the result for subsequent calls.
337+
338+
Returns:
339+
torch.LongTensor | None: The full input_ids tensor, or ``None`` if
340+
``input_ids`` is ``None``.
341+
"""
342+
if self._raw_input_ids is not None:
343+
return self._raw_input_ids
344+
if self.sequence_parallel_mesh is None or self.sequence_parallel_mesh.size() == 1:
345+
return self.input_ids
346+
assert self.input_ids is not None
347+
gathered = gather_for_sequence_parallel(
348+
self.input_ids, dim=1, sp_group=self.sequence_parallel_mesh.get_group()
349+
)
350+
self._raw_input_ids = cast(torch.LongTensor, gathered)
351+
return self._raw_input_ids
352+
353+
@property
354+
def raw_inputs_embeds(self) -> torch.FloatTensor | None:
355+
"""Full (un-split) inputs_embeds across all SP ranks.
356+
357+
In non-SP mode, returns ``inputs_embeds`` directly. In SP mode, triggers
358+
a single allgather on first access and caches the result for subsequent
359+
calls, so the communication cost is paid at most once.
360+
361+
Returns:
362+
torch.FloatTensor | None: The full inputs_embeds tensor, or ``None`` if
363+
``inputs_embeds`` is ``None``.
364+
"""
365+
if self._raw_inputs_embeds is not None:
366+
return self._raw_inputs_embeds
367+
if self.inputs_embeds is None:
368+
return None
369+
if self.sequence_parallel_mesh is None or self.sequence_parallel_mesh.size() == 1:
370+
return self.inputs_embeds
371+
gathered = gather_for_sequence_parallel(
372+
self.inputs_embeds, dim=1, sp_group=self.sequence_parallel_mesh.get_group()
373+
)
374+
self._raw_inputs_embeds = cast(torch.FloatTensor, gathered)
375+
return self._raw_inputs_embeds
376+
377+
@property
378+
def raw_position_ids(self) -> torch.LongTensor | None:
379+
"""Full (un-split) position_ids across all SP ranks.
380+
381+
Returns:
382+
torch.LongTensor | None: The full position_ids tensor.
383+
"""
384+
raise NotImplementedError("raw_position_ids is not yet implemented")
385+
386+
@property
387+
def raw_rollout_routed_experts(self) -> torch.Tensor | None:
388+
"""Full (un-split) rollout_routed_experts across all SP ranks.
389+
390+
Returns:
391+
torch.Tensor | None: The full rollout_routed_experts tensor.
392+
"""
393+
raise NotImplementedError("raw_rollout_routed_experts is not yet implemented")
394+
311395
# TODO: 暂时没有用到,可能要删掉
312396
def chunk(self, num_chunks: int) -> list[Self]:
313397
n = self.seq_lens_q.numel()
@@ -374,6 +458,10 @@ def copy(self, **overrides) -> Self:
374458
inputs_embeds=overrides.get("inputs_embeds", self.inputs_embeds),
375459
num_img_tokens=overrides.get("num_img_tokens", self.num_img_tokens),
376460
rollout_routed_experts=overrides.get("rollout_routed_experts", self.rollout_routed_experts),
461+
raw_input_ids=overrides.get("raw_input_ids", self._raw_input_ids),
462+
raw_inputs_embeds=overrides.get("raw_inputs_embeds", self._raw_inputs_embeds),
463+
shard_start=overrides.get("shard_start", self._shard_start),
464+
shard_size=overrides.get("shard_size", self._shard_size),
377465
)
378466

379467
def to(self, device: torch.device | str):

xtuner/v1/loss/mtp_loss.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
23
from torch.distributed.device_mesh import DeviceMesh
34

45
from xtuner.v1.loss.ce_loss import CELossConfig, CELossKwargs, LMHeadLossContext
5-
from xtuner.v1.module.mtp.utils import roll_packed_tensor
66
from xtuner.v1.utils.device import get_device
77

88

@@ -63,12 +63,30 @@ def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContex
6363
MTPLossContext | None: Built loss context, or ``None`` if
6464
``shifted_labels`` is not present in ``data``.
6565
"""
66+
# TODO: Should move the common utils function to public package to avoid from circular import.
67+
from xtuner.v1.module.mtp.utils import roll_packed_tensor
68+
6669
if "shifted_labels" not in data:
6770
return None
6871

6972
shifted_labels = data["shifted_labels"]
7073
cu_seq_lens = data["seq_ctx"].cu_seq_lens_k
7174

75+
# cu_seq_lens[-1] may be larger than shifted_labels.shape[-1] when seq_ctx
76+
# was split for sequence parallelism (padding is added to make the sequence
77+
# length a multiple of sp_size). Pad with -100 so roll_packed_tensor does
78+
# not go out of bounds.
79+
padded_len = int(cu_seq_lens[-1].item())
80+
seq_len = shifted_labels.shape[-1]
81+
if padded_len > seq_len:
82+
pad = torch.full(
83+
(*shifted_labels.shape[:-1], padded_len - seq_len),
84+
fill_value=-100,
85+
dtype=shifted_labels.dtype,
86+
device=shifted_labels.device,
87+
)
88+
shifted_labels = torch.cat([shifted_labels, pad], dim=-1)
89+
7290
rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100)
7391

7492
loss_kwargs = MTPLossKwargs(shifted_labels=rolled).to(DEVICE)

xtuner/v1/module/mtp/utils.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,24 +100,31 @@ def roll_sequence_context(
100100
Original input_ids: [1, 2, 3, 4, 5, 6]
101101
Rolled input_ids: [2, 3, 0, 5, 6, 0]
102102
"""
103-
assert seq_ctx.sequence_parallel_mesh is None, "Sequence parallel is not yet supported"
103+
sp_mesh = seq_ctx.sequence_parallel_mesh
104+
is_sp = sp_mesh is not None and sp_mesh.size() > 1
104105

105106
overrides: dict = {}
106107

107-
if seq_ctx.input_ids is not None:
108-
overrides["input_ids"] = roll_packed_tensor(
109-
tensor=seq_ctx.input_ids,
110-
cu_seq_lens=seq_ctx.cu_seq_lens_q,
111-
shifts=shifts,
112-
dim=-1,
113-
)
114-
115-
if seq_ctx.inputs_embeds is not None:
116-
overrides["inputs_embeds"] = roll_packed_tensor(
117-
tensor=seq_ctx.inputs_embeds,
118-
cu_seq_lens=seq_ctx.cu_seq_lens_q,
119-
shifts=shifts,
120-
dim=-2, # Embedding dimension is typically the second to last
108+
raw_input_ids = seq_ctx.raw_input_ids
109+
if raw_input_ids is not None:
110+
rolled = roll_packed_tensor(tensor=raw_input_ids, cu_seq_lens=seq_ctx.cu_seq_lens_q, shifts=shifts, dim=-1)
111+
overrides["raw_input_ids"] = rolled
112+
if is_sp:
113+
s = seq_ctx._shard_start
114+
overrides["input_ids"] = rolled[:, s : s + seq_ctx._shard_size]
115+
else:
116+
overrides["input_ids"] = rolled
117+
118+
raw_inputs_embeds = seq_ctx.raw_inputs_embeds
119+
if raw_inputs_embeds is not None:
120+
rolled_e = roll_packed_tensor(
121+
tensor=raw_inputs_embeds, cu_seq_lens=seq_ctx.cu_seq_lens_q, shifts=shifts, dim=-2
121122
)
123+
overrides["raw_inputs_embeds"] = rolled_e
124+
if is_sp:
125+
s = seq_ctx._shard_start
126+
overrides["inputs_embeds"] = rolled_e[:, s : s + seq_ctx._shard_size]
127+
else:
128+
overrides["inputs_embeds"] = rolled_e
122129

123130
return seq_ctx.copy(**overrides)

0 commit comments

Comments
 (0)