Skip to content

Commit b6da77e

Browse files
committed
qwen-image splited training
1 parent 260e322 commit b6da77e

File tree

7 files changed

+221
-14
lines changed

7 files changed

+221
-14
lines changed

diffsynth/pipelines/qwen_image.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,12 @@ def _enable_fp8_lora_training(self, dtype):
174174
computation_dtype=self.torch_dtype,
175175
computation_device="cuda",
176176
)
177-
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
178-
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
179-
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
177+
if self.text_encoder is not None:
178+
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
179+
if self.dit is not None:
180+
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
181+
if self.vae is not None:
182+
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
180183

181184

182185
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):

diffsynth/trainers/unified_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def __init__(self, map_location="cpu"):
214214
self.map_location = map_location
215215

216216
def __call__(self, data):
217-
return torch.load(data, map_location=self.map_location)
217+
return torch.load(data, map_location=self.map_location, weights_only=False)
218218

219219

220220

@@ -306,7 +306,7 @@ def load_metadata(self, metadata_path):
306306

307307
def __getitem__(self, data_id):
308308
if self.load_from_cache:
309-
data = self.cached_data[data_id % len(self.data)].copy()
309+
data = self.cached_data[data_id % len(self.cached_data)]
310310
data = self.cached_data_operator(data)
311311
else:
312312
data = self.data[data_id % len(self.data)].copy()

diffsynth/trainers/utils.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,13 @@ def export_trainable_state_dict(self, state_dict, remove_prefix=None):
417417
state_dict_[name] = param
418418
state_dict = state_dict_
419419
return state_dict
420+
421+
422+
def transfer_data_to_device(self, data, device):
423+
for key in data:
424+
if isinstance(data[key], torch.Tensor):
425+
data[key] = data[key].to(device)
426+
return data
420427

421428

422429

@@ -484,7 +491,10 @@ def launch_training_task(
484491
for data in tqdm(dataloader):
485492
with accelerator.accumulate(model):
486493
optimizer.zero_grad()
487-
loss = model(data)
494+
if dataset.load_from_cache:
495+
loss = model({}, inputs=data)
496+
else:
497+
loss = model(data)
488498
accelerator.backward(loss)
489499
optimizer.step()
490500
model_logger.on_step_end(accelerator, model, save_steps)
@@ -494,16 +504,24 @@ def launch_training_task(
494504
model_logger.on_training_end(accelerator, model, save_steps)
495505

496506

497-
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
498-
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0])
507+
def launch_data_process_task(
508+
dataset: torch.utils.data.Dataset,
509+
model: DiffusionTrainingModule,
510+
model_logger: ModelLogger,
511+
num_workers: int = 8,
512+
):
513+
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
499514
accelerator = Accelerator()
500515
model, dataloader = accelerator.prepare(model, dataloader)
501-
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True)
502-
for data_id, data in enumerate(tqdm(dataloader)):
503-
with torch.no_grad():
504-
inputs = model.forward_preprocess(data)
505-
inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs}
506-
torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth"))
516+
517+
for data_id, data in tqdm(enumerate(dataloader)):
518+
with accelerator.accumulate(model):
519+
with torch.no_grad():
520+
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
521+
os.makedirs(folder, exist_ok=True)
522+
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
523+
data = model(data)
524+
torch.save(data, save_path)
507525

508526

509527

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
accelerate launch examples/qwen_image/model_training/train_data_process.py \
2+
--dataset_base_path data/example_image_dataset \
3+
--dataset_metadata_path data/example_image_dataset/metadata.csv \
4+
--max_pixels 1048576 \
5+
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
6+
--output_path "./models/train/Qwen-Image_lora_cache" \
7+
--use_gradient_checkpointing \
8+
--dataset_num_workers 8
9+
10+
accelerate launch examples/qwen_image/model_training/train.py \
11+
--dataset_base_path models/train/Qwen-Image_lora_cache \
12+
--max_pixels 1048576 \
13+
--dataset_repeat 50 \
14+
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
15+
--learning_rate 1e-4 \
16+
--num_epochs 5 \
17+
--remove_prefix_in_ckpt "pipe.dit." \
18+
--output_path "./models/train/Qwen-Image_lora" \
19+
--lora_base_model "dit" \
20+
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
21+
--lora_rank 32 \
22+
--use_gradient_checkpointing \
23+
--dataset_num_workers 8 \
24+
--find_unused_parameters \
25+
--enable_fp8_training

examples/qwen_image/model_training/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def forward_preprocess(self, data):
111111

