@@ -77,18 +77,17 @@ def execute_merge(model_names, calc_mode, all_modes, recipe_params, model_type):
7777 print (f"[Merger] Preparing memory for { total_size_gb :.2f} GB merge operation..." )
7878 prepare_for_large_operation (total_size_gb * 1.2 , torch .device (process_device ))
7979
80- lazy_load = recipe_params .get ('lazy_load' , True )
8180 handlers = {}
8281 for name in model_names .values ():
8382 if name and name != "None" :
8483 path = folder_paths .get_full_path (model_type , name )
8584 if not path :
8685 raise FileNotFoundError (f"Model '{ name } ' not found." )
87- handlers [name ] = MemoryEfficientSafeOpen (path , low_memory = lazy_load )
86+ handlers [name ] = MemoryEfficientSafeOpen (path , low_memory = True )
8887
8988 primary_handler = handlers [primary_model_name ]
9089 all_keys = primary_handler .keys ()
91- metadata = primary_handler .header . get ( "__metadata__" , {} )
90+ metadata = primary_handler .metadata ( )
9291
9392 # Convert mismatch_mode string to enum
9493 mismatch_mode_str = recipe_params .get ('mismatch_mode' , 'skip' )
@@ -125,72 +124,78 @@ def execute_merge(model_names, calc_mode, all_modes, recipe_params, model_type):
125124 discarded_keys = 0
126125 error_keys = []
127126
128- for key in tqdm (all_keys , desc = "Merging layers" , unit = "layers" ):
129- # Check discard patterns first - skip entirely
130- if _matches_any_pattern (key , discard_patterns ):
131- discarded_keys += 1
132- pbar .update (1 )
133- continue
134-
135- # Pre-load Model A's tensor with pinned memory for CUDA
136- cpu_tensor = primary_handler .get_tensor (key )
137- if process_device == 'cuda' :
138- tensor_a = transfer_to_gpu_pinned (cpu_tensor , process_device , process_dtype )
139- else :
140- tensor_a = cpu_tensor .to (device = process_device , dtype = process_dtype )
141- del cpu_tensor
142-
143- # Check exclude patterns - use Model A only, no merge
144- if _matches_any_pattern (key , exclude_patterns ):
145- merged_state_dict [key ] = tensor_a .to (save_torch_dtype ).cpu ().clone ()
146- excluded_keys += 1
127+ with torch .no_grad ():
128+ for key in tqdm (all_keys , desc = "Merging layers" , unit = "layers" ):
129+ # Check discard patterns first - skip entirely
130+ if _matches_any_pattern (key , discard_patterns ):
131+ discarded_keys += 1
132+ pbar .update (1 )
133+ continue
134+
135+ # Pre-load Model A's tensor with pinned memory for CUDA
136+ cpu_tensor = primary_handler .get_tensor (key )
137+ if process_device == 'cuda' :
138+ tensor_a = transfer_to_gpu_pinned (cpu_tensor , process_device , process_dtype )
139+ else :
140+ tensor_a = cpu_tensor .to (device = process_device , dtype = process_dtype )
141+ del cpu_tensor
142+
143+ # Check exclude patterns - use Model A only, no merge
144+ if _matches_any_pattern (key , exclude_patterns ):
145+ merged_state_dict [key ] = tensor_a .detach ().to (save_torch_dtype ).cpu ().clone ()
146+ excluded_keys += 1
147+ pbar .update (1 )
148+ continue
149+
150+ # Pass tensor_a metadata to recipes for zeros mode and fallback
151+ recipe_params ['_tensor_a' ] = tensor_a
152+ recipe_params ['_tensor_a_shape' ] = tensor_a .shape
153+ recipe_params ['_tensor_a_dtype' ] = tensor_a .dtype
154+
155+ try :
156+ recipe = calc_mode_class .create_recipe (key = key , ** recipe_params )
157+ result = recipe .merge ()
158+ except MissingTensorError as e :
159+ if mismatch_mode == MissingTensorBehavior .ERROR :
160+ raise ValueError (f"Layer mismatch error (mismatch_mode='error'): { e } " )
161+ result = None
162+ error_keys .append (key )
163+
164+ # Handle None result (mismatch occurred with skip mode)
165+ if result is None :
166+ result = tensor_a
167+ skipped_keys += 1
168+
169+ if isinstance (result , dict ):
170+ for r_key , r_tensor in result .items ():
171+ merged_state_dict [r_key ] = r_tensor .detach ().to (save_torch_dtype ).cpu ().clone ()
172+ else :
173+ # Ensure compatibility with Model A's architecture.
174+ # If alignment_mode is 'pad/crop', we crop results that were padded.
175+ # If alignment_mode is 'interpolate', resizing happened during operators.
176+ if alignment_mode == 'pad/crop' :
177+ target_shape = recipe_params ['_tensor_a_shape' ]
178+ if result .shape != target_shape :
179+ slices = tuple (slice (0 , min (res_s , tgt_s )) for res_s , tgt_s in zip (result .shape , target_shape ))
180+ result = result [slices ]
181+
182+ merged_state_dict [key ] = result .detach ().to (save_torch_dtype ).cpu ().clone ()
183+
184+ # Clean up references to allow GC immediately.
185+ # Local loop variables must be explicitly deleted to prevent PyTorch from keeping tensors in VRAM.
186+ recipe .clean ()
187+ del recipe_params ['_tensor_a' ]
188+ del tensor_a
189+ del recipe
190+ del result
191+
192+ if recipe_params .get ('force_clear_cache' , True ):
193+ import gc
194+ gc .collect ()
195+ if torch .cuda .is_available ():
196+ torch .cuda .empty_cache ()
197+
147198 pbar .update (1 )
148- continue
149-
150- # Pass tensor_a metadata to recipes for zeros mode and fallback
151- recipe_params ['_tensor_a' ] = tensor_a
152- recipe_params ['_tensor_a_shape' ] = tensor_a .shape
153- recipe_params ['_tensor_a_dtype' ] = tensor_a .dtype
154-
155- try :
156- recipe = calc_mode_class .create_recipe (key = key , ** recipe_params )
157- result = recipe .merge ()
158- except MissingTensorError as e :
159- if mismatch_mode == MissingTensorBehavior .ERROR :
160- raise ValueError (f"Layer mismatch error (mismatch_mode='error'): { e } " )
161- result = None
162- error_keys .append (key )
163-
164- # Handle None result (mismatch occurred with skip mode)
165- if result is None :
166- result = tensor_a
167- skipped_keys += 1
168-
169- if isinstance (result , dict ):
170- for r_key , r_tensor in result .items ():
171- merged_state_dict [r_key ] = r_tensor .to (save_torch_dtype ).cpu ().clone ()
172- else :
173- # Ensure compatibility with Model A's architecture.
174- # If alignment_mode is 'pad/crop', we crop results that were padded.
175- # If alignment_mode is 'interpolate', resizing happened during operators.
176- if alignment_mode == 'pad/crop' :
177- target_shape = recipe_params ['_tensor_a_shape' ]
178- if result .shape != target_shape :
179- slices = tuple (slice (0 , min (res_s , tgt_s )) for res_s , tgt_s in zip (result .shape , target_shape ))
180- result = result [slices ]
181-
182- merged_state_dict [key ] = result .to (save_torch_dtype ).cpu ().clone ()
183-
184- # Clean up tensor_a reference to allow GC
185- del recipe_params ['_tensor_a' ]
186-
187- if recipe_params .get ('force_clear_cache' , True ):
188- import gc
189- gc .collect ()
190- if torch .cuda .is_available ():
191- torch .cuda .empty_cache ()
192-
193- pbar .update (1 )
194199
195200 # Log summary
196201 if excluded_keys > 0 :
0 commit comments