Skip to content

Commit 7bb0058

Browse files
committed
[Refactor] Rename CELossContext to LMHeadLossContext and refactor loss context base class
- Rename CELossContext to LMHeadLossContext for better semantic clarity - Refactor BaseLossContext to be more abstract by removing LM-specific logic - Move eager_mode and chunk_mode implementations from base class to LMHeadLossContext - Make loss_ctx_cls and _loss_kwargs_cls abstract properties in BaseLossConfig - Remove sp_split() and to() implementations from BaseLossKwargs base class - Move sp_split() and to() to CELossKwargs subclass - Update BaseRLLossKwargs to properly inherit and extend sp_split() and to() methods - Add deprecation alias: CELossContext = LMHeadLossContext for backward compatibility - Export LMHeadLossContext in __init__.py ghstack-source-id: 1b3d648 Pull-Request: #1571
1 parent a725a14 commit 7bb0058

File tree

9 files changed

+114
-123
lines changed

9 files changed

+114
-123
lines changed

xtuner/v1/loss/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .base_loss_ctx import BaseLossConfig, BaseLossContext, BaseLossKwargs
2-
from .ce_loss import CELossConfig, CELossContext
2+
from .ce_loss import CELossConfig, CELossContext, LMHeadLossContext
33
from .chunk_loss import ChunkLoss
44
from .moe_loss import (
55
BalancingLoss,

xtuner/v1/loss/base_loss_ctx.py

Lines changed: 12 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,10 @@
33
from typing import Annotated, Any, Literal, TypeVar
44

55
import torch
6-
import torch.distributed as dist
76
import torch.nn as nn
87
from cyclopts import Parameter
98
from pydantic import BaseModel, ConfigDict
109
from torch.distributed.device_mesh import DeviceMesh
11-
from torch.distributed.nn.functional import all_reduce
12-
from typing_extensions import Self
13-
14-
from xtuner.v1.loss.utils import sp_split
15-
16-
from .chunk_loss import ChunkLoss
1710

1811

1912
# Do loss calibration among dp, sp and grad accumulation:
@@ -46,18 +39,13 @@
4639

4740

4841
class BaseLossKwargs(BaseModel):
49-
"""Everything needed to compute the loss."""
50-
51-
model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True)
52-
shifted_labels: torch.Tensor
42+
"""Everything needed to compute the loss.
5343
54-
def sp_split(self, sp_mesh: DeviceMesh) -> Self:
55-
self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100)
56-
return self
44+
Subclasses should implement sp_split() and to() methods if they contain tensors that need to be split across
45+
sequence parallel mesh or moved to device.
46+
"""
5747

58-
def to(self, device: torch.device | str) -> Self:
59-
self.shifted_labels = self.shifted_labels.to(device)
60-
return self
48+
model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True)
6149

6250
def chunk(self, chunk_size) -> list["BaseLossKwargs"]:
6351
tensor_fields: dict[str, tuple[torch.Tensor, ...]] = {}
@@ -114,10 +102,13 @@ class BaseLossConfig(BaseModel):
114102
chunk_size: Annotated[int | None, Parameter(help="chunk size when mode is chunk")] = 1024
115103

116104
@property
105+
@abstractmethod
117106
def loss_ctx_cls(self) -> type["BaseLossContext"]:
118107
raise NotImplementedError
119108

109+
# TODO: private property maybe not a good idea
120110
@property
111+
@abstractmethod
121112
def _loss_kwargs_cls(self) -> type["BaseLossKwargs"]:
122113
raise NotImplementedError
123114

@@ -160,72 +151,10 @@ def __init__(self, loss_cfg: BaseLossConfig, loss_kwargs: BaseLossKwargs):
160151
self._batch_size = 1
161152

