1212
1313from lightx2v_train .model_zoo import build_model
1414from lightx2v_train .runtime .checkpoint import prune_checkpoints
15- from lightx2v_train .runtime .ddp import apply_ddp , set_ddp_gradient_sync
1615from lightx2v_train .runtime .distributed import barrier , get_world_size , is_distributed , is_main_process , reduce_mean
17- from lightx2v_train .runtime .fsdp import apply_fsdp2
16+ from lightx2v_train .runtime .parallel import apply_parallel , set_parallel_gradient_sync
1817from lightx2v_train .schedulers import DMDFlowMatchingScheduler
1918from lightx2v_train .schedulers .flow_matching import CausalForcingFlowMatchScheduler
2019from lightx2v_train .utils .registry import TRAINER_REGISTER
@@ -71,10 +70,9 @@ def setup(self, resume_ckpt_path=None):
7170 self .fake_model = build_model (fake_model_config )
7271 self .fake_model .load_components (transformer_only = True , reference_model = self .model )
7372 self ._setup_trainable_model (self .fake_model )
74- apply_fsdp2 (self .fake_model , self .config )
73+ apply_parallel (self .fake_model , self .config )
7574 if self .gradient_checkpointing :
7675 self .fake_model .enable_gradient_checkpointing ()
77- apply_ddp (self .fake_model , self .config )
7876
7977 teacher_model_config = copy .deepcopy (self .config )
8078 teacher_model_config ["model" ] = copy .deepcopy (base_model_config )
@@ -83,7 +81,7 @@ def setup(self, resume_ckpt_path=None):
8381 self .teacher_model .load_components (transformer_only = True , reference_model = self .model )
8482 self .teacher_model .transformer .requires_grad_ (False )
8583 self .teacher_model .transformer .eval ()
86- apply_fsdp2 (self .teacher_model , self .config )
84+ apply_parallel (self .teacher_model , self .config )
8785 self .teacher_model .transformer .eval ()
8886
8987 self .fake_trainable_params = list (self .fake_model .trainable_parameters ())
@@ -445,12 +443,10 @@ def train(self):
445443 logger .info ("[train] finished iter={}/{}" , current_iter , max_train_iters )
446444
447445 def _set_student_gradient_sync (self , enabled ):
448- self .model .set_fsdp2_gradient_sync (enabled )
449- set_ddp_gradient_sync (self .model .denoiser_module (), enabled )
446+ set_parallel_gradient_sync (self .model , enabled )
450447
451448 def _set_fake_gradient_sync (self , enabled ):
452- self .fake_model .set_fsdp2_gradient_sync (enabled )
453- set_ddp_gradient_sync (self .fake_model .denoiser_module (), enabled )
449+ set_parallel_gradient_sync (self .fake_model , enabled )
454450
455451 def _set_gradient_sync (self , enabled ):
456452 self ._set_student_gradient_sync (enabled )
0 commit comments