Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lmms_eval/models/chat/internvl_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def _collate(x):
images_kwargs["min_patches"] = self.min_patches
if self.max_patches is not None:
images_kwargs["max_patches"] = self.max_patches
if self.num_frames is not None or self.fps is not None:
# InternVL only applies num_frames/fps when frame sampling is explicitly enabled.
videos_kwargs["do_sample_frames"] = True
if self.num_frames is not None:
videos_kwargs["num_frames"] = self.num_frames
if self.fps is not None:
Expand All @@ -260,6 +263,8 @@ def _collate(x):
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")

if len(visuals) == 0:
visuals = None
if len(videos) == 0:
videos = None
inputs = self.processor(
Expand All @@ -275,7 +280,7 @@ def _collate(x):
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]

gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
gen_kwargs["image_sizes"] = [visual.size for visual in visuals] if visuals is not None else []
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
Expand Down
Loading