Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
model:
name: qwen_image
pretrained_model_name_or_path: /path/to/Qwen/Qwen-Image
max_sequence_length: 1024
running_dtype: bf16

data:
train:
name: image_dataset
num_workers: 8
prompt_dropout_rate: 0.0
target_area: 1048576 # 1024 * 1024
shuffle: true
data_path:
- /path/to/LightX2V_train_data_examples/dataset_v1/train.jsonl
val:
name: image_dataset
num_workers: 8
shuffle: false
data_path:
- /path/to/LightX2V_train_data_examples/dataset_v1/val.jsonl

scheduler:
num_train_timesteps: 1000
time_shift_settings:
do_time_shift: true
shift_type: linear
# shift function: "linear" => mu/(mu+(1/t-1)^p), "exponential" => exp(mu)/(exp(mu)+(1/t-1)^p)
time_shift_power: 1.0
dynamic_shift: false
time_shift_mu: 3.0

training:
method: dmd_lora
max_train_iters: 1000
gradient_accumulation_iters: 1
gradient_checkpointing: true
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_iters: 10
save_every_iters: 100
save_total_limit: 10
dmd:
num_inference_steps: 4
fake_update_ratio: 2
image_sizes:
- [1024, 1024]
- [768, 1344]
- [1344, 768]
renoise_sigma_min: 0.02
renoise_sigma_max: 1.0
renoise_discrete_samples: 1000
renoise_shift: 5.0
lora:
rank: 32
alpha: 32
target_modules:
- to_k
- to_q
- to_v
- to_out.0
# - add_q_proj
# - add_k_proj
# - add_v_proj
# - to_add_out
# - img_mlp.net.0.proj
# - img_mlp.net.2
# - txt_mlp.net.0.proj
# - txt_mlp.net.2
student:
optimizer:
learning_rate: 0.0001
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.001
adam_epsilon: 0.00000001
fake:
optimizer:
learning_rate: 0.00002
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.001
adam_epsilon: 0.00000001
teacher:
guidance_scale: 4.0
negative_prompt: " "
cfg_norm: layer_norm
output_dir: ./output_train/qwen_image_dmd_lora

inference:
method: image_infer
default_width: 1024
default_height: 1024
num_inference_steps: 4
cfg_guidance_scale: 4.0
negative_prompt: " "
enable_cfg: false
output_dir: ./output_infer/qwen_image_dmd_lora
infer_every_iters: ${training.save_every_iters}

logging:
rank_zero_only: true
train_log_every_iters: 10
infer_log_every_steps: 10

resume:
auto_resume: true
19 changes: 17 additions & 2 deletions lightx2v_train/lightx2v_train/model_zoo/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ class QwenImageModel(BaseModel):

pipeline_cls = QwenImagePipeline

def load_components(self):
def load_components(self, transformer_only=False, reference_model=None):
if transformer_only:
if reference_model is not None:
self.text_pipeline = reference_model.text_pipeline
self.vae = reference_model.vae
self.vae_scale_factor = reference_model.vae_scale_factor
self.image_processor = reference_model.image_processor
self.transformer = self.load_transformer()
return
model_path = self.config["model"]["pretrained_model_name_or_path"]
self.text_pipeline = QwenImagePipeline.from_pretrained(
model_path,
Expand All @@ -35,13 +43,17 @@ def load_components(self):
torch_dtype=self.running_dtype,
).to(self.device)
self.vae = AutoencoderKLQwenImage.from_pretrained(model_path, subfolder="vae").to(self.device, dtype=self.running_dtype)
self.transformer = QwenImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype)
self.transformer = self.load_transformer()

self.text_pipeline.text_encoder.requires_grad_(False)
self.vae.requires_grad_(False)
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)

def load_transformer(self):
model_path = self.config["model"]["pretrained_model_name_or_path"]
return QwenImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype)

def denoiser_module(self):
return self.transformer

Expand Down Expand Up @@ -69,6 +81,9 @@ def encode_to_latent(self, sample):

