Skip to content

Commit ce0b948

Browse files
committed
support qwen-image fp8 lora training
1 parent c795e35 commit ce0b948

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

diffsynth/pipelines/qwen_image.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,35 @@ def training_loss(self, **inputs):
150150
return loss
151151

152152

153+
def _enable_fp8_lora_training(self, dtype):
154+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
155+
from ..models.qwen_image_dit import RMSNorm
156+
from ..models.qwen_image_vae import QwenImageRMS_norm
157+
module_map = {
158+
RMSNorm: AutoWrappedModule,
159+
torch.nn.Linear: AutoWrappedLinear,
160+
torch.nn.Conv3d: AutoWrappedModule,
161+
torch.nn.Conv2d: AutoWrappedModule,
162+
torch.nn.Embedding: AutoWrappedModule,
163+
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
164+
Qwen2RMSNorm: AutoWrappedModule,
165+
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
166+
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
167+
QwenImageRMS_norm: AutoWrappedModule,
168+
}
169+
model_config = dict(
170+
offload_dtype=dtype,
171+
offload_device="cuda",
172+
onload_dtype=dtype,
173+
onload_device="cuda",
174+
computation_dtype=self.torch_dtype,
175+
computation_device="cuda",
176+
)
177+
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
178+
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
179+
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
180+
181+
153182
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
154183
self.vram_management_enabled = True
155184
if vram_limit is None:

diffsynth/trainers/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,11 +338,15 @@ def trainable_param_names(self):
338338
return trainable_param_names
339339

340340

341-
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
341+
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
342342
if lora_alpha is None:
343343
lora_alpha = lora_rank
344344
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
345345
model = inject_adapter_in_model(lora_config, model)
346+
if upcast_dtype is not None:
347+
for param in model.parameters():
348+
if param.requires_grad:
349+
param.data = param.to(upcast_dtype)
346350
return model
347351

348352

@@ -555,4 +559,5 @@ def qwen_image_parser():
555559
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
556560
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
557561
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
562+
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
558563
return parser

examples/qwen_image/model_training/train.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,27 @@ def __init__(
1717
use_gradient_checkpointing=True,
1818
use_gradient_checkpointing_offload=False,
1919
extra_inputs=None,
20+
enable_fp8_training=False,
2021
):
2122
super().__init__()
2223
# Load models
24+
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
2325
model_configs = []
2426
if model_paths is not None:
2527
model_paths = json.loads(model_paths)
26-
model_configs += [ModelConfig(path=path) for path in model_paths]
28+
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
2729
if model_id_with_origin_paths is not None:
2830
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
29-
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
31+
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
3032

3133
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
3234
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
3335
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
3436

37+
# Enable FP8
38+
if enable_fp8_training:
39+
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
40+
3541
# Reset training scheduler (do it in each training step)
3642
self.pipe.scheduler.set_timesteps(1000, training=True)
3743

@@ -43,7 +49,8 @@ def __init__(
4349
model = self.add_lora_to_model(
4450
getattr(self.pipe, lora_base_model),
4551
target_modules=lora_target_modules.split(","),
46-
lora_rank=lora_rank
52+
lora_rank=lora_rank,
53+
upcast_dtype=self.pipe.torch_dtype,
4754
)
4855
if lora_checkpoint is not None:
4956
state_dict = load_state_dict(lora_checkpoint)
@@ -126,6 +133,7 @@ def forward(self, data, inputs=None):
126133
use_gradient_checkpointing=args.use_gradient_checkpointing,
127134
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
128135
extra_inputs=args.extra_inputs,
136+
enable_fp8_training=args.enable_fp8_training,
129137
)
130138
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
131139
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)

0 commit comments

Comments
 (0)