Skip to content

Commit 144365b

Browse files
committed
merge data process to training script
1 parent cb8de6b commit 144365b

File tree

6 files changed

+35
-147
lines changed

6 files changed

+35
-147
lines changed

diffsynth/trainers/utils.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,14 +520,26 @@ def launch_training_task(
520520
dataset: torch.utils.data.Dataset,
521521
model: DiffusionTrainingModule,
522522
model_logger: ModelLogger,
523-
optimizer: torch.optim.Optimizer,
524-
scheduler: torch.optim.lr_scheduler.LRScheduler,
523+
learning_rate: float = 1e-5,
524+
weight_decay: float = 1e-2,
525525
num_workers: int = 8,
526526
save_steps: int = None,
527527
num_epochs: int = 1,
528528
gradient_accumulation_steps: int = 1,
529529
find_unused_parameters: bool = False,
530+
args = None,
530531
):
532+
if args is not None:
533+
learning_rate = args.learning_rate
534+
weight_decay = args.weight_decay
535+
num_workers = args.dataset_num_workers
536+
save_steps = args.save_steps
537+
num_epochs = args.num_epochs
538+
gradient_accumulation_steps = args.gradient_accumulation_steps
539+
find_unused_parameters = args.find_unused_parameters
540+
541+
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
542+
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
531543
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
532544
accelerator = Accelerator(
533545
gradient_accumulation_steps=gradient_accumulation_steps,
@@ -557,8 +569,12 @@ def launch_data_process_task(
557569
model: DiffusionTrainingModule,
558570
model_logger: ModelLogger,
559571
num_workers: int = 8,
572+
args = None,
560573
):
561-
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
574+
if args is not None:
575+
num_workers = args.dataset_num_workers
576+
577+
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
562578
accelerator = Accelerator()
563579
model, dataloader = accelerator.prepare(model, dataloader)
564580

@@ -568,7 +584,7 @@ def launch_data_process_task(
568584
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
569585
os.makedirs(folder, exist_ok=True)
570586
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
571-
data = model(data)
587+
data = model(data, return_inputs=True)
572588
torch.save(data, save_path)
573589

574590

@@ -671,4 +687,5 @@ def qwen_image_parser():
671687
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
672688
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
673689
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
690+
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
674691
return parser

examples/flux/model_training/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525

2626
# Training mode
2727
self.switch_pipe_to_training_mode(
28-
self, self.pipe, trainable_models,
28+
self.pipe, trainable_models,
2929
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
3030
enable_fp8_training=False,
3131
)

examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
accelerate launch examples/qwen_image/model_training/train_data_process.py \
1+
accelerate launch examples/qwen_image/model_training/train.py \
22
--dataset_base_path data/example_image_dataset \
33
--dataset_metadata_path data/example_image_dataset/metadata.csv \
44
--max_pixels 1048576 \
55
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
66
--output_path "./models/train/Qwen-Image_lora_cache" \
77
--use_gradient_checkpointing \
8-
--dataset_num_workers 8
8+
--dataset_num_workers 8 \
9+
--task data_process
910

1011
accelerate launch examples/qwen_image/model_training/train.py \
1112
--dataset_base_path models/train/Qwen-Image_lora_cache \

examples/qwen_image/model_training/train.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from diffsynth import load_state_dict
33
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
44
from diffsynth.pipelines.flux_image_new import ControlNetInput
5-
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, qwen_image_parser
5+
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
66
from diffsynth.trainers.unified_dataset import UnifiedDataset
77
os.environ["TOKENIZERS_PARALLELISM"] = "false"
88

@@ -29,7 +29,7 @@ def __init__(
2929

3030
# Training mode
3131
self.switch_pipe_to_training_mode(
32-
self, self.pipe, trainable_models,
32+
self.pipe, trainable_models,
3333
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
3434
enable_fp8_training=enable_fp8_training,
3535
)
@@ -81,9 +81,10 @@ def forward_preprocess(self, data):
8181
return {**inputs_shared, **inputs_posi}
8282

8383

84-
def forward(self, data, inputs=None):
84+
def forward(self, data, inputs=None, return_inputs=False):
8585
if inputs is None: inputs = self.forward_preprocess(data)
8686
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
87+
if return_inputs: return inputs
8788
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
8889
loss = self.pipe.training_loss(**models, **inputs)
8990
return loss
@@ -123,13 +124,8 @@ def forward(self, data, inputs=None):
123124
enable_fp8_training=args.enable_fp8_training,
124125
)
125126
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
126-
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
127-
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
128-
launch_training_task(
129-
dataset, model, model_logger, optimizer, scheduler,
130-
num_epochs=args.num_epochs,
131-
gradient_accumulation_steps=args.gradient_accumulation_steps,
132-
save_steps=args.save_steps,
133-
find_unused_parameters=args.find_unused_parameters,
134-
num_workers=args.dataset_num_workers,
135-
)
127+
launcher_map = {
128+
"sft": launch_training_task,
129+
"data_process": launch_data_process_task
130+
}
131+
launcher_map[args.task](dataset, model, model_logger, args=args)

examples/qwen_image/model_training/train_data_process.py

Lines changed: 0 additions & 126 deletions
This file was deleted.

examples/wanvideo/model_training/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626

2727
# Training mode
2828
self.switch_pipe_to_training_mode(
29-
self, self.pipe, trainable_models,
29+
self.pipe, trainable_models,
3030
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
3131
enable_fp8_training=False,
3232
)

0 commit comments

Comments
 (0)