Skip to content

Commit 826bab5

Browse files
authored
Merge pull request #859 from krahets/main
Fix batch decoding for Wan-Video-VAE
2 parents 5b6d112 + 419d47c commit 826bab5

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

diffsynth/models/wan_video_vae.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,6 @@ def single_decode(self, hidden_state, device):
12161216

12171217

12181218
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
1219-
12201219
videos = [video.to("cpu") for video in videos]
12211220
hidden_states = []
12221221
for video in videos:
@@ -1234,11 +1233,18 @@ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(1
12341233

12351234

12361235
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
1237-
if tiled:
1238-
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
1239-
else:
1240-
video = self.single_decode(hidden_states, device)
1241-
return video
1236+
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
1237+
videos = []
1238+
for hidden_state in hidden_states:
1239+
hidden_state = hidden_state.unsqueeze(0)
1240+
if tiled:
1241+
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
1242+
else:
1243+
video = self.single_decode(hidden_state, device)
1244+
video = video.squeeze(0)
1245+
videos.append(video)
1246+
videos = torch.stack(videos)
1247+
return videos
12421248

12431249

12441250
@staticmethod

0 commit comments

Comments
 (0)