@@ -93,6 +93,7 @@ class MemoryEfficientSafeOpen:
9393
9494 Features:
9595 - mmap mode: Zero-copy tensor access via memory-mapped file
96+ - low_memory mode: Direct file reads to minimize OS page cache usage
9697 - Parallel loading: Multi-threaded tensor reads for 2-4x speedup
9798 - Sorted batch reads: Keys sorted by file offset for sequential I/O
9899 - Auto-optimized workers: Adjusts parallelism based on device capabilities
@@ -102,12 +103,15 @@ class MemoryEfficientSafeOpen:
102103 filename: Path to safetensors file
103104 device: Target device (default 'cpu')
104105 mmap_mode: Use memory-mapped file for zero-copy (default True)
106+ low_memory: Use direct file reads to minimize memory (overrides mmap_mode)
105107 """
106108
107- def __init__ (self , filename : str , device : str = 'cpu' , mmap_mode : bool = True ):
109+ def __init__ (self , filename : str , device : str = 'cpu' , mmap_mode : bool = True , low_memory : bool = False ):
108110 self .filename = filename
109111 self .device = device
110- self .mmap_mode = mmap_mode
112+ # low_memory mode forces mmap off to avoid OS page caching
113+ self .low_memory = low_memory
114+ self .mmap_mode = mmap_mode and not low_memory
111115 self .header , self .header_size = self ._read_header ()
112116 self .file = open (filename , "rb" )
113117 self .mmap_obj = None
@@ -123,6 +127,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
123127 self .mmap_obj .close ()
124128 self .file .close ()
125129
130+
126131 def keys (self ) -> List [str ]:
127132 """Return all tensor keys (excluding metadata)."""
128133 return [k for k in self .header .keys () if k != "__metadata__" ]
@@ -156,13 +161,17 @@ def get_tensor(self, key: str) -> torch.Tensor:
156161 else :
157162 tensor_bytes = None
158163 else :
164+ # Non-mmap mode: use pre-allocated bytearray with readinto for minimal copies
159165 tensor_bytes = None
160166 if offset_start != offset_end :
161167 self .file .seek (self .header_size + 8 + offset_start )
162- tensor_bytes = self .file .read (offset_end - offset_start )
168+ # Pre-allocate writable buffer and read directly into it
169+ tensor_bytes = bytearray (offset_end - offset_start )
170+ self .file .readinto (tensor_bytes )
163171
164172 return self ._deserialize_tensor (tensor_bytes , metadata )
165173
174+
166175 def get_tensor_to_gpu (
167176 self ,
168177 key : str ,
@@ -276,13 +285,19 @@ def _deserialize_tensor(self, tensor_bytes, metadata):
276285 if tensor_bytes is None :
277286 byte_tensor = torch .empty (0 , dtype = torch .uint8 )
278287 else :
279- byte_tensor = torch .frombuffer (bytearray (tensor_bytes ), dtype = torch .uint8 )
288+ # Avoid extra copy if already a bytearray (low_memory mode)
289+ if isinstance (tensor_bytes , bytearray ):
290+ byte_tensor = torch .frombuffer (tensor_bytes , dtype = torch .uint8 )
291+ else :
292+ # mmap memoryview needs copy to create writable tensor
293+ byte_tensor = torch .frombuffer (bytearray (tensor_bytes ), dtype = torch .uint8 )
280294
281295 if dtype_str in ["F8_E5M2" , "F8_E4M3" ]:
282296 return self ._convert_float8 (byte_tensor , dtype_str , shape )
283297
284298 return byte_tensor .view (dtype ).reshape (shape )
285299
300+
286301 @staticmethod
287302 def _get_torch_dtype (dtype_str ):
288303 dtype_map = {
0 commit comments