Skip to content

Commit 5c1a4cb

Browse files
committed
unify parallel entrypoint
1 parent 5a8e530 commit 5c1a4cb

6 files changed

Lines changed: 41 additions & 25 deletions

File tree

lightx2v_train/lightx2v_train/runtime/ddp.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,15 @@ def apply_ddp(model, config):
100100
logger.info("DP(DDP) skipped for {} because the denoiser has no trainable parameters.", model.__class__.__name__)
101101
return model
102102

103-
wrapped = LightX2VDistributedDataParallel(denoiser, **_ddp_kwargs(config))
103+
ddp_kwargs = _ddp_kwargs(config)
104+
wrapped = LightX2VDistributedDataParallel(denoiser, **ddp_kwargs)
104105
if getattr(model, "transformer", None) is not denoiser:
105106
raise RuntimeError(f"{model.__class__.__name__} must store its trainable denoiser in self.transformer to use DP(DDP).")
106107
model.transformer = wrapped
107108
logger.info(
108109
"DP(DDP) transformer wrapped: broadcast_buffers={} find_unused_parameters={} static_graph={}",
109-
wrapped.broadcast_buffers,
110-
wrapped.find_unused_parameters,
111-
wrapped.static_graph,
110+
ddp_kwargs["broadcast_buffers"],
111+
ddp_kwargs["find_unused_parameters"],
112+
ddp_kwargs["static_graph"],
112113
)
113114
return model

lightx2v_train/lightx2v_train/runtime/fsdp.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
from loguru import logger
33
from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard
44

5-
from lightx2v_train.runtime.ddp import ddp_enabled
65
from lightx2v_train.runtime.distributed import get_device_mesh, is_distributed
76
from lightx2v_train.utils.utils import get_running_dtype
87

98

109
def fsdp2_enabled(config):
11-
if ddp_enabled(config):
12-
return False
1310
fsdp_config = config.get("distributed", {}).get("fsdp2", {})
1411
return is_distributed() and fsdp_config.get("enabled", True)
1512

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from loguru import logger
2+
3+
from lightx2v_train.runtime.ddp import apply_ddp, ddp_enabled, set_ddp_gradient_sync
4+
from lightx2v_train.runtime.distributed import is_distributed
5+
from lightx2v_train.runtime.fsdp import apply_fsdp2, fsdp2_enabled
6+
7+
8+
def apply_parallel(model, config):
9+
"""Apply the configured distributed parallel strategy exactly once."""
10+
11+
if not is_distributed():
12+
return model
13+
14+
if ddp_enabled(config):
15+
return apply_ddp(model, config)
16+
17+
if fsdp2_enabled(config):
18+
return apply_fsdp2(model, config)
19+
20+
logger.warning("Distributed training is initialized, but neither DP(DDP) nor FSDP2 is enabled. The model will run without distributed wrapping.")
21+
return model
22+
23+
24+
def set_parallel_gradient_sync(model, enabled):
25+
model.set_fsdp2_gradient_sync(enabled)
26+
set_ddp_gradient_sync(model.denoiser_module(), enabled)

lightx2v_train/lightx2v_train/trainers/base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99

1010
from lightx2v_train.infer import build_inferencer
1111
from lightx2v_train.runtime.checkpoint import find_latest_checkpoint, parse_checkpoint_iteration, prune_checkpoints
12-
from lightx2v_train.runtime.ddp import apply_ddp, set_ddp_gradient_sync
1312
from lightx2v_train.runtime.distributed import barrier, get_world_size, is_main_process
14-
from lightx2v_train.runtime.fsdp import apply_fsdp2
13+
from lightx2v_train.runtime.parallel import apply_parallel, set_parallel_gradient_sync
1514
from lightx2v_train.schedulers.flow_matching import RectifiedFlowMatchingScheduler
1615
from lightx2v_train.utils.utils import get_running_dtype
1716

