Skip to content

Commit 2fd6a5d

Browse files
committed
fix: revert layer offload iteration
1 parent 3b8057c commit 2fd6a5d

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

OmniGen/transformer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@ def prefetch_layer(self, layer_idx: int, device: torch.device):
3333
"Starts prefetching the next layer cache"
3434
with torch.cuda.stream(self.prefetch_stream):
3535
# Prefetch next layer tensors to GPU
36-
self.layers[layer_idx] = self.layers[layer_idx].to(device, non_blocking=True)
36+
for name, param in self.layers[layer_idx].named_parameters():
37+
param.data = param.data.to(device, non_blocking=True)
3738

3839
def evict_previous_layer(self, layer_idx: int):
3940
"Moves the previous layer cache to the CPU"
4041
prev_layer_idx = layer_idx - 1
41-
self.layers[prev_layer_idx] = self.layers[prev_layer_idx].to("cpu")
42-
42+
for name, param in self.layers[prev_layer_idx].named_parameters():
43+
param.data = param.data.to("cpu")
44+
4345
def get_offload_layer(self, layer_idx: int, device: torch.device):
4446
# init stream
4547
if not hasattr(self, "prefetch_stream"):

0 commit comments

Comments
 (0)