diff --git a/lightx2v_train/configs/train/dmd/flux2_klein_dmd_lora.yaml b/lightx2v_train/configs/train/dmd/flux2_klein_dmd_lora.yaml new file mode 100644 index 000000000..a17e81b3b --- /dev/null +++ b/lightx2v_train/configs/train/dmd/flux2_klein_dmd_lora.yaml @@ -0,0 +1,99 @@ +model: + name: flux2_klein + pretrained_model_name_or_path: /path/to/FLUX.2-klein-4B + max_sequence_length: 1024 + running_dtype: bf16 + +data: + train: + name: image_dataset + num_workers: 8 + prompt_dropout_rate: 0.0 + target_area: 1048576 # 1024 * 1024 + shuffle: true + data_path: + - /path/to/LightX2V_train_data_examples/dataset_v1/train.jsonl + val: + name: image_dataset + num_workers: 8 + shuffle: false + data_path: + - /path/to/LightX2V_train_data_examples/dataset_v1/val.jsonl + +scheduler: + num_train_timesteps: 1000 + time_shift_settings: + do_time_shift: true + shift_type: exponential + time_shift_power: 1.0 + dynamic_shift: false + time_shift_mu: 3.0 + +training: + method: dmd + train_type: lora + max_train_iters: 1000 + gradient_accumulation_iters: 1 + gradient_checkpointing: true + max_grad_norm: 1.0 + lr_scheduler: constant + lr_warmup_iters: 10 + save_every_iters: 100 + save_total_limit: 10 + dmd: + num_inference_steps: 4 + fake_update_ratio: 2 + image_sizes: + - [1024, 1024] + - [768, 1344] + - [1344, 768] + renoise_sigma_min: 0.02 + renoise_sigma_max: 1.0 + renoise_discrete_samples: 1000 + renoise_shift: 5.0 + lora: + rank: 32 + alpha: 32 + target_modules: + - to_k + - to_q + - to_v + - to_out.0 + student: + optimizer: + learning_rate: 0.0001 + adam_beta1: 0.9 + adam_beta2: 0.999 + weight_decay: 0.001 + adam_epsilon: 0.00000001 + fake: + optimizer: + learning_rate: 0.00002 + adam_beta1: 0.9 + adam_beta2: 0.999 + weight_decay: 0.001 + adam_epsilon: 0.00000001 + teacher: + guidance_scale: 4.0 + negative_prompt: " " + cfg_norm: layer_norm + output_dir: ./output_train/flux2_klein_dmd_lora + +inference: + method: image_infer + default_width: 1024 + default_height: 1024 + num_inference_steps: 4 + cfg_guidance_scale: 4.0 + negative_prompt: " " + enable_cfg: false + output_dir: ./output_infer/flux2_klein_dmd_lora + infer_every_iters: ${training.save_every_iters} + +logging: + rank_zero_only: true + train_log_every_iters: 10 + infer_log_every_steps: 10 + +resume: + auto_resume: true diff --git a/lightx2v_train/lightx2v_train/model_zoo/base.py b/lightx2v_train/lightx2v_train/model_zoo/base.py index 1de8aa0a3..16bbd77b4 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/base.py +++ b/lightx2v_train/lightx2v_train/model_zoo/base.py @@ -27,6 +27,9 @@ def __init__(self, config): def load_components(self, transformer_only=False, reference_model=None): raise NotImplementedError + def dmd_latent_shape(self, batch_size, height, width): + raise NotImplementedError(f"{self.__class__.__name__} must define dmd_latent_shape().") + def denoiser_module(self): raise NotImplementedError(f"{self.__class__.__name__} must define denoiser_module().") diff --git a/lightx2v_train/lightx2v_train/model_zoo/flux2_dev.py b/lightx2v_train/lightx2v_train/model_zoo/flux2_dev.py index e92abcba8..02cace3db 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/flux2_dev.py +++ b/lightx2v_train/lightx2v_train/model_zoo/flux2_dev.py @@ -108,7 +108,7 @@ def encode_prompt_condition(self, prompt): ) return {"prompt_embed": prompt_embed, "text_ids": text_ids} - def prepare_denoiser_input(self, noisy_latent): + def prepare_denoiser_input(self, noisy_latent, condition=None): h, w = noisy_latent.shape[2], noisy_latent.shape[3] packed = Flux2Pipeline._pack_latents(noisy_latent) img_ids = Flux2Pipeline._prepare_latent_ids(noisy_latent).to(self.device) diff --git a/lightx2v_train/lightx2v_train/model_zoo/flux2_klein.py b/lightx2v_train/lightx2v_train/model_zoo/flux2_klein.py index e60e3a134..d6b0f616c 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/flux2_klein.py +++ b/lightx2v_train/lightx2v_train/model_zoo/flux2_klein.py @@ -21,7 +21,14 @@ class Flux2KleinDenoiserInput: class Flux2KleinModel(BaseModel): pipeline_cls = Flux2KleinPipeline - def load_components(self): + def load_components(self, transformer_only=False, reference_model=None): + if transformer_only: + if reference_model is not None: + self.text_pipeline = reference_model.text_pipeline + self.vae = reference_model.vae + self.image_processor = reference_model.image_processor + self.transformer = self.load_transformer() + return model_path = self.config["model"]["pretrained_model_name_or_path"] self.text_pipeline = Flux2KleinPipeline.from_pretrained( model_path, @@ -81,7 +88,11 @@ def encode_to_latent(self, sample): latent = self.vae.encode(image).latent_dist.sample() return self._normalize_patch_latents(latent) - def encode_prompt_text(self, prompt): + def encode_condition(self, sample): + prompt = sample["prompt"] + return self.encode_prompt_condition(prompt) + + def encode_prompt_condition(self, prompt): model_config = self.config["model"] prompt_embed, text_ids = self.text_pipeline.encode_prompt( prompt=prompt, @@ -92,8 +103,18 @@ def encode_prompt_text(self, prompt): ) return {"prompt_embed": prompt_embed, "text_ids": text_ids} - def encode_condition(self, sample): - return self.encode_prompt_text(sample["prompt"]) + def encode_prompt_text(self, prompt): + return self.encode_prompt_condition(prompt) + + def dmd_latent_shape(self, batch_size, height, width): + latent_h = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_w = 2 * (int(width) // (self.vae_scale_factor * 2)) + return ( + batch_size, + self.transformer.config.in_channels, + latent_h // 2, + latent_w // 2, + ) def prepare_denoiser_input(self, noisy_latent, condition=None): h, w = noisy_latent.shape[2], noisy_latent.shape[3]