Skip to content

Commit 9aee939

Browse files
Fix duplicate safetensors.load_file call in _onload_from_disk when st… (#13851)
Fix duplicate safetensors.load_file call in _onload_from_disk when stream is None Signed-off-by: Gagan Dhakrey <gagandhakrey@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 6dbf6e0 commit 9aee939

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

src/diffusers/hooks/group_offloading.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,17 +259,16 @@ def _onload_from_disk(self):
259259
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
260260

261261
with context:
262-
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
263-
device = str(self.onload_device) if self.stream is None else "cpu"
264-
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
265-
266262
if self.stream is not None:
263+
# Load to CPU first, pin memory, then async copy to the target device
264+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
267265
for key, tensor_obj in self.key_to_tensor.items():
268266
pinned_tensor = loaded_tensors[key].pin_memory()
269267
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
270268
if self.record_stream:
271269
tensor_obj.data.record_stream(current_stream)
272270
else:
271+
# Load directly to the target device
273272
onload_device = (
274273
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
275274
)

0 commit comments

Comments
 (0)