def encode_condition(self, sample):
prompt = sample["prompt"]
return self.encode_prompt_condition(prompt)

def encode_prompt_condition(self, prompt):
prompt_embed, prompt_embed_mask = self.text_pipeline.encode_prompt(
prompt=prompt,
device=self.device,
Expand Down
4 changes: 4 additions & 0 deletions lightx2v_train/lightx2v_train/schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .dmd_scheduler import DMDFlowMatchingScheduler
from .flow_matching import RectifiedFlowMatchingScheduler

__all__ = ["DMDFlowMatchingScheduler", "RectifiedFlowMatchingScheduler"]
55 changes: 55 additions & 0 deletions lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch

from .flow_matching import RectifiedFlowMatchingScheduler


class DMDFlowMatchingScheduler(RectifiedFlowMatchingScheduler):
def __init__(self, config, dmd_config={}):
super().__init__(config)
self.renoise_shift = float(dmd_config.get("renoise_shift", 5.0))
self.renoise_sigma_min = float(dmd_config.get("renoise_sigma_min", dmd_config.get("sigma_min", 0.02)))
self.renoise_sigma_max = float(dmd_config.get("renoise_sigma_max", dmd_config.get("sigma_max", 1.0)))
self.renoise_discrete_samples = int(dmd_config.get("renoise_discrete_samples", dmd_config.get("discrete_samples", 1000)))

@staticmethod
def linear_shift(mu, t):
return mu / (mu + (1 / t - 1))

def set_timesteps(self, num_inference_steps, sigmas=None, latent_hw=None, device=None):
super().set_timesteps(num_inference_steps, sigmas=sigmas, latent_hw=latent_hw)
if device is not None:
self.infer_sigmas = self.infer_sigmas.to(device)
self.infer_timesteps = self.infer_timesteps.to(device)
self.sigmas = self.infer_sigmas
self.timesteps = self.infer_timesteps

def sigma_at(self, step_idx, batch_size, device=None, dtype=None):
sigma = self.sigmas[int(step_idx)].expand(int(batch_size))
if device is not None or dtype is not None:
sigma = sigma.to(device=device, dtype=dtype)
return sigma

def sample_renoise_sigma(self, batch_size, device=None, dtype=None):
device = device or self.device
raw = torch.rand((int(batch_size),), device=device, dtype=torch.float32)
if self.renoise_discrete_samples > 0:
raw = torch.ceil(raw * self.renoise_discrete_samples) / self.renoise_discrete_samples
raw = torch.clamp(raw, 1e-7, 1 - 1e-7)
sigma = torch.clamp(self.linear_shift(self.renoise_shift, raw), self.renoise_sigma_min, self.renoise_sigma_max)
if dtype is not None:
sigma = sigma.to(dtype=dtype)
return sigma

def add_noise(self, latent, noise, sigmas):
sigmas = sigmas.to(device=latent.device)
sigmas = self._expand_to_ndim(sigmas, latent.ndim)
return ((1.0 - sigmas) * latent + sigmas * noise).to(dtype=latent.dtype)

def step_by_index(self, velocity, step_idx, sample):
sigma = self.sigma_at(step_idx, sample.shape[0], device=sample.device)
sigma_next = self.sigma_at(int(step_idx) + 1, sample.shape[0], device=sample.device)
sigma = self._expand_to_ndim(sigma, sample.ndim)
sigma_next = self._expand_to_ndim(sigma_next, sample.ndim)
next_sample = sample + (sigma_next - sigma) * velocity
x0 = sample - sigma * velocity
return next_sample.to(sample.dtype), x0.to(sample.dtype)
3 changes: 2 additions & 1 deletion lightx2v_train/lightx2v_train/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lightx2v_train.utils.registry import build_trainer

from .dmd_lora import DmdLoraTrainer
from .lora import LoraTrainer

__all__ = ["build_trainer", "LoraTrainer"]
__all__ = ["build_trainer", "DmdLoraTrainer", "LoraTrainer"]
Loading
Loading