Skip to content

Commit 8deb4fd

Browse files
committed
perf: Add low_memory mode to reduce RAM usage in LoRA merge
- MemoryEfficientSafeOpen now supports low_memory=True which disables mmap - Uses direct file reads with readinto() to avoid OS page caching - _deserialize_tensor avoids bytearray copy when input is already bytearray - merge_loras_to_model now uses low_memory=True for base model This should reduce peak RAM from 2x model size to 1x + overhead
1 parent e9f7cda commit 8deb4fd

2 files changed

Lines changed: 22 additions & 6 deletions

File tree

nodes/lora_resize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,10 +1625,11 @@ def merge_loras_to_model(
16251625
print(f"[LoRA Merge To Model] Merging {len(lora_paths)} LoRAs with weights: {lora_weights}")
16261626
prepare_for_large_operation(total_size_gb * 1.5, torch.device(device))
16271627

1628-
# Open all files
1629-
base_handler = MemoryEfficientSafeOpen(base_model_path)
1628+
# Open all files - use low_memory for base model to avoid OS page caching
1629+
base_handler = MemoryEfficientSafeOpen(base_model_path, low_memory=True)
16301630
lora_handlers = [MemoryEfficientSafeOpen(lp) for lp in lora_paths]
16311631

1632+
16321633
try:
16331634
# Detect format and extract pairs for each LoRA
16341635
lora_infos = []

nodes/merger_utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class MemoryEfficientSafeOpen:
9393
9494
Features:
9595
- mmap mode: Zero-copy tensor access via memory-mapped file
96+
- low_memory mode: Direct file reads to minimize OS page cache usage
9697
- Parallel loading: Multi-threaded tensor reads for 2-4x speedup
9798
- Sorted batch reads: Keys sorted by file offset for sequential I/O
9899
- Auto-optimized workers: Adjusts parallelism based on device capabilities
@@ -102,12 +103,15 @@ class MemoryEfficientSafeOpen:
102103
filename: Path to safetensors file
103104
device: Target device (default 'cpu')
104105
mmap_mode: Use memory-mapped file for zero-copy (default True)
106+
low_memory: Use direct file reads to minimize memory (overrides mmap_mode)
105107
"""
106108

107-
def __init__(self, filename: str, device: str = 'cpu', mmap_mode: bool = True):
109+
def __init__(self, filename: str, device: str = 'cpu', mmap_mode: bool = True, low_memory: bool = False):
108110
self.filename = filename
109111
self.device = device
110-
self.mmap_mode = mmap_mode
112+
# low_memory mode forces mmap off to avoid OS page caching
113+
self.low_memory = low_memory
114+
self.mmap_mode = mmap_mode and not low_memory
111115
self.header, self.header_size = self._read_header()
112116
self.file = open(filename, "rb")
113117
self.mmap_obj = None
@@ -123,6 +127,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
123127
self.mmap_obj.close()
124128
self.file.close()
125129

130+
126131
def keys(self) -> List[str]:
127132
"""Return all tensor keys (excluding metadata)."""
128133
return [k for k in self.header.keys() if k != "__metadata__"]
@@ -156,13 +161,17 @@ def get_tensor(self, key: str) -> torch.Tensor:
156161
else:
157162
tensor_bytes = None
158163
else:
164+
# Non-mmap mode: use pre-allocated bytearray with readinto for minimal copies
159165
tensor_bytes = None
160166
if offset_start != offset_end:
161167
self.file.seek(self.header_size + 8 + offset_start)
162-
tensor_bytes = self.file.read(offset_end - offset_start)
168+
# Pre-allocate writable buffer and read directly into it
169+
tensor_bytes = bytearray(offset_end - offset_start)
170+
self.file.readinto(tensor_bytes)
163171

164172
return self._deserialize_tensor(tensor_bytes, metadata)
165173

174+
166175
def get_tensor_to_gpu(
167176
self,
168177
key: str,
@@ -276,13 +285,19 @@ def _deserialize_tensor(self, tensor_bytes, metadata):
276285
if tensor_bytes is None:
277286
byte_tensor = torch.empty(0, dtype=torch.uint8)
278287
else:
279-
byte_tensor = torch.frombuffer(bytearray(tensor_bytes), dtype=torch.uint8)
288+
# Avoid extra copy if already a bytearray (low_memory mode)
289+
if isinstance(tensor_bytes, bytearray):
290+
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
291+
else:
292+
# mmap memoryview needs copy to create writable tensor
293+
byte_tensor = torch.frombuffer(bytearray(tensor_bytes), dtype=torch.uint8)
280294

281295
if dtype_str in ["F8_E5M2", "F8_E4M3"]:
282296
return self._convert_float8(byte_tensor, dtype_str, shape)
283297

284298
return byte_tensor.view(dtype).reshape(shape)
285299

300+
286301
@staticmethod
287302
def _get_torch_dtype(dtype_str):
288303
dtype_map = {

0 commit comments

Comments
 (0)