Skip to content

Commit e3337a0

Browse files
committed
Create a script to cache the processed dataset
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent e8d9b0e commit e3337a0

2 files changed

Lines changed: 148 additions & 5 deletions

File tree

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Script to pre-generate processed video dataset for Qwen3-Omni quantization."""
18+
19+
import argparse
20+
import os
21+
22+
import torch
23+
from transformers import AutoProcessor
24+
25+
from modelopt.torch.utils.video_dataset_utils import (
26+
Qwen3OmniVideoProcessor,
27+
get_video_dataset_dataloader,
28+
)
29+
30+
31+
def main():
32+
parser = argparse.ArgumentParser(description="Generate processed video dataset cache")
33+
parser.add_argument(
34+
"--model-name",
35+
type=str,
36+
default="Qwen/Qwen3-Omni-30B-A3B-Thinking",
37+
help="Model name or path for loading the processor",
38+
)
39+
parser.add_argument(
40+
"--dataset-name",
41+
type=str,
42+
default="finevideo",
43+
help="Name of the video dataset to process",
44+
)
45+
parser.add_argument(
46+
"--num-samples",
47+
type=int,
48+
default=512,
49+
help="Number of samples to process",
50+
)
51+
parser.add_argument(
52+
"--cache-dir",
53+
type=str,
54+
required=True,
55+
help="Directory to save the processed dataset cache",
56+
)
57+
parser.add_argument(
58+
"--dtype",
59+
type=str,
60+
default="bfloat16",
61+
choices=["float16", "bfloat16", "float32"],
62+
help="Data type for processing",
63+
)
64+
parser.add_argument(
65+
"--no-audio",
66+
action="store_true",
67+
help="Disable audio extraction from videos",
68+
)
69+
args = parser.parse_args()
70+
71+
use_audio = not args.no_audio
72+
73+
# Set dtype
74+
dtype_map = {
75+
"float16": torch.float16,
76+
"bfloat16": torch.bfloat16,
77+
"float32": torch.float32,
78+
}
79+
dtype = dtype_map[args.dtype]
80+
81+
print(f"Loading processor from {args.model_name}...")
82+
hf_processor = AutoProcessor.from_pretrained(args.model_name, trust_remote_code=True)
83+
84+
print(f"Creating Qwen3OmniVideoProcessor (use_audio={use_audio}, dtype={args.dtype})...")
85+
processor = Qwen3OmniVideoProcessor(
86+
tokenizer=hf_processor,
87+
device="cuda" if torch.cuda.is_available() else "cpu",
88+
dtype=dtype,
89+
use_audio_in_video=use_audio,
90+
)
91+
92+
print(f"Processing {args.num_samples} samples from {args.dataset_name}...")
93+
print(f"Cache directory: {args.cache_dir}")
94+
95+
# This will process and save to cache
96+
_ = get_video_dataset_dataloader(
97+
dataset_name=args.dataset_name,
98+
processor=processor,
99+
batch_size=1,
100+
num_samples=args.num_samples,
101+
cache_dir=args.cache_dir,
102+
)
103+
104+
# Cleanup temp files
105+
processor.cleanup()
106+
107+
cache_path = os.path.join(args.cache_dir, f"{args.dataset_name}_n{args.num_samples}_processed")
108+
print(f"\nDone! Processed dataset saved to: {cache_path}")
109+
110+
111+
if __name__ == "__main__":
112+
main()

modelopt/torch/utils/video_dataset_utils.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def get_video_dataset_dataloader(
9191
processor: "Qwen3OmniVideoProcessor" = None,
9292
batch_size: int = 1,
9393
num_samples: int = 512,
94+
cache_dir: str | None = None,
9495
) -> DataLoader:
9596
"""Get a dataloader with the dataset name and processor of the target model.
9697
@@ -99,17 +100,47 @@ def get_video_dataset_dataloader(
99100
processor: Processor used for encoding video and text data.
100101
batch_size: Batch size of the returned dataloader.
101102
num_samples: Number of samples from the dataset.
103+
cache_dir: Directory to cache the processed dataset. Defaults to a temp directory.
104+
If the cache exists, it will be loaded instead of reprocessing.
102105
103106
Returns:
104107
An instance of dataloader.
105108
"""
106109
assert processor is not None, "Please provide a valid processor."
107110

108-
dataset = _get_video_dataset(dataset_name, num_samples=num_samples)
109-
# Apply the preprocessing function to the dataset
110-
processed_dataset = dataset.map(
111-
processor.preprocess_function, batched=False, remove_columns=dataset.column_names
112-
)
111+
# Default cache_dir to temp directory
112+
if cache_dir is None:
113+
cache_dir = os.path.join(tempfile.gettempdir(), "modelopt_video_dataset_cache")
114+
115+
processed_dataset = None
116+
117+
# Try to load from cache
118+
if cache_dir is not None:
119+
from datasets import load_from_disk
120+
121+
cache_path = os.path.join(cache_dir, f"{dataset_name}_n{num_samples}_processed")
122+
if os.path.exists(cache_path):
123+
try:
124+
processed_dataset = load_from_disk(cache_path)
125+
print(f"Loaded processed dataset from cache: {cache_path}")
126+
except Exception as e:
127+
print(f"Failed to load cache from {cache_path}: {e}. Reprocessing...")
128+
processed_dataset = None
129+
130+
# Process dataset if not loaded from cache
131+
if processed_dataset is None:
132+
dataset = _get_video_dataset(dataset_name, num_samples=num_samples)
133+
# Apply the preprocessing function to the dataset
134+
processed_dataset = dataset.map(
135+
processor.preprocess_function, batched=False, remove_columns=dataset.column_names
136+
)
137+
138+
# Save to cache if cache_dir is provided
139+
if cache_dir is not None:
140+
os.makedirs(cache_dir, exist_ok=True)
141+
# Use num_shards=1 to avoid off-by-one sharding bug with complex nested structures
142+
processed_dataset.save_to_disk(cache_path, num_shards=1)
143+
print(f"Saved processed dataset to cache: {cache_path}")
113144

114145
# Create DataLoader with the custom collate function
115146
return DataLoader(

0 commit comments

Comments
 (0)