Skip to content

Commit cb8de6b

Browse files
committed
move training code to base trainer
1 parent 958ebf1 commit cb8de6b

File tree

5 files changed

+79
-129
lines changed

5 files changed

+79
-129
lines changed

diffsynth/trainers/utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import imageio, os, torch, warnings, torchvision, argparse, json
2+
from ..utils import ModelConfig
3+
from ..models.utils import load_state_dict
24
from peft import LoraConfig, inject_adapter_in_model
35
from PIL import Image
46
import pandas as pd
@@ -424,7 +426,53 @@ def transfer_data_to_device(self, data, device):
424426
if isinstance(data[key], torch.Tensor):
425427
data[key] = data[key].to(device)
426428
return data
427-
429+
430+
431+
def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False):
432+
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
433+
model_configs = []
434+
if model_paths is not None:
435+
model_paths = json.loads(model_paths)
436+
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
437+
if model_id_with_origin_paths is not None:
438+
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
439+
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
440+
return model_configs
441+
442+
443+
def switch_pipe_to_training_mode(
444+
self,
445+
pipe,
446+
trainable_models,
447+
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None,
448+
enable_fp8_training=False,
449+
):
450+
# Scheduler
451+
pipe.scheduler.set_timesteps(1000, training=True)
452+
453+
# Freeze untrainable models
454+
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
455+
456+
# Enable FP8 if pipeline supports
457+
if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"):
458+
pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
459+
460+
# Add LoRA to the base models
461+
if lora_base_model is not None:
462+
model = self.add_lora_to_model(
463+
getattr(pipe, lora_base_model),
464+
target_modules=lora_target_modules.split(","),
465+
lora_rank=lora_rank,
466+
upcast_dtype=pipe.torch_dtype,
467+
)
468+
if lora_checkpoint is not None:
469+
state_dict = load_state_dict(lora_checkpoint)
470+
state_dict = self.mapping_lora_state_dict(state_dict)
471+
load_result = model.load_state_dict(state_dict, strict=False)
472+
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
473+
if len(load_result[1]) > 0:
474+
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
475+
setattr(pipe, lora_base_model, model)
428476

429477

430478
class ModelLogger:

examples/flux/model_training/train.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,37 +20,16 @@ def __init__(
2020
):
2121
super().__init__()
2222
# Load models
23-
model_configs = []
24-
if model_paths is not None:
25-
model_paths = json.loads(model_paths)
26-
model_configs += [ModelConfig(path=path) for path in model_paths]
27-
if model_id_with_origin_paths is not None:
28-
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
29-
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
23+
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
3024
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
3125

32-
# Reset training scheduler
33-
self.pipe.scheduler.set_timesteps(1000, training=True)
34-
35-
# Freeze untrainable models
36-
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
26+
# Training mode
27+
self.switch_pipe_to_training_mode(
28+
self, self.pipe, trainable_models,
29+
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
30+
enable_fp8_training=False,
31+
)
3732

38-
# Add LoRA to the base models
39-
if lora_base_model is not None:
40-
model = self.add_lora_to_model(
41-
getattr(self.pipe, lora_base_model),
42-
target_modules=lora_target_modules.split(","),
43-
lora_rank=lora_rank
44-
)
45-
if lora_checkpoint is not None:
46-
state_dict = load_state_dict(lora_checkpoint)
47-
state_dict = self.mapping_lora_state_dict(state_dict)
48-
load_result = model.load_state_dict(state_dict, strict=False)
49-
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
50-
if len(load_result[1]) > 0:
51-
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
52-
setattr(self.pipe, lora_base_model, model)
53-
5433
# Store other configs
5534
self.use_gradient_checkpointing = use_gradient_checkpointing
5635
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload

examples/qwen_image/model_training/train.py

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,46 +22,18 @@ def __init__(
2222
):
2323
super().__init__()
2424
# Load models
25-
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
26-
model_configs = []
27-
if model_paths is not None:
28-
model_paths = json.loads(model_paths)
29-
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
30-
if model_id_with_origin_paths is not None:
31-
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
32-
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
33-
25+
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
3426
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
3527
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
3628
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
29+
30+
# Training mode
31+
self.switch_pipe_to_training_mode(
32+
self, self.pipe, trainable_models,
33+
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
34+
enable_fp8_training=enable_fp8_training,
35+
)
3736

