@@ -77,10 +77,63 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
7777 self .model_fn = model_fn_qwen_image
7878
7979
80- def load_lora (self , module , path , alpha = 1 ):
81- loader = GeneralLoRALoader (torch_dtype = self .torch_dtype , device = self .device )
82- lora = load_state_dict (path , torch_dtype = self .torch_dtype , device = self .device )
83- loader .load (module , lora , alpha = alpha )
80+ def load_lora (
81+ self ,
82+ module : torch .nn .Module ,
83+ lora_config : Union [ModelConfig , str ] = None ,
84+ alpha = 1 ,
85+ hotload = False ,
86+ state_dict = None ,
87+ ):
88+ if state_dict is None :
89+ if isinstance (lora_config , str ):
90+ lora = load_state_dict (lora_config , torch_dtype = self .torch_dtype , device = self .device )
91+ else :
92+ lora_config .download_if_necessary ()
93+ lora = load_state_dict (lora_config .path , torch_dtype = self .torch_dtype , device = self .device )
94+ else :
95+ lora = state_dict
96+ if hotload :
97+ for name , module in module .named_modules ():
98+ if isinstance (module , AutoWrappedLinear ):
99+ lora_a_name = f'{ name } .lora_A.default.weight'
100+ lora_b_name = f'{ name } .lora_B.default.weight'
101+ if lora_a_name in lora and lora_b_name in lora :
102+ module .lora_A_weights .append (lora [lora_a_name ] * alpha )
103+ module .lora_B_weights .append (lora [lora_b_name ])
104+ else :
105+ loader = GeneralLoRALoader (torch_dtype = self .torch_dtype , device = self .device )
106+ loader .load (module , lora , alpha = alpha )
107+
108+
109+ def clear_lora (self ):
110+ for name , module in self .named_modules ():
111+ if isinstance (module , AutoWrappedLinear ):
112+ if hasattr (module , "lora_A_weights" ):
113+ module .lora_A_weights .clear ()
114+ if hasattr (module , "lora_B_weights" ):
115+ module .lora_B_weights .clear ()
116+
117+
118+ def enable_lora_magic (self ):
119+ if self .dit is not None :
120+ if not (hasattr (self .dit , "vram_management_enabled" ) and self .dit .vram_management_enabled ):
121+ dtype = next (iter (self .dit .parameters ())).dtype
122+ enable_vram_management (
123+ self .dit ,
124+ module_map = {
125+ torch .nn .Linear : AutoWrappedLinear ,
126+ },
127+ module_config = dict (
128+ offload_dtype = dtype ,
129+ offload_device = self .device ,
130+ onload_dtype = dtype ,
131+ onload_device = self .device ,
132+ computation_dtype = self .torch_dtype ,
133+ computation_device = self .device ,
134+ ),
135+ vram_limit = None ,
136+ )
84137
85138
86139 def training_loss (self , ** inputs ):
0 commit comments