88import torch
99import torch .distributed as dist
1010import torch .distributed .checkpoint as dcp
11- from pydantic import ConfigDict
1211from safetensors import safe_open
1312from torch .distributed .checkpoint .state_dict import (
1413 StateDictOptions ,
2120from torch .utils ._foreach_utils import (
2221 _device_has_foreach_support ,
2322)
24- from typing_extensions import NotRequired , TypedDict
2523
2624from xtuner .v1 .config import FSDPConfig , OptimConfig
2725from xtuner .v1 .data_proto .sequence_context import SequenceContext
28- from xtuner .v1 .model .base import BaseModel , ModelItem , XTunerBaseModelConfig
29- from xtuner .v1 .model .utils import ModelForwardExtraLogInfo
30- from xtuner .v1 .module .router import NoAuxRouterConfig
26+ from xtuner .v1 .model .base import (
27+ BaseModel ,
28+ BatchForwardInfo ,
29+ DataBatchInfo ,
30+ ModelItem ,
31+ ModelOutputs ,
32+ XTunerBaseModelConfig ,
33+ )
3134from xtuner .v1 .profiler .prober import ProberList
3235from xtuner .v1 .utils import get_device , get_logger , get_torch_device_module , profile_time_and_memory
3336from xtuner .v1 .utils .grad_norm import cal_grad_norm
3437
3538
39+ class TrainStepInfo (DataBatchInfo , BatchForwardInfo ):
40+ total_loss : float
41+
42+
3643logger = get_logger ()
3744DEVICE = get_device ()
3845DEVICE_MODULE = get_torch_device_module ()
3946
4047threading_lock = threading .Lock ()
4148
4249
43- class LossLog (TypedDict ):
44- __pydantic_config__ = ConfigDict (arbitrary_types_allowed = True ) # type: ignore[misc]
45- local_loss : float
46- reduced_llm_loss : float
47- reduced_balancing_loss : NotRequired [float ]
48- reduced_z_loss : NotRequired [float ]
49-
50-
51- class OtherLog (TypedDict ):
52- __pydantic_config__ = ConfigDict (arbitrary_types_allowed = True ) # type: ignore[misc]
53- maxvio : NotRequired [float ]
54- step_consumed_tokens : int
55- step_consumed_img_tokens : NotRequired [int ]
56- extra_info : ModelForwardExtraLogInfo
57- efficient_attn_ratio : float
58-
59-
6050class CPUThreadTaskCoordinator :
6151 def __init__ (self , futures , callback ):
6252 self .futures = futures
@@ -206,66 +196,36 @@ def grad_accumulation_steps(self, data_batches_len: int):
206196 intra_layer_micro_batch = self .intra_layer_micro_batch
207197 return data_batches_len // intra_layer_micro_batch
208198
209- def train_step (self , data_batches : list [ModelItem ]) -> tuple [ LossLog , OtherLog ] :
199+ def train_step (self , data_batches : list [ModelItem ]) -> TrainStepInfo :
210200 """Perform a training step with the given data batches and mesh.
211201
212202 Args:
213203 data_batches (List[Dict]): The input data batches for the training step.
214204 """
215205 self ._maybe_precompute_float8_dynamic_scale_for_fsdp ()
216206
217- loss_log : LossLog = {} # type: ignore[typeddict-item]
218- other_log : OtherLog = {} # type: ignore[typeddict-item]
219207 intra_layer_micro_batch = self .intra_layer_micro_batch
220208 assert len (data_batches ) % intra_layer_micro_batch == 0 , (
221209 f"data_batches length { len (data_batches )} is not divisible by intra_layer_micro_batch { intra_layer_micro_batch } "
222210 )
223211 iters_per_step = self .grad_accumulation_steps (len (data_batches ))
224212
225- moe_need_update_bias = (
226- isinstance (getattr (self .model_cfg , "router" , None ), NoAuxRouterConfig )
227- and self .model_cfg .router .router_bias_update_speed > 0
228- )
229- moe_need_log_maxvio = getattr (self .model_cfg , "router" , None ) is not None
230-
231- if moe_need_log_maxvio :
232- tokens_per_expert_global_for_bias = torch .zeros (
233- self .model_cfg .num_hidden_layers - self .model_cfg .first_k_dense_replace ,
234- self .model_cfg .n_routed_experts ,
235- dtype = torch .int64 ,
236- device = DEVICE ,
237- )
238-
239- step_loss = torch .tensor (0.0 , device = DEVICE )
240- step_llm_loss = torch .tensor (0.0 , device = DEVICE )
241- step_balancing_loss : torch .Tensor | None = None
242- step_z_loss : torch .Tensor | None = None
243- step_consumed_tokens = torch .tensor (0 , device = DEVICE )
244-
245213 if self ._count == 0 :
246214 logger .info (f"grad_accumulation_steps: { iters_per_step } " )
247215 self ._count += 1
248216
249- train_engine_extra_info = ModelForwardExtraLogInfo ()
250217 micro_batch_iter = 0
251- efficient_forward_tokens = torch .tensor (0 , device = DEVICE , dtype = torch .long )
252- total_forward_tokens = torch .tensor (0 , device = DEVICE , dtype = torch .long )
218+ micro_batch_results = []
219+
220+ data_batch_info = self .model .pre_micro_batch_forward (data_batches )
221+ total_loss = torch .tensor (0.0 , device = DEVICE )
222+
253223 for i in range (0 , len (data_batches ), intra_layer_micro_batch ):
254224 ProberList .set_micro_batch_iter (micro_batch_iter )
255225 micro_batch_iter += 1
256226 data_batch = data_batches [i : i + intra_layer_micro_batch ]
257- seq_ctx_list = []
258- loss_ctx_list = []
259- for data in data_batch :
260- seq_ctx = data ["seq_ctx" ]
261- loss_ctx = data ["loss_ctx" ]
262- seq_ctx_list .append (seq_ctx )
263- loss_ctx_list .append (loss_ctx )
264- step_consumed_tokens += seq_ctx .mask .sum ()
265-
266- num_tokens = seq_ctx .cu_seq_lens_k [1 :] - seq_ctx .cu_seq_lens_k [:- 1 ]
267- efficient_forward_tokens += (num_tokens .long () ** 2 ).sum ()
268- total_forward_tokens += (num_tokens .long ().sum ()) ** 2
227+ seq_ctx_list = [i ["seq_ctx" ] for i in data_batch ]
228+ loss_ctx_list = [i ["loss_ctx" ] for i in data_batch ]
269229
270230 if self .intra_layer_micro_batch == 1 :
271231 output = self .model (seq_ctx = seq_ctx_list [0 ], loss_ctx = loss_ctx_list [0 ])
@@ -278,78 +238,16 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
278238 )
279239 output .free_nongrad_feature ()
280240
281- # llm loss has been global averaged
282- llm_loss = output ["loss" ]
283- step_llm_loss += llm_loss .detach ().clone ()
284-
285- loss = llm_loss
286- if "extra_info" in output :
287- train_engine_extra_info .append (output ["extra_info" ])
288-
289- if "balancing_loss" in output :
290- balancing_loss = output ["balancing_loss" ] / iters_per_step
291- loss = loss + balancing_loss
292- if step_balancing_loss is None :
293- step_balancing_loss = balancing_loss
294- else :
295- step_balancing_loss += balancing_loss
296-
297- if "z_loss" in output :
298- z_loss = output ["z_loss" ] / iters_per_step
299- loss = loss + z_loss
241+ micro_batch_results .append (output )
300242
301- if step_z_loss is None :
302- step_z_loss = z_loss
303- else :
304- step_z_loss += z_loss
305-
306- if moe_need_log_maxvio :
307- assert "tokens_per_expert_global" in output , "tokens_per_expert_global is required for bias update."
308- tokens_per_expert_global_for_bias += output ["tokens_per_expert_global" ]
309-
310- del output
243+ loss = self ._get_total_loss (output )
311244 loss .backward ()
245+ total_loss += loss .detach ()
312246 # call dump_forward_records after backward to record the recomputed activations
313247 ProberList .after_micro_iter_forward ()
314- step_loss += loss .detach ().clone ()
315-
316- if moe_need_log_maxvio :
317- avg_count_load = tokens_per_expert_global_for_bias .float ().mean (1 )
318- max_load_i , _ = torch .max (tokens_per_expert_global_for_bias , dim = 1 )
319- maxvio_all_layers = (max_load_i - avg_count_load ) / avg_count_load
320- maxvio = maxvio_all_layers .mean ()
321- if moe_need_update_bias :
322- self .model .update_bias (tokens_per_expert_global_for_bias , avg_count_load ) # type: ignore
323- other_log ["maxvio" ] = maxvio .item ()
324-
325- reduced_llm_loss = step_llm_loss
326- dist .all_reduce (reduced_llm_loss .div_ (dist .get_world_size ()))
327-
328- loss_log ["local_loss" ] = step_loss .item ()
329- loss_log ["reduced_llm_loss" ] = reduced_llm_loss .item ()
330- if step_balancing_loss is not None :
331- reduced_balancing_loss = step_balancing_loss
332- dist .all_reduce (reduced_balancing_loss .div_ (dist .get_world_size ()))
333- loss_log ["reduced_balancing_loss" ] = reduced_balancing_loss .item ()
334- if step_z_loss is not None :
335- reduced_z_loss = step_z_loss
336- dist .all_reduce (reduced_z_loss .div_ (dist .get_world_size ()))
337- loss_log ["reduced_z_loss" ] = reduced_z_loss .item ()
338- other_log ["step_consumed_tokens" ] = int (step_consumed_tokens .item ())
339- other_log ["extra_info" ] = train_engine_extra_info
340- other_log ["efficient_attn_ratio" ] = (efficient_forward_tokens / total_forward_tokens ).item ()
341-
342- extra_info = other_log .get ("extra_info" , {}) # type: ignore
343-
344- # TODO: @duanyanhui `extra_info` should be redesigned.
345- if not isinstance (extra_info , ModelForwardExtraLogInfo ):
346- extra_info = ModelForwardExtraLogInfo (extra_info )
347- loss_log .update (extra_info .get ())
348-
349- if "maxvio" in other_log :
350- loss_log ["maxvio" ] = other_log ["maxvio" ] # type: ignore
351- loss_log ["efficient_attn_ratio" ] = other_log ["efficient_attn_ratio" ] # type: ignore
352- return loss_log , other_log
248+
249+ batch_forward_info = self .model .post_micro_batch_forward (micro_batch_results )
250+ return TrainStepInfo (total_loss = total_loss .item (), ** data_batch_info , ** batch_forward_info )
353251
354252 def from_hf (self , hf_path : str | Path , strict : bool = False ):
355253 self .model .from_hf (hf_path = hf_path , strict = strict )
@@ -529,3 +427,17 @@ def _maybe_precompute_float8_dynamic_scale_for_fsdp(self):
529427 for model in self .model .modules ():
530428 if isinstance (model , BaseModel ) and model .float8_handler is not None :
531429 model .float8_handler .precompute_float8_dynamic_scale_for_fsdp (model )
430+
431+ def _get_total_loss (self , model_outputs : ModelOutputs ) -> torch .Tensor :
432+ # TODO: This logic should be moved into the model layer. The model should be responsible
433+ # for aggregating all losses (CE loss, balancing loss, z loss, etc.) and returning a
434+ # single total_loss. The engine should only call model.forward() and use the returned
435+ # total_loss directly, rather than iterating through fields to sum losses here.
436+ # This would provide better separation of concerns and make the loss computation logic
437+ # more explicit and maintainable.
438+ loss = torch .tensor (0.0 , device = DEVICE )
439+ for key in model_outputs .model_fields :
440+ value = getattr (model_outputs , key )
441+ if "loss" in key and isinstance (value , torch .Tensor ):
442+ loss += value
443+ return loss
0 commit comments