@@ -17,21 +17,27 @@ def __init__(
1717 use_gradient_checkpointing = True ,
1818 use_gradient_checkpointing_offload = False ,
1919 extra_inputs = None ,
20+ enable_fp8_training = False ,
2021 ):
2122 super ().__init__ ()
2223 # Load models
24+ offload_dtype = torch .float8_e4m3fn if enable_fp8_training else None
2325 model_configs = []
2426 if model_paths is not None :
2527 model_paths = json .loads (model_paths )
26- model_configs += [ModelConfig (path = path ) for path in model_paths ]
28+ model_configs += [ModelConfig (path = path , offload_dtype = offload_dtype ) for path in model_paths ]
2729 if model_id_with_origin_paths is not None :
2830 model_id_with_origin_paths = model_id_with_origin_paths .split ("," )
29- model_configs += [ModelConfig (model_id = i .split (":" )[0 ], origin_file_pattern = i .split (":" )[1 ]) for i in model_id_with_origin_paths ]
31+ model_configs += [ModelConfig (model_id = i .split (":" )[0 ], origin_file_pattern = i .split (":" )[1 ], offload_dtype = offload_dtype ) for i in model_id_with_origin_paths ]
3032
3133 tokenizer_config = ModelConfig (model_id = "Qwen/Qwen-Image" , origin_file_pattern = "tokenizer/" ) if tokenizer_path is None else ModelConfig (tokenizer_path )
3234 processor_config = ModelConfig (model_id = "Qwen/Qwen-Image-Edit" , origin_file_pattern = "processor/" ) if processor_path is None else ModelConfig (processor_path )
3335 self .pipe = QwenImagePipeline .from_pretrained (torch_dtype = torch .bfloat16 , device = "cpu" , model_configs = model_configs , tokenizer_config = tokenizer_config , processor_config = processor_config )
3436
37+ # Enable FP8
38+ if enable_fp8_training :
39+ self .pipe ._enable_fp8_lora_training (torch .float8_e4m3fn )
40+
3541 # Reset training scheduler (do it in each training step)
3642 self .pipe .scheduler .set_timesteps (1000 , training = True )
3743
@@ -43,7 +49,8 @@ def __init__(
4349 model = self .add_lora_to_model (
4450 getattr (self .pipe , lora_base_model ),
4551 target_modules = lora_target_modules .split ("," ),
46- lora_rank = lora_rank
52+ lora_rank = lora_rank ,
53+ upcast_dtype = self .pipe .torch_dtype ,
4754 )
4855 if lora_checkpoint is not None :
4956 state_dict = load_state_dict (lora_checkpoint )
@@ -126,6 +133,7 @@ def forward(self, data, inputs=None):
126133 use_gradient_checkpointing = args .use_gradient_checkpointing ,
127134 use_gradient_checkpointing_offload = args .use_gradient_checkpointing_offload ,
128135 extra_inputs = args .extra_inputs ,
136+ enable_fp8_training = args .enable_fp8_training ,
129137 )
130138 model_logger = ModelLogger (args .output_path , remove_prefix_in_ckpt = args .remove_prefix_in_ckpt )
131139 optimizer = torch .optim .AdamW (model .trainable_modules (), lr = args .learning_rate , weight_decay = args .weight_decay )
0 commit comments