@@ -121,13 +120,11 @@ def _build_lr_scheduler(self, optimizer, num_training_steps=None, num_warmup_ste
121120
def setup(self, resume_ckpt_path=None):
122121
self._setup_trainable_model(self.model)
123122

124-
apply_fsdp2(self.model, self.config)
123+
apply_parallel(self.model, self.config)
125124

126125
if self.gradient_checkpointing:
127126
self.model.enable_gradient_checkpointing()
128127

129-
apply_ddp(self.model, self.config)
130-
131128
if self.infer_every_iters:
132129
self.inferencer = build_inferencer(self.config)
133130
self.inferencer.set_model(self.model)
@@ -225,8 +222,7 @@ def _resolve_resume(self):
225222
return ckpt_path, current_iter
226223

227224
def _set_gradient_sync(self, enabled):
228-
self.model.set_fsdp2_gradient_sync(enabled)
229-
set_ddp_gradient_sync(self.model.denoiser_module(), enabled)
225+
set_parallel_gradient_sync(self.model, enabled)
230226

231227
def run_inference(self, current_iter):
232228
base_output_dir = self.infer_config.get("output_dir", "./output_infer")

lightx2v_train/lightx2v_train/trainers/dmd.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212

1313
from lightx2v_train.model_zoo import build_model
1414
from lightx2v_train.runtime.checkpoint import prune_checkpoints
15-
from lightx2v_train.runtime.ddp import apply_ddp, set_ddp_gradient_sync
1615
from lightx2v_train.runtime.distributed import barrier, get_world_size, is_distributed, is_main_process, reduce_mean
17-
from lightx2v_train.runtime.fsdp import apply_fsdp2
16+
from lightx2v_train.runtime.parallel import apply_parallel, set_parallel_gradient_sync
1817
from lightx2v_train.schedulers import DMDFlowMatchingScheduler
1918
from lightx2v_train.schedulers.flow_matching import CausalForcingFlowMatchScheduler
2019
from lightx2v_train.utils.registry import TRAINER_REGISTER
@@ -71,10 +70,9 @@ def setup(self, resume_ckpt_path=None):
7170
self.fake_model = build_model(fake_model_config)
7271
self.fake_model.load_components(transformer_only=True, reference_model=self.model)
7372
self._setup_trainable_model(self.fake_model)
74-
apply_fsdp2(self.fake_model, self.config)
73+
apply_parallel(self.fake_model, self.config)
7574
if self.gradient_checkpointing:
7675
self.fake_model.enable_gradient_checkpointing()
77-
apply_ddp(self.fake_model, self.config)
7876

7977
teacher_model_config = copy.deepcopy(self.config)
8078
teacher_model_config["model"] = copy.deepcopy(base_model_config)
@@ -83,7 +81,7 @@ def setup(self, resume_ckpt_path=None):
8381
self.teacher_model.load_components(transformer_only=True, reference_model=self.model)
8482
self.teacher_model.transformer.requires_grad_(False)
8583
self.teacher_model.transformer.eval()
86-
apply_fsdp2(self.teacher_model, self.config)
84+
apply_parallel(self.teacher_model, self.config)
8785
self.teacher_model.transformer.eval()
8886

8987
self.fake_trainable_params = list(self.fake_model.trainable_parameters())
@@ -445,12 +443,10 @@ def train(self):
445443
logger.info("[train] finished iter={}/{}", current_iter, max_train_iters)
446444

447445
def _set_student_gradient_sync(self, enabled):
448-
self.model.set_fsdp2_gradient_sync(enabled)
449-
set_ddp_gradient_sync(self.model.denoiser_module(), enabled)
446+
set_parallel_gradient_sync(self.model, enabled)
450447

451448
def _set_fake_gradient_sync(self, enabled):
452-
self.fake_model.set_fsdp2_gradient_sync(enabled)
453-
set_ddp_gradient_sync(self.fake_model.denoiser_module(), enabled)
449+
set_parallel_gradient_sync(self.fake_model, enabled)
454450

455451
def _set_gradient_sync(self, enabled):
456452
self._set_student_gradient_sync(enabled)

lightx2v_train/lightx2v_train/trainers/dopsd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from lightx2v_train.infer.dopsd_trajectory_viz import save_student_teacher_trajectory_grid
1717
from lightx2v_train.runtime.checkpoint import find_latest_checkpoint, parse_checkpoint_iteration, prune_checkpoints
1818
from lightx2v_train.runtime.distributed import barrier, get_rank, get_world_size, is_distributed, is_main_process, reduce_mean
19-
from lightx2v_train.runtime.fsdp import apply_fsdp2
19+
from lightx2v_train.runtime.parallel import apply_parallel
2020
from lightx2v_train.utils.registry import TRAINER_REGISTER
2121
from lightx2v_train.utils.utils import get_running_dtype
2222

@@ -102,7 +102,7 @@ def setup(self, resume_ckpt_path=None):
102102
)
103103
self.model.set_dual_lora_trainable(self.student_adapter, self.teacher_adapter)
104104

105-
apply_fsdp2(self.model, self.config)
105+
apply_parallel(self.model, self.config)
106106

107107
if self.gradient_checkpointing:
108108
self.model.enable_gradient_checkpointing()

0 commit comments

Comments
 (0)