11import json
22import os
3- from typing import Any , Dict , Tuple
3+ from typing import Any , Dict
44
5- import numpy as np
65import torch
76from PIL import Image
87
8+ from lmms_engine .datasets .codec_video_mixin import CodecVideoLoadingMixin
99from lmms_engine .datasets .iterable .vision_iterable_dataset import (
1010 VisionSFTIterableDataset ,
1111)
1414
1515
1616@register_dataset ("llava_ov2_iterable" )
17- class LlavaOv2IterableDataset (VisionSFTIterableDataset ):
17+ class LlavaOv2IterableDataset (CodecVideoLoadingMixin , VisionSFTIterableDataset ):
1818 """Iterable dataset for LLaVA-OneVision-2 with codec-stream video input.
1919
2020 Reuses ``VisionSFTIterableDataset`` plumbing but routes video loading
21- through the ``lmms_video_utils`` backend so each video produces a
22- ``CodecVideoOutput`` (canvases + patch_positions + source_pts) that the
23- downstream processor can consume directly instead of re-deriving
24- timestamps from frame index.
21+ through the ``lmms_video_utils`` backend (via ``CodecVideoLoadingMixin``)
22+ so each video produces a ``CodecVideoOutput`` (canvases + patch_positions
23+ + source_pts) that the downstream processor can consume directly instead
24+ of re-deriving timestamps from frame index.
2525 """
2626
2727 def load_from_json (self , data , data_folder = None ) -> Dict [str , torch .Tensor ]:
2828 images_list = []
29- videos = []
30- video_metadata_list = []
3129 kwargs : Dict [str , Any ] = {}
3230 messages = data ["messages" ]
3331 if isinstance (messages , str ):
3432 messages = json .loads (messages )
33+
3534 for message in messages :
3635 for content in message ["content" ]:
3736 if content ["type" ] == "image_url" :
3837 images_list .append (content ["image_url" ]["url" ])
39- elif content ["type" ] == "video_url" :
40- video_url = content ["video_url" ]
41- extra = {k : v for k , v in video_url .items () if k != "url" and v is not None }
42- frames , sample_fps , codec_output = self .load_videos (
43- video_url ["url" ],
44- data_folder = data_folder ,
45- fps = self .config .fps ,
46- video_kwargs = extra or None ,
47- )
48- videos .append (frames )
49- video_metadata_list .append (codec_output )
50- kwargs ["fps" ] = sample_fps
38+
39+ videos , video_metadata_list , sample_fps = self .collect_codec_video_inputs (messages , data_folder = data_folder )
40+ if sample_fps is not None :
41+ kwargs ["fps" ] = sample_fps
5142
5243 hf_messages = TrainUtilities .convert_open_to_hf (messages )
5344 if data_folder is not None :
@@ -58,23 +49,8 @@ def load_from_json(self, data, data_folder=None) -> Dict[str, torch.Tensor]:
5849 images = None
5950 if len (videos ) == 0 :
6051 videos = None
61- video_metadata_list = None
62- if video_metadata_list is not None :
52+ else :
6353 kwargs ["video_metadata" ] = video_metadata_list
6454
6555 inputs = self .processor .process (images = images , hf_messages = hf_messages , videos = videos , ** kwargs )
6656 return inputs
67-
68- def load_videos (
69- self ,
70- video_path : str ,
71- data_folder = None ,
72- fps : int = 1 ,
73- video_kwargs = None ,
74- ) -> Tuple [np .ndarray , float , Any ]:
75- assert (
76- self .config .video_backend == "lmms_video_utils"
77- ), "LlavaOv2IterableDataset only supports lmms_video_utils backend"
78- if data_folder is not None :
79- video_path = os .path .join (data_folder , video_path )
80- return self .load_video_lmms_video_utils (video_path , fps , video_kwargs = video_kwargs )
0 commit comments