Skip to content

Commit 5b6d112

Browse files
authored
Merge pull request #843 from wuutiing/main
add read gifs as video support
2 parents febdaf6 + f641800 commit 5b6d112

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

diffsynth/trainers/utils.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(
154154
height_division_factor=16, width_division_factor=16,
155155
data_file_keys=("video",),
156156
image_file_extension=("jpg", "jpeg", "png", "webp"),
157-
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
157+
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm", "gif"),
158158
repeat=1,
159159
args=None,
160160
):
@@ -259,8 +259,53 @@ def get_num_frames(self, reader):
259259
num_frames -= 1
260260
return num_frames
261261

262-
262+
def _load_gif(self, file_path):
263+
gif_img = Image.open(file_path)
264+
frame_count = 0
265+
delays, frames = [], []
266+
while True:
267+
delay = gif_img.info.get('duration', 100) # ms
268+
delays.append(delay)
269+
rgb_frame = gif_img.convert("RGB")
270+
croped_frame = self.crop_and_resize(rgb_frame, *self.get_height_width(rgb_frame))
271+
frames.append(croped_frame)
272+
frame_count += 1
273+
try:
274+
gif_img.seek(frame_count)
275+
except:
276+
break
277+
# delays canbe used to calculate framerates
278+
# i guess it is better to sample images with stable interval,
279+
# and using minimal_interval as the interval,
280+
# and framerate = 1000 / minimal_interval
281+
if any((delays[0] != i) for i in delays):
282+
minimal_interval = min([i for i in delays if i > 0])
283+
# make a ((start,end),frameid) struct
284+
start_end_idx_map = [((sum(delays[:i]), sum(delays[:i+1])), i) for i in range(len(delays))]
285+
_frames = []
286+
# according gemini-code-assist, make it more efficient to locate
287+
# where to sample the frame
288+
last_match = 0
289+
for i in range(sum(delays) // minimal_interval):
290+
current_time = minimal_interval * i
291+
for idx, ((start, end), frame_idx) in enumerate(start_end_idx_map[last_match:]):
292+
if start <= current_time < end:
293+
_frames.append(frames[frame_idx])
294+
last_match = idx + last_match
295+
break
296+
frames = _frames
297+
num_frames = len(frames)
298+
if num_frames > self.num_frames:
299+
num_frames = self.num_frames
300+
else:
301+
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
302+
num_frames -= 1
303+
frames = frames[:num_frames]
304+
return frames
305+
263306
def load_video(self, file_path):
307+
if file_path.lower().endswith(".gif"):
308+
return self._load_gif(file_path)
264309
reader = imageio.get_reader(file_path)
265310
num_frames = self.get_num_frames(reader)
266311
frames = []

0 commit comments

Comments
 (0)