Skip to content

Commit d41dc35

Browse files
committed
Add video dataset utils
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 02fb601 commit d41dc35

File tree

1 file changed

+332
-0
lines changed

1 file changed

+332
-0
lines changed
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utility functions for getting samples and forward loop function for video datasets."""
17+
18+
import os
19+
import tempfile
20+
from typing import Any
21+
22+
import torch
23+
from torch.utils.data import DataLoader
24+
25+
from .image_processor import BaseImageProcessor
26+
27+
# Use dict to store the config for each dataset.
28+
SUPPORTED_VIDEO_DATASET_CONFIG: dict[str, dict[str, Any]] = {
29+
"finevideo": {
30+
"config": {"path": "HuggingFaceFV/finevideo", "split": "train", "streaming": True}
31+
},
32+
}
33+
34+
__all__ = [
35+
"Qwen3OmniVideoProcessor",
36+
"get_supported_video_datasets",
37+
"get_video_dataset_dataloader",
38+
]
39+
40+
41+
def _get_video_dataset(dataset_name: str, num_samples: int):
42+
"""Load a portion of train dataset with the dataset name and a given size.
43+
44+
Args:
45+
dataset_name: Name of the dataset to load.
46+
num_samples: Number of samples to load from the dataset.
47+
48+
Returns:
49+
A hugging face Dataset.
50+
"""
51+
if dataset_name in SUPPORTED_VIDEO_DATASET_CONFIG:
52+
from datasets import Dataset, load_dataset
53+
54+
config = SUPPORTED_VIDEO_DATASET_CONFIG[dataset_name]["config"]
55+
is_streaming = config.get("streaming", False)
56+
57+
dataset = load_dataset(**config)
58+
59+
if is_streaming:
60+
# For streaming datasets, use take() and convert to list then Dataset
61+
samples = list(dataset.take(num_samples))
62+
return Dataset.from_list(samples)
63+
else:
64+
return dataset.select(range(num_samples))
65+
else:
66+
raise NotImplementedError(
67+
f"dataset {dataset_name} is not supported. Please use one of the following:"
68+
f" {get_supported_video_datasets()}."
69+
)
70+
71+
72+
def get_supported_video_datasets() -> list[str]:
73+
"""Retrieves a list of video datasets supported.
74+
75+
Returns:
76+
A list of strings, where each string is the name of a supported dataset.
77+
78+
Example usage:
79+
80+
.. code-block:: python
81+
82+
from modelopt.torch.utils import get_supported_video_datasets
83+
84+
print("Supported video datasets:", get_supported_video_datasets())
85+
"""
86+
return list(SUPPORTED_VIDEO_DATASET_CONFIG.keys())
87+
88+
89+
def get_video_dataset_dataloader(
90+
dataset_name: str = "finevideo",
91+
processor: "Qwen3OmniVideoProcessor" = None,
92+
batch_size: int = 1,
93+
num_samples: int = 512,
94+
cache_dir: str | None = None,
95+
) -> DataLoader:
96+
"""Get a dataloader with the dataset name and processor of the target model.
97+
98+
Args:
99+
dataset_name: Name of the dataset to load.
100+
processor: Processor used for encoding video and text data.
101+
batch_size: Batch size of the returned dataloader.
102+
num_samples: Number of samples from the dataset.
103+
cache_dir: Directory to cache the processed dataset. Defaults to a temp directory.
104+
If the cache exists, it will be loaded instead of reprocessing.
105+
106+
Returns:
107+
An instance of dataloader.
108+
"""
109+
assert processor is not None, "Please provide a valid processor."
110+
111+
# Default cache_dir to temp directory
112+
if cache_dir is None:
113+
cache_dir = os.path.join(tempfile.gettempdir(), "modelopt_video_dataset_cache")
114+
115+
processed_dataset = None
116+
117+
# Try to load from cache (use torch.save/load to avoid Arrow 32-bit offset overflow)
118+
if cache_dir is not None:
119+
cache_path = os.path.join(cache_dir, f"{dataset_name}_n{num_samples}_processed.pt")
120+
if os.path.exists(cache_path):
121+
try:
122+
from datasets import Dataset
123+
124+
processed_samples = torch.load(cache_path, weights_only=False)
125+
processed_dataset = Dataset.from_list(processed_samples)
126+
print(f"Loaded processed dataset from cache: {cache_path}")
127+
except Exception as e:
128+
print(f"Failed to load cache from {cache_path}: {e}. Reprocessing...")
129+
processed_dataset = None
130+
131+
# Process dataset if not loaded from cache
132+
if processed_dataset is None:
133+
from datasets import Dataset
134+
135+
dataset = _get_video_dataset(dataset_name, num_samples=num_samples)
136+
137+
# Process samples manually to avoid Arrow 32-bit offset overflow
138+
# (dataset.map() uses Arrow internally which can't handle large nested lists)
139+
processed_samples = []
140+
for i, sample in enumerate(dataset):
141+
processed = processor.preprocess_function(sample)
142+
processed_samples.append(processed)
143+
if (i + 1) % 10 == 0:
144+
print(f"Processed {i + 1}/{len(dataset)} samples...")
145+
146+
processed_dataset = Dataset.from_list(processed_samples)
147+
148+
# Save to cache using torch.save to avoid Arrow 32-bit offset overflow
149+
if cache_dir is not None:
150+
os.makedirs(cache_dir, exist_ok=True)
151+
torch.save(processed_samples, cache_path)
152+
print(f"Saved processed dataset to cache: {cache_path}")
153+
154+
# Create DataLoader with the custom collate function
155+
return DataLoader(
156+
processed_dataset,
157+
batch_size=batch_size,
158+
shuffle=False,
159+
collate_fn=processor.collate_function,
160+
)
161+
162+
163+
class Qwen3OmniVideoProcessor(BaseImageProcessor):
164+
"""Video processor for Qwen3-Omni multimodal model with finevideo dataset support."""
165+
166+
def __init__(self, tokenizer, device="cuda", dtype=None, use_audio_in_video=True):
167+
"""Constructor.
168+
169+
Args:
170+
tokenizer: The Qwen3OmniMoeProcessor for tokenizing and processing inputs.
171+
device: Device to move tensors to.
172+
dtype: dtype for float tensors (e.g., torch.bfloat16). If None, uses default.
173+
use_audio_in_video: Whether to extract and use audio from video files.
174+
"""
175+
super().__init__(tokenizer, device)
176+
self.dtype = dtype
177+
self.use_audio_in_video = use_audio_in_video
178+
self._temp_dir = tempfile.mkdtemp(prefix="qwen3omni_video_")
179+
self._video_counter = 0
180+
# Try to import qwen_omni_utils for multimodal processing
181+
try:
182+
from qwen_omni_utils import process_mm_info
183+
184+
self.process_mm_info = process_mm_info
185+
except ImportError:
186+
raise ImportError(
187+
"qwen_omni_utils is required for Qwen3OmniVideoProcessor. "
188+
"Please install it from https://github.com/QwenLM/Qwen3-Omni"
189+
)
190+
191+
def _save_video_bytes_to_file(self, video_bytes: bytes) -> str:
192+
"""Save video bytes to a temporary file and return the path.
193+
194+
Args:
195+
video_bytes: Raw video bytes (e.g., from finevideo's 'mp4' field).
196+
197+
Returns:
198+
Path to the temporary video file.
199+
"""
200+
video_path = os.path.join(self._temp_dir, f"video_{self._video_counter}.mp4")
201+
self._video_counter += 1
202+
with open(video_path, "wb") as f:
203+
f.write(video_bytes)
204+
return video_path
205+
206+
def preprocess_function(self, examples):
207+
"""Preprocess function for Qwen3-Omni with video support.
208+
209+
Handles both standard video paths and raw video bytes (finevideo format).
210+
"""
211+
# Get question/prompt - finevideo has metadata in 'json' field
212+
if "json" in examples and examples["json"] is not None:
213+
metadata = examples["json"]
214+
# Try to get a meaningful question from metadata
215+
category = metadata.get("content_fine_category", "")
216+
question = (
217+
f"Describe what is happening in this video in detail. Category hint: {category}"
218+
)
219+
else:
220+
question = examples.get("question", "Describe this video in detail.")
221+
222+
# Build conversation in Qwen format
223+
content = []
224+
225+
# Handle video - check for raw bytes (finevideo format) or path
226+
video_path = None
227+
if examples.get("mp4") is not None:
228+
# finevideo format: raw video bytes in 'mp4' field
229+
video_path = self._save_video_bytes_to_file(examples["mp4"])
230+
elif examples.get("video") is not None:
231+
# Standard format: video path or URL
232+
video_path = examples["video"]
233+
234+
if video_path is not None:
235+
content.append({"type": "video", "video": video_path})
236+
237+
content.append({"type": "text", "text": question})
238+
239+
conversation = [{"role": "user", "content": content}]
240+
text = self.tokenizer.apply_chat_template(
241+
conversation, add_generation_prompt=True, tokenize=False, enable_thinking=False
242+
)
243+
244+
# Extract multimodal info using qwen_omni_utils
245+
audios, images, videos = self.process_mm_info(
246+
conversation, use_audio_in_video=self.use_audio_in_video
247+
)
248+
249+
# Process inputs with the processor
250+
values = self.tokenizer(
251+
text=text,
252+
audio=audios,
253+
images=images,
254+
videos=videos,
255+
return_tensors="pt",
256+
padding=True,
257+
use_audio_in_video=self.use_audio_in_video,
258+
)
259+
# Define all possible keys to ensure consistent schema for Arrow serialization
260+
all_keys = [
261+
"input_ids",
262+
"attention_mask",
263+
"pixel_values_videos",
264+
"video_grid_thw",
265+
"video_second_per_grid",
266+
"feature_attention_mask",
267+
"input_features",
268+
]
269+
270+
# Convert tensors to lists for Arrow serialization compatibility
271+
# Tensor conversion back happens in collate_function
272+
result = dict.fromkeys(all_keys) # Initialize all keys to None
273+
for key, val in values.items():
274+
if val is not None and hasattr(val, "tolist"):
275+
result[key] = val.tolist()
276+
elif val is not None:
277+
result[key] = val
278+
279+
return result
280+
281+
def collate_function(self, batch):
282+
"""Collate function to process inputs during data loading."""
283+
result = {}
284+
285+
# Take first item from batch (batch_size handling)
286+
first = batch[0]
287+
288+
# Convert lists to tensors and move to device
289+
if first.get("input_ids") is not None:
290+
result["input_ids"] = torch.LongTensor(first["input_ids"]).to(self.device)
291+
if first.get("attention_mask") is not None:
292+
result["attention_mask"] = torch.LongTensor(first["attention_mask"]).to(self.device)
293+
294+
# Handle pixel values for video frames
295+
if first.get("pixel_values_videos") is not None:
296+
pv = torch.tensor(first["pixel_values_videos"])
297+
if self.dtype is not None:
298+
pv = pv.to(self.dtype)
299+
result["pixel_values_videos"] = pv.to(self.device)
300+
301+
# Handle video grid thw (tile height width info)
302+
if first.get("video_grid_thw") is not None:
303+
result["video_grid_thw"] = torch.LongTensor(first["video_grid_thw"]).to(self.device)
304+
305+
# Handle video second per grid (temporal info for rope)
306+
if first.get("video_second_per_grid") is not None:
307+
result["video_second_per_grid"] = torch.tensor(first["video_second_per_grid"]).to(
308+
self.device
309+
)
310+
311+
# Handle audio features if present
312+
if first.get("feature_attention_mask") is not None:
313+
result["feature_attention_mask"] = torch.LongTensor(first["feature_attention_mask"]).to(
314+
self.device
315+
)
316+
if first.get("input_features") is not None:
317+
inp_feat = torch.tensor(first["input_features"])
318+
if self.dtype is not None:
319+
inp_feat = inp_feat.to(self.dtype)
320+
result["input_features"] = inp_feat.to(self.device)
321+
322+
# Pass use_audio_in_video flag to model.generate() for Qwen3Omni
323+
result["use_audio_in_video"] = self.use_audio_in_video
324+
325+
return result
326+
327+
def cleanup(self):
328+
"""Clean up temporary video files."""
329+
import shutil
330+
331+
if os.path.exists(self._temp_dir):
332+
shutil.rmtree(self._temp_dir)

0 commit comments

Comments
 (0)