112112
def forward(self, data, inputs=None):
113113
if inputs is None: inputs = self.forward_preprocess(data)
114+
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
114115
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
115116
loss = self.pipe.training_loss(**models, **inputs)
116117
return loss
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import torch, os, json
2+
from diffsynth import load_state_dict
3+
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
4+
from diffsynth.pipelines.flux_image_new import ControlNetInput
5+
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_data_process_task, qwen_image_parser
6+
from diffsynth.trainers.unified_dataset import UnifiedDataset
7+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
8+
9+
10+
11+
class QwenImageTrainingModule(DiffusionTrainingModule):
12+
def __init__(
13+
self,
14+
model_paths=None, model_id_with_origin_paths=None,
15+
tokenizer_path=None, processor_path=None,
16+
trainable_models=None,
17+
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
18+
use_gradient_checkpointing=True,
19+
use_gradient_checkpointing_offload=False,
20+
extra_inputs=None,
21+
enable_fp8_training=False,
22+
):
23+
super().__init__()
24+
# Load models
25+
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
26+
model_configs = []
27+
if model_paths is not None:
28+
model_paths = json.loads(model_paths)
29+
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
30+
if model_id_with_origin_paths is not None:
31+
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
32+
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
33+
34+
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
35+
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
36+
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
37+
38+
# Enable FP8
39+
if enable_fp8_training:
40+
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
41+
42+
# Reset training scheduler (do it in each training step)
43+
self.pipe.scheduler.set_timesteps(1000, training=True)
44+
45+
# Freeze untrainable models
46+
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
47+
48+
# Add LoRA to the base models
49+
if lora_base_model is not None:
50+
model = self.add_lora_to_model(
51+
getattr(self.pipe, lora_base_model),
52+
target_modules=lora_target_modules.split(","),
53+
lora_rank=lora_rank,
54+
upcast_dtype=self.pipe.torch_dtype,
55+
)
56+
if lora_checkpoint is not None:
57+
state_dict = load_state_dict(lora_checkpoint)
58+
state_dict = self.mapping_lora_state_dict(state_dict)
59+
load_result = model.load_state_dict(state_dict, strict=False)
60+
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
61+
if len(load_result[1]) > 0:
62+
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
63+
setattr(self.pipe, lora_base_model, model)
64+
65+
# Store other configs
66+
self.use_gradient_checkpointing = use_gradient_checkpointing
67+
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
68+
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
69+
70+
71+
def forward_preprocess(self, data):
72+
# CFG-sensitive parameters
73+
inputs_posi = {"prompt": data["prompt"]}
74+
inputs_nega = {"negative_prompt": ""}
75+
76+
# CFG-unsensitive parameters
77+
inputs_shared = {
78+
# Assume you are using this pipeline for inference,
79+
# please fill in the input parameters.
80+
"input_image": data["image"],
81+
"height": data["image"].size[1],
82+
"width": data["image"].size[0],
83+
# Please do not modify the following parameters
84+
# unless you clearly know what this will cause.
85+
"cfg_scale": 1,
86+
"rand_device": self.pipe.device,
87+
"use_gradient_checkpointing": self.use_gradient_checkpointing,
88+
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
89+
"edit_image_auto_resize": True,
90+
}
91+
92+
# Extra inputs
93+
controlnet_input, blockwise_controlnet_input = {}, {}
94+
for extra_input in self.extra_inputs:
95+
if extra_input.startswith("blockwise_controlnet_"):
96+
blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
97+
elif extra_input.startswith("controlnet_"):
98+
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
99+
else:
100+
inputs_shared[extra_input] = data[extra_input]
101+
if len(controlnet_input) > 0:
102+
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
103+
if len(blockwise_controlnet_input) > 0:
104+
inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)]
105+
106+
# Pipeline units will automatically process the input parameters.
107+
for unit in self.pipe.units:
108+
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
109+
return {**inputs_shared, **inputs_posi}
110+
111+
112+
def forward(self, data, inputs=None):
113+
if inputs is None: inputs = self.forward_preprocess(data)
114+
return inputs
115+
116+
117+
118+
if __name__ == "__main__":
119+
parser = qwen_image_parser()
120+
args = parser.parse_args()
121+
dataset = UnifiedDataset(
122+
base_path=args.dataset_base_path,
123+
metadata_path=args.dataset_metadata_path,
124+
repeat=1, # Set repeat = 1
125+
data_file_keys=args.data_file_keys.split(","),
126+
main_data_operator=UnifiedDataset.default_image_operator(
127+
base_path=args.dataset_base_path,
128+
max_pixels=args.max_pixels,
129+
height=args.height,
130+
width=args.width,
131+
height_division_factor=16,
132+
width_division_factor=16,
133+
)
134+
)
135+
model = QwenImageTrainingModule(
136+
model_paths=args.model_paths,
137+
model_id_with_origin_paths=args.model_id_with_origin_paths,
138+
tokenizer_path=args.tokenizer_path,
139+
processor_path=args.processor_path,
140+
trainable_models=args.trainable_models,
141+
lora_base_model=args.lora_base_model,
142+
lora_target_modules=args.lora_target_modules,
143+
lora_rank=args.lora_rank,
144+
lora_checkpoint=args.lora_checkpoint,
145+
use_gradient_checkpointing=args.use_gradient_checkpointing,
146+
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
147+
extra_inputs=args.extra_inputs,
148+
enable_fp8_training=args.enable_fp8_training,
149+
)
150+
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
151+
launch_data_process_task(
152+
dataset, model, model_logger,
153+
num_workers=args.dataset_num_workers,
154+
)

test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import torch
2+
3+
4+
data = torch.load("models/train/Qwen-Image_lora_cache/0/0.pth", map_location="cpu", weights_only=False)
5+
for i in data:
6+
print(i)

0 commit comments

Comments
 (0)