Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
_swap_torchao_tensor(tensor, moved)
else:
tensor.data = moved
# `record_stream` only delays deallocation of the underlying block until the
# consumer stream is done — it is NOT a pre-write barrier. Cross-stream
# synchronization is provided separately by `_gate_default_stream_on_transfer`.
if self.record_stream:
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
Expand Down Expand Up @@ -248,6 +251,25 @@ def _check_disk_offload_torchao(self):
"setting `offload_to_disk_path`."
)

def _gate_default_stream_on_transfer(self):
"""Block the default stream on the transfer stream completing.

Without this barrier, the first op on the default stream (e.g. the first
matmul of the loaded module) can begin executing before the non-blocking
copies on the transfer stream have completed, producing pre-copy state
reads. The PyTorch streams contract assigns this synchronization to the
user; see
https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-streams.
No-op when streams aren't enabled.
"""
if self.stream is None:
return
current_default = self._torch_accelerator_module.current_stream()
if hasattr(current_default, "wait_stream"):
current_default.wait_stream(self.stream)
else:
self.stream.synchronize()

def _onload_from_disk(self):
self._check_disk_offload_torchao()

Expand Down Expand Up @@ -277,6 +299,8 @@ def _onload_from_disk(self):
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]

self._gate_default_stream_on_transfer()

def _onload_from_memory(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
Expand All @@ -292,6 +316,8 @@ def _onload_from_memory(self):
else:
self._process_tensors_from_modules(None)

self._gate_default_stream_on_transfer()

def _offload_to_disk(self):
self._check_disk_offload_torchao()

Expand Down
Loading