38-
# Enable FP8
39-
if enable_fp8_training:
40-
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
41-
42-
# Reset training scheduler (do it in each training step)
43-
self.pipe.scheduler.set_timesteps(1000, training=True)
44-
45-
# Freeze untrainable models
46-
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
47-
48-
# Add LoRA to the base models
49-
if lora_base_model is not None:
50-
model = self.add_lora_to_model(
51-
getattr(self.pipe, lora_base_model),
52-
target_modules=lora_target_modules.split(","),
53-
lora_rank=lora_rank,
54-
upcast_dtype=self.pipe.torch_dtype,
55-
)
56-
if lora_checkpoint is not None:
57-
state_dict = load_state_dict(lora_checkpoint)
58-
state_dict = self.mapping_lora_state_dict(state_dict)
59-
load_result = model.load_state_dict(state_dict, strict=False)
60-
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
61-
if len(load_result[1]) > 0:
62-
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
63-
setattr(self.pipe, lora_base_model, model)
64-
6537
# Store other configs
6638
self.use_gradient_checkpointing = use_gradient_checkpointing
6739
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload

examples/qwen_image/model_training/train_data_process.py

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,46 +22,18 @@ def __init__(
2222
):
2323
super().__init__()
2424
# Load models
25-
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
26-
model_configs = []
27-
if model_paths is not None:
28-
model_paths = json.loads(model_paths)
29-
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
30-
if model_id_with_origin_paths is not None:
31-
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
32-
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
33-
25+
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
3426
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
3527
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
3628
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
29+
30+
# Training mode
31+
self.switch_pipe_to_training_mode(
32+
self, self.pipe, trainable_models,
33+
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
34+
enable_fp8_training=enable_fp8_training,
35+
)
3736

38-
# Enable FP8
39-
if enable_fp8_training:
40-
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
41-
42-
# Reset training scheduler (do it in each training step)
43-
self.pipe.scheduler.set_timesteps(1000, training=True)
44-
45-
# Freeze untrainable models
46-
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
47-
48-
# Add LoRA to the base models
49-
if lora_base_model is not None:
50-
model = self.add_lora_to_model(
51-
getattr(self.pipe, lora_base_model),
52-
target_modules=lora_target_modules.split(","),
53-
lora_rank=lora_rank,
54-
upcast_dtype=self.pipe.torch_dtype,
55-
)
56-
if lora_checkpoint is not None:
57-
state_dict = load_state_dict(lora_checkpoint)
58-
state_dict = self.mapping_lora_state_dict(state_dict)
59-
load_result = model.load_state_dict(state_dict, strict=False)
60-
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
61-
if len(load_result[1]) > 0:
62-
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
63-
setattr(self.pipe, lora_base_model, model)
64-
6537
# Store other configs
6638
self.use_gradient_checkpointing = use_gradient_checkpointing
6739
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload

examples/wanvideo/model_training/train.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,37 +21,16 @@ def __init__(
2121
):
2222
super().__init__()
2323
# Load models
24-
model_configs = []
25-
if model_paths is not None:
26-
model_paths = json.loads(model_paths)
27-
model_configs += [ModelConfig(path=path) for path in model_paths]
28-
if model_id_with_origin_paths is not None:
29-
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
30-
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
24+
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
3125
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
3226

33-
# Reset training scheduler
34-
self.pipe.scheduler.set_timesteps(1000, training=True)
27+
# Training mode
28+
self.switch_pipe_to_training_mode(
29+
self, self.pipe, trainable_models,
30+
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
31+
enable_fp8_training=False,
32+
)
3533

36-
# Freeze untrainable models
37-
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
38-
39-
# Add LoRA to the base models
40-
if lora_base_model is not None:
41-
model = self.add_lora_to_model(
42-
getattr(self.pipe, lora_base_model),
43-
target_modules=lora_target_modules.split(","),
44-
lora_rank=lora_rank
45-
)
46-
if lora_checkpoint is not None:
47-
state_dict = load_state_dict(lora_checkpoint)
48-
state_dict = self.mapping_lora_state_dict(state_dict)
49-
load_result = model.load_state_dict(state_dict, strict=False)
50-
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
51-
if len(load_result[1]) > 0:
52-
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
53-
setattr(self.pipe, lora_base_model, model)
54-
5534
# Store other configs
5635
self.use_gradient_checkpointing = use_gradient_checkpointing
5736
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload

0 commit comments

Comments
 (0)