Skip to content

Commit 2d47776

Browse files
authored
Merge pull request #9 from silveroxides/fix/loading-fixes
Fix models getting stuck in memory when merging
2 parents 6c4d505 + 4d49477 commit 2d47776

5 files changed

Lines changed: 194 additions & 219 deletions

File tree

nodes/lora_merger.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def merge_multi_loras(
4848

4949
prepare_for_large_operation(total_size_gb * 2.5, torch.device(device))
5050

51-
handlers = [MemoryEfficientSafeOpen(p) for p in lora_paths]
51+
handlers = [MemoryEfficientSafeOpen(p, low_memory=True) for p in lora_paths]
5252

5353
try:
5454
# 1. Analyze all LoRAs
@@ -190,6 +190,12 @@ def merge_multi_loras(
190190
del merged_down, merged_up
191191
for d, _ in downs: del d
192192
for u, _ in ups: del u
193+
downs.clear()
194+
ups.clear()
195+
import gc
196+
gc.collect()
197+
if torch.cuda.is_available():
198+
torch.cuda.empty_cache()
193199

194200
pbar.update(1)
195201

@@ -412,7 +418,7 @@ def merge_multi_loras_dare(
412418
print(f"[LoRA Multi-Merge DARE] Drop rate: {drop_rate}, Trim quantile: {trim_quantile}")
413419

414420
prepare_for_large_operation(total_size_gb * 2.5, torch.device(device))
415-
handlers = [MemoryEfficientSafeOpen(p) for p in lora_paths]
421+
handlers = [MemoryEfficientSafeOpen(p, low_memory=True) for p in lora_paths]
416422

417423
rng = torch.Generator(device=device).manual_seed(seed)
418424

@@ -523,6 +529,16 @@ def process_ties_dare(tensors, weights, dim_to_pad):
523529
output_sd[f"{core}.lora_up.weight"] = merged_up.to(save_dtype).cpu().contiguous()
524530
output_sd[f"{core}.alpha"] = torch.tensor(float(max_rank), dtype=save_dtype)
525531

532+
del merged_down, merged_up
533+
for d, _ in downs: del d
534+
for u, _ in ups: del u
535+
downs.clear()
536+
ups.clear()
537+
import gc
538+
gc.collect()
539+
if torch.cuda.is_available():
540+
torch.cuda.empty_cache()
541+
526542
pbar.update(1)
527543

528544
# Final Summary

nodes/lora_resize.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,12 @@ def resize_lora_file(
461461
if alpha_suffix:
462462
output_sd[f"{new_block_name}{alpha_suffix}"] = torch.tensor(result["new_alpha"]).to(save_dtype)
463463

464+
del result
465+
import gc
466+
gc.collect()
467+
if torch.cuda.is_available():
468+
torch.cuda.empty_cache()
469+
464470
pbar.update(1)
465471

466472
if verbose and fro_list:
@@ -854,6 +860,10 @@ def extract_core_layer_lora(block_name: str) -> str:
854860
del cpu_base
855861
stats["copied"] += 1
856862

863+
import gc
864+
gc.collect()
865+
if torch.cuda.is_available():
866+
torch.cuda.empty_cache()
857867

858868
pbar.update(1)
859869

nodes/merger.py

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,17 @@ def execute_merge(model_names, calc_mode, all_modes, recipe_params, model_type):
7777
print(f"[Merger] Preparing memory for {total_size_gb:.2f}GB merge operation...")
7878
prepare_for_large_operation(total_size_gb * 1.2, torch.device(process_device))
7979

80-
lazy_load = recipe_params.get('lazy_load', True)
8180
handlers = {}
8281
for name in model_names.values():
8382
if name and name != "None":
8483
path = folder_paths.get_full_path(model_type, name)
8584
if not path:
8685
raise FileNotFoundError(f"Model '{name}' not found.")
87-
handlers[name] = MemoryEfficientSafeOpen(path, low_memory=lazy_load)
86+
handlers[name] = MemoryEfficientSafeOpen(path, low_memory=True)
8887

8988
primary_handler = handlers[primary_model_name]
9089
all_keys = primary_handler.keys()
91-
metadata = primary_handler.header.get("__metadata__", {})
90+
metadata = primary_handler.metadata()
9291

9392
# Convert mismatch_mode string to enum
9493
mismatch_mode_str = recipe_params.get('mismatch_mode', 'skip')
@@ -125,72 +124,78 @@ def execute_merge(model_names, calc_mode, all_modes, recipe_params, model_type):
125124
discarded_keys = 0
126125
error_keys = []
127126

128-
for key in tqdm(all_keys, desc="Merging layers", unit="layers"):
129-
# Check discard patterns first - skip entirely
130-
if _matches_any_pattern(key, discard_patterns):
131-
discarded_keys += 1
132-
pbar.update(1)
133-
continue
134-
135-
# Pre-load Model A's tensor with pinned memory for CUDA
136-
cpu_tensor = primary_handler.get_tensor(key)
137-
if process_device == 'cuda':
138-
tensor_a = transfer_to_gpu_pinned(cpu_tensor, process_device, process_dtype)
139-
else:
140-
tensor_a = cpu_tensor.to(device=process_device, dtype=process_dtype)
141-
del cpu_tensor
142-
143-
# Check exclude patterns - use Model A only, no merge
144-
if _matches_any_pattern(key, exclude_patterns):
145-
merged_state_dict[key] = tensor_a.to(save_torch_dtype).cpu().clone()
146-
excluded_keys += 1
127+
with torch.no_grad():
128+
for key in tqdm(all_keys, desc="Merging layers", unit="layers"):
129+
# Check discard patterns first - skip entirely
130+
if _matches_any_pattern(key, discard_patterns):
131+
discarded_keys += 1
132+
pbar.update(1)
133+
continue
134+
135+
# Pre-load Model A's tensor with pinned memory for CUDA
136+
cpu_tensor = primary_handler.get_tensor(key)
137+
if process_device == 'cuda':
138+
tensor_a = transfer_to_gpu_pinned(cpu_tensor, process_device, process_dtype)
139+
else:
140+
tensor_a = cpu_tensor.to(device=process_device, dtype=process_dtype)
141+
del cpu_tensor
142+
143+
# Check exclude patterns - use Model A only, no merge
144+
if _matches_any_pattern(key, exclude_patterns):
145+
merged_state_dict[key] = tensor_a.detach().to(save_torch_dtype).cpu().clone()
146+
excluded_keys += 1
147+
pbar.update(1)
148+
continue
149+
150+
# Pass tensor_a metadata to recipes for zeros mode and fallback
151+
recipe_params['_tensor_a'] = tensor_a
152+
recipe_params['_tensor_a_shape'] = tensor_a.shape
153+
recipe_params['_tensor_a_dtype'] = tensor_a.dtype
154+
155+
try:
156+
recipe = calc_mode_class.create_recipe(key=key, **recipe_params)
157+
result = recipe.merge()
158+
except MissingTensorError as e:
159+
if mismatch_mode == MissingTensorBehavior.ERROR:
160+
raise ValueError(f"Layer mismatch error (mismatch_mode='error'): {e}")
161+
result = None
162+
error_keys.append(key)
163+
164+
# Handle None result (mismatch occurred with skip mode)
165+
if result is None:
166+
result = tensor_a
167+
skipped_keys += 1
168+
169+
if isinstance(result, dict):
170+
for r_key, r_tensor in result.items():
171+
merged_state_dict[r_key] = r_tensor.detach().to(save_torch_dtype).cpu().clone()
172+
else:
173+
# Ensure compatibility with Model A's architecture.
174+
# If alignment_mode is 'pad/crop', we crop results that were padded.
175+
# If alignment_mode is 'interpolate', resizing happened during operators.
176+
if alignment_mode == 'pad/crop':
177+
target_shape = recipe_params['_tensor_a_shape']
178+
if result.shape != target_shape:
179+
slices = tuple(slice(0, min(res_s, tgt_s)) for res_s, tgt_s in zip(result.shape, target_shape))
180+
result = result[slices]
181+
182+
merged_state_dict[key] = result.detach().to(save_torch_dtype).cpu().clone()
183+
184+
# Clean up references to allow GC immediately.
185+
# Local loop variables must be explicitly deleted to prevent PyTorch from keeping tensors in VRAM.
186+
recipe.clean()
187+
del recipe_params['_tensor_a']
188+
del tensor_a
189+
del recipe
190+
del result
191+
192+
if recipe_params.get('force_clear_cache', True):
193+
import gc
194+
gc.collect()
195+
if torch.cuda.is_available():
196+
torch.cuda.empty_cache()
197+
147198
pbar.update(1)
148-
continue
149-
150-
# Pass tensor_a metadata to recipes for zeros mode and fallback
151-
recipe_params['_tensor_a'] = tensor_a
152-
recipe_params['_tensor_a_shape'] = tensor_a.shape
153-
recipe_params['_tensor_a_dtype'] = tensor_a.dtype
154-
155-
try:
156-
recipe = calc_mode_class.create_recipe(key=key, **recipe_params)
157-
result = recipe.merge()
158-
except MissingTensorError as e:
159-
if mismatch_mode == MissingTensorBehavior.ERROR:
160-
raise ValueError(f"Layer mismatch error (mismatch_mode='error'): {e}")
161-
result = None
162-
error_keys.append(key)
163-
164-
# Handle None result (mismatch occurred with skip mode)
165-
if result is None:
166-
result = tensor_a
167-
skipped_keys += 1
168-
169-
if isinstance(result, dict):
170-
for r_key, r_tensor in result.items():
171-
merged_state_dict[r_key] = r_tensor.to(save_torch_dtype).cpu().clone()
172-
else:
173-
# Ensure compatibility with Model A's architecture.
174-
# If alignment_mode is 'pad/crop', we crop results that were padded.
175-
# If alignment_mode is 'interpolate', resizing happened during operators.
176-
if alignment_mode == 'pad/crop':
177-
target_shape = recipe_params['_tensor_a_shape']
178-
if result.shape != target_shape:
179-
slices = tuple(slice(0, min(res_s, tgt_s)) for res_s, tgt_s in zip(result.shape, target_shape))
180-
result = result[slices]
181-
182-
merged_state_dict[key] = result.to(save_torch_dtype).cpu().clone()
183-
184-
# Clean up tensor_a reference to allow GC
185-
del recipe_params['_tensor_a']
186-
187-
if recipe_params.get('force_clear_cache', True):
188-
import gc
189-
gc.collect()
190-
if torch.cuda.is_available():
191-
torch.cuda.empty_cache()
192-
193-
pbar.update(1)
194199

195200
# Log summary
196201
if excluded_keys > 0:

nodes/merger_ops.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from enum import Enum
55

66

7+
from .merger_utils import transfer_to_gpu_pinned
8+
79
class MissingTensorBehavior(Enum):
810
"""Controls behavior when a tensor key is missing from a model."""
911
ERROR = "error" # Raise exception (strict mode)
@@ -127,12 +129,21 @@ def oper(self, *args) -> torch.Tensor:
127129
raise NotImplementedError
128130

129131
def recurse(self, operation):
130-
source_tensors = [source_oper.merge() for source_oper in operation.sources]
131-
return operation.oper(*source_tensors)
132+
self._source_tensors = [source_oper.merge() for source_oper in operation.sources]
133+
return operation.oper(*self._source_tensors)
132134

133135
def merge(self):
134136
return self.merge_func(self)
135137

138+
def clean(self):
139+
if hasattr(self, '_source_tensors'):
140+
for t in self._source_tensors:
141+
del t
142+
del self._source_tensors
143+
for source in self.sources:
144+
if hasattr(source, 'clean'):
145+
source.clean()
146+
136147
class LoadTensor(Operation):
137148
def __init__(self, key, model_name, handlers, device, dtype,
138149
on_missing=MissingTensorBehavior.ERROR, fallback_shape=None, fallback_dtype=None):
@@ -165,7 +176,17 @@ def merge(self) -> torch.Tensor:
165176
dtype = self.fallback_dtype if self.fallback_dtype else self.dtype
166177
return torch.zeros(self.fallback_shape, device=self.device, dtype=dtype)
167178

168-
return handler.get_tensor(self.key).to(device=self.device, dtype=self.dtype)
179+
cpu_tensor = handler.get_tensor(self.key)
180+
if self.device == 'cuda':
181+
self._tensor = transfer_to_gpu_pinned(cpu_tensor, self.device, self.dtype)
182+
else:
183+
self._tensor = cpu_tensor.to(device=self.device, dtype=self.dtype)
184+
del cpu_tensor
185+
return self._tensor
186+
187+
def clean(self):
188+
if hasattr(self, '_tensor'):
189+
del self._tensor
169190

170191
class Multiply(Operation):
171192
def __init__(self, key, alpha, *sources):
@@ -409,6 +430,10 @@ def __init__(self, key, tensor):
409430
def merge(self) -> torch.Tensor:
410431
return self.tensor
411432

433+
def clean(self):
434+
if hasattr(self, 'tensor'):
435+
del self.tensor
436+
412437
class WeightSum(CalcMode):
413438
name = 'Weight-Sum'
414439
description = 'A * (1 - α) + B * α. Simple linear interpolation.'

0 commit comments

Comments
 (0)