1313from lightx2v .utils .registry_factory import RUNNER_REGISTER
1414from lightx2v_platform .base .global_var import AI_DEVICE
1515
16+ torch_device_module = getattr (torch , AI_DEVICE )
17+
1618
1719def calculate_dimensions (target_area , ratio ):
1820 width = math .sqrt (target_area * ratio )
@@ -45,8 +47,11 @@ def load_vae(self):
4547
4648 def init_modules (self ):
4749 logger .info (f"Initializing { self .config ['model_cls' ]} modules..." )
48- self .load_model ()
49- self .model .set_scheduler (self .scheduler )
50+ if not self .config .get ("lazy_load" , False ) and not self .config .get ("unload_modules" , False ):
51+ self .load_model ()
52+ self .model .set_scheduler (self .scheduler )
53+ elif self .config .get ("lazy_load" , False ):
54+ assert self .config .get ("cpu_offload" , False )
5055
5156 task = self .config .get ("task" , "t2i" )
5257 if task == "i2i" :
@@ -59,8 +64,12 @@ def init_modules(self):
5964 @ProfilingContext4DebugL2 ("Run Encoders" )
6065 def _run_input_encoder_local_t2i (self ):
6166 prompt = self .input_info .prompt
67+ if self .config .get ("lazy_load" , False ) or self .config .get ("unload_modules" , False ):
68+ self .text_encoders = self .load_text_encoder ()
6269 text_encoder_output = self .run_text_encoder (prompt , neg_prompt = self .input_info .negative_prompt )
63- torch .cuda .empty_cache ()
70+ if self .config .get ("lazy_load" , False ) or self .config .get ("unload_modules" , False ):
71+ del self .text_encoders [0 ]
72+ torch_device_module .empty_cache ()
6473 gc .collect ()
6574 return {
6675 "text_encoder_output" : text_encoder_output ,
@@ -70,7 +79,11 @@ def _run_input_encoder_local_t2i(self):
7079 @ProfilingContext4DebugL2 ("Run Encoders I2I" )
7180 def _run_input_encoder_local_i2i (self ):
7281 prompt = self .input_info .prompt
82+ if self .config .get ("lazy_load" , False ) or self .config .get ("unload_modules" , False ):
83+ self .text_encoders = self .load_text_encoder ()
7384 text_encoder_output = self .run_text_encoder (prompt , neg_prompt = self .input_info .negative_prompt )
85+ if self .config .get ("lazy_load" , False ) or self .config .get ("unload_modules" , False ):
86+ del self .text_encoders [0 ]
7487
7588 image_path = self .input_info .image_path
7689 from PIL import Image
@@ -108,7 +121,7 @@ def _run_input_encoder_local_i2i(self):
108121 if index == 0 :
109122 self .input_info .target_shape = (image_height , image_width )
110123
111- torch . cuda .empty_cache ()
124+ torch_device_module .empty_cache ()
112125 gc .collect ()
113126
114127 return {
@@ -244,6 +257,9 @@ def set_img_shapes(self):
244257
245258 @ProfilingContext4DebugL1 ("Run VAE Decoder" )
246259 def run_vae_decoder (self , latents ):
260+ if self .config .get ("lazy_load" , False ) or self .config .get ("unload_modules" , False ):
261+ self .vae = self .load_vae ()
262+
247263 B , _ , C = latents .shape
248264
249265 H = int ((self .input_info .latent_image_ids [0 , :, 1 ].max () + 1 ).item ())
@@ -252,14 +268,20 @@ def run_vae_decoder(self, latents):
252268 latents = latents .view (B , H , W , C ).permute (0 , 3 , 1 , 2 )
253269
254270 bn_mean = self .vae .vae .bn .running_mean .view (1 , - 1 , 1 , 1 ).to (latents .device , latents .dtype )
255- bn_std = torch .sqrt (self .vae .vae .bn .running_var .view (1 , - 1 , 1 , 1 ) + self .vae .vae .config .batch_norm_eps )
271+ bn_std = torch .sqrt (self .vae .vae .bn .running_var .view (1 , - 1 , 1 , 1 ) + self .vae .vae .config .batch_norm_eps ). to ( latents . device , latents . dtype )
256272 latents = latents * bn_std + bn_mean
257273
258274 latents = latents .reshape (B , C // 4 , 2 , 2 , H , W )
259275 latents = latents .permute (0 , 1 , 4 , 2 , 5 , 3 )
260276 latents = latents .reshape (B , C // 4 , H * 2 , W * 2 )
261277
262278 images = self .vae .decode (latents , self .input_info )
279+
280+ if self .config .get ("lazy_load" , False ) or self .config .get ("unload_modules" , False ):
281+ del self .vae
282+ torch_device_module .empty_cache ()
283+ gc .collect ()
284+
263285 return images
264286
265287 @ProfilingContext4DebugL1 ("RUN pipeline" )
@@ -279,7 +301,7 @@ def run_pipeline(self, input_info):
279301 image .save (input_info .save_result_path )
280302 logger .info (f"Image saved: { input_info .save_result_path } " )
281303
282- torch . cuda .empty_cache ()
304+ torch_device_module .empty_cache ()
283305 gc .collect ()
284306
285307 if input_info .return_result_tensor :
0 commit comments