162153
@staticmethod
163-
@abstractmethod
164-
def build_batches(loss_ctx_list: list[_BaseLossContextT], *args, **kwargs) -> list[_BaseLossContextT]: ...
165-
166-
@abstractmethod
167-
def loss_fn(
168-
self,
169-
hidden_states: torch.Tensor,
170-
head_weight: torch.Tensor,
171-
head_bias: torch.Tensor | None,
172-
loss_kwargs: BaseLossKwargs,
173-
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
174-
"""Step 2.a and 2.b in the loss calculation."""
175-
...
176-
177-
def eager_mode(
178-
self,
179-
hidden_states: torch.Tensor,
180-
head_weight: torch.Tensor,
181-
head_bias: torch.Tensor | None,
182-
loss_kwargs: BaseLossKwargs,
183-
):
184-
return self.loss_fn(hidden_states, head_weight, head_bias, loss_kwargs)
185-
186-
def chunk_mode(
187-
self,
188-
hidden_states: torch.Tensor,
189-
head_weight: torch.Tensor,
190-
head_bias: torch.Tensor | None,
191-
loss_kwargs: BaseLossKwargs,
192-
):
193-
assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode"
194-
195-
chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size)
196-
loss, extra_info = ChunkLoss.apply(
197-
hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size
198-
)
199-
return loss, (None, extra_info)
200-
201-
def forward(
202-
self,
203-
hidden_states: torch.Tensor,
204-
head_weight: torch.Tensor,
205-
head_bias: torch.Tensor | None = None,
206-
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
207-
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
208-
209-
assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward"
210-
if head_bias is not None:
211-
raise NotImplementedError("Loss does not support head_bias yet.")
212-
213-
if self.loss_cfg.mode == "eager":
214-
loss, (logits, extra_info) = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
215-
else:
216-
loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
217-
218-
# TODO: yanhuida, should be removed
219-
if not isinstance(extra_info, ModelForwardExtraLogInfo):
220-
extra_info = ModelForwardExtraLogInfo(extra_info)
221-
222-
extra_info["local_base_loss"] = loss.detach().clone()
223-
224-
# Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support
225-
if dist.is_initialized():
226-
loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD)
227-
228-
return loss, (logits, extra_info)
154+
def build_batches(loss_ctx_list: list[_BaseLossContextT], *args, **kwargs) -> list[_BaseLossContextT]:
155+
for ctx in loss_ctx_list:
156+
ctx._batch_size = len(loss_ctx_list)
157+
return loss_ctx_list
229158

230159
@classmethod
231160
def cat(cls: type[_BaseLossContextT], chunks: list[_BaseLossContextT]) -> _BaseLossContextT:

xtuner/v1/loss/ce_loss.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import torch.nn.functional as F
77
from cyclopts import Parameter
88
from torch.distributed.device_mesh import DeviceMesh
9+
from torch.distributed.nn.functional import all_reduce
910

1011
from xtuner.v1.loss import BaseLossConfig, BaseLossContext, BaseLossKwargs
12+
from xtuner.v1.loss.chunk_loss import ChunkLoss
1113
from xtuner.v1.utils.device import get_device
1214

1315
# from xtuner.v1.profiler.prober import ProberList
@@ -37,7 +39,11 @@ class CELossConfig(BaseLossConfig):
3739
def loss_ctx_cls(self) -> type["CELossContext"]:
3840
return CELossContext
3941

40-
def model_post_init(self, __context: Any) -> None:
42+
@property
43+
def _loss_kwargs_cls(self) -> type["CELossKwargs"]:
44+
return CELossKwargs
45+
46+
def model_post_init(self, _context: Any) -> None:
4147
if self.mode == "liger":
4248
assert self.loss_reduction == "token", "Currently, cannot use liger kernel with sample or square reduction"
4349

@@ -80,8 +86,16 @@ class CELossKwargs(BaseLossKwargs):
8086
shifted_labels: torch.Tensor
8187
loss_weight: torch.Tensor | None = None
8288

89+
def sp_split(self, sp_mesh: DeviceMesh) -> "CELossKwargs":
90+
self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100)
91+
return self
92+
93+
def to(self, device: torch.device | str) -> "CELossKwargs":
94+
self.shifted_labels = self.shifted_labels.to(device)
95+
return self
96+
8397

84-
class CELossContext(BaseLossContext):
98+
class LMHeadLossContext(BaseLossContext):
8599
"""Cross-entropy loss context for language models.
86100
87101
Args:
@@ -163,6 +177,7 @@ def build_batches( # type: ignore[override]
163177

164178
for loss_ctx in loss_ctx_list:
165179
loss_ctx._batch_size = len(loss_ctx_list)
180+
assert loss_ctx.loss_kwargs.loss_weight is not None
166181
loss_ctx.loss_kwargs.loss_weight /= global_denominator + 1e-12
167182
return loss_ctx_list
168183

@@ -195,15 +210,30 @@ def loss_fn(
195210

196211
return loss, (logits, {})
197212

213+
def eager_mode(
214+
self,
215+
hidden_states: torch.Tensor,
216+
head_weight: torch.Tensor,
217+
head_bias: torch.Tensor | None,
218+
loss_kwargs: CELossKwargs,
219+
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
220+
return self.loss_fn(hidden_states, head_weight, head_bias, loss_kwargs)
221+
198222
def chunk_mode(
199223
self,
200224
hidden_states: torch.Tensor,
201225
head_weight: torch.Tensor,
202226
head_bias: torch.Tensor | None,
203227
loss_kwargs: CELossKwargs,
204-
):
228+
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
205229
if self.loss_cfg.mode == "chunk":
206-
return super().chunk_mode(hidden_states, head_weight, head_bias, loss_kwargs)
230+
assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode"
231+
232+
chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size)
233+
loss, extra_info = ChunkLoss.apply(
234+
hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size
235+
)
236+
return loss, (None, extra_info)
207237
else:
208238
assert self.liger_loss_fct is not None, "liger_loss_fct must be initialized in liger mode"
209239
shifted_labels = loss_kwargs.shifted_labels # (bs, seq_len)
@@ -225,3 +255,36 @@ def chunk_mode(
225255
@property
226256
def batch_size(self) -> int:
227257
return self._batch_size
258+
259+
def forward(
260+
self,
261+
hidden_states: torch.Tensor,
262+
head_weight: torch.Tensor,
263+
head_bias: torch.Tensor | None = None,
264+
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
265+
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
266+
267+
assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward"
268+
if head_bias is not None:
269+
raise NotImplementedError("Loss does not support head_bias yet.")
270+
271+
if self.loss_cfg.mode == "eager":
272+
loss, (logits, extra_info) = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
273+
else:
274+
loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
275+
276+
# TODO: yanhuida, should be removed
277+
if not isinstance(extra_info, ModelForwardExtraLogInfo):
278+
extra_info = ModelForwardExtraLogInfo(extra_info)
279+
280+
extra_info["local_base_loss"] = loss.detach().clone()
281+
282+
# Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support
283+
if dist.is_initialized():
284+
loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD)
285+
286+
return loss, (logits, extra_info)
287+
288+
289+
# Deprecated: Use LMHeadLossContext instead. Will be removed in version 1.1.0
290+
CELossContext = LMHeadLossContext

xtuner/v1/loss/moe_loss.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from typing import Annotated, Any, Literal
1+
from typing import Annotated, Literal
22

33
import torch
44
import torch.nn as nn
55
from cyclopts import Parameter
66
from pydantic import BaseModel, ConfigDict
77
from torch import distributed as dist
88
from torch.distributed._functional_collectives import all_reduce
9-
from torch.distributed.device_mesh import DeviceMesh
109

1110
from xtuner.v1.utils.device import get_device
1211

@@ -223,7 +222,7 @@ def forward(
223222

224223
tokens_per_expert_global = tokens_per_expert.to(router_weights.dtype) # (nlayers, ne)
225224
if self.loss_cfg.balancing_loss_global_average and dist.is_initialized():
226-
tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD)
225+
tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD) # type: ignore
227226
tokens_global = tokens_per_expert_global.sum(-1) # (nlayers, )
228227
seqlen_global = tokens_global // num_experts_per_tok
229228
routing_weights_sum_global = all_reduce_autograd(router_weights.sum(dim=1), "sum", dist.group.WORLD)
@@ -327,7 +326,9 @@ def forward(self, router_logits: torch.Tensor) -> torch.Tensor:
327326
if self.loss_cfg.z_loss_global_average and dist.is_initialized():
328327
unmasked_num = router_logits.shape[1]
329328
unmasked_num_rank = torch.tensor(unmasked_num, device=router_logits.device, dtype=torch.int64)
330-
unmasked_num_global = all_reduce(unmasked_num_rank, "sum", dist.group.WORLD)
329+
group = dist.group.WORLD
330+
assert group is not None
331+
unmasked_num_global = all_reduce(unmasked_num_rank, "sum", group)
331332
world_size = dist.get_world_size()
332333
loss = loss * unmasked_num * world_size / unmasked_num_global
333334

xtuner/v1/model/base.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from itertools import chain
99
from pathlib import Path
1010
from shutil import copy, copytree
11-
from typing import Annotated, Any, Generator, Iterable, Literal, Mapping, Sequence, TypedDict, cast
11+
from typing import Annotated, Any, Generator, Iterable, Literal, Mapping, Sequence, cast
1212

1313
import torch
1414
import torch.distributed as dist
@@ -699,7 +699,8 @@ def build_loss_ctx_batch(
699699
if lm_loss_ctx_list is not None:
700700
loss_ctx_cls = lm_loss_ctx_list[0].__class__
701701
lm_loss_ctx_list = loss_ctx_cls.build_batches(
702-
lm_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh)
702+
lm_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh
703+
)
703704

704705
if lm_loss_ctx_list is not None:
705706
for i, lm_loss_ctx in enumerate(lm_loss_ctx_list):
@@ -1864,10 +1865,7 @@ def _collect_full_state_dict(self, module: nn.Module):
18641865
return ret
18651866

18661867
def _build_loss_ctx(
1867-
self,
1868-
loss_ctx_cfg: BaseLossConfig | None,
1869-
data_batch: list[dict],
1870-
sp_mesh: DeviceMesh | None
1868+
self, loss_ctx_cfg: BaseLossConfig | None, data_batch: list[dict], sp_mesh: DeviceMesh | None
18711869
) -> list[BaseLossContext] | None:
18721870
if loss_ctx_cfg is None:
18731871
return None
@@ -1878,9 +1876,9 @@ def _build_loss_ctx(
18781876
if first_loss_ctx is None:
18791877
return None
18801878
else:
1881-
ret = [first_loss_ctx] + [
1882-
loss_ctx_cfg.build(data=data, sp_mesh=sp_mesh) for data in data_batch[1:]]
1883-
return ret
1879+
ret = [first_loss_ctx] + [loss_ctx_cfg.build(data=data, sp_mesh=sp_mesh) for data in data_batch[1:]]
1880+
return ret # type: ignore[return-value]
1881+
18841882
# NOTE: Add this overload for inferring the return type for easier type checking and using
18851883
@overload # type: ignore
18861884
def __call__( # type: ignore

xtuner/v1/model/dense/dense.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from pathlib import Path
3-
from typing import Self, cast, Literal
3+
from typing import Self, cast
44

55
import torch
66
import torch.distributed as dist
@@ -19,7 +19,7 @@
1919
from xtuner.v1.config import FSDPConfig
2020
from xtuner.v1.data_proto import SequenceContext
2121
from xtuner.v1.float8.float8_handler import Float8Handler
22-
from xtuner.v1.loss import CELossContext, BaseLossContext
22+
from xtuner.v1.loss import BaseLossContext, CELossContext
2323
from xtuner.v1.model.base import (
2424
DEFAULT_FLOAT8_CFG,
2525
BaseModel,
@@ -78,7 +78,7 @@ def __init__(self, config: TransformerConfig):
7878
def forward(
7979
self,
8080
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
81-
loss_ctx: dict[Literal["lm"], BaseLossContext] | None = None,
81+
loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None,
8282
) -> ModelOutputs:
8383
input_ids = seq_ctx.input_ids
8484
position_ids = seq_ctx.position_ids
@@ -116,7 +116,7 @@ def forward(
116116
output["logits"] = logits
117117
else:
118118
# Training mode
119-
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"])
119+
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload]
120120
output["loss"] = loss
121121
output["logits"] = logits
122122
output["extra_info"] = extra_info

0 commit comments

Comments
 (0)