Skip to content

Commit af62827

Browse files
committed
dev
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent e93dc4e commit af62827

1 file changed

Lines changed: 215 additions & 0 deletions

File tree

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Add Llama-Nemotron-VLM-Dataset-v1 conversations to a conversation dataset (VLM format)."""
17+
18+
import argparse
19+
import re
20+
from pathlib import Path
21+
22+
from datasets import load_dataset
23+
from huggingface_hub import snapshot_download
24+
from tqdm import tqdm
25+
from utils import (
26+
dataset_splits_explanation,
27+
id_for_conversation,
28+
update_dataset_file_with_conversations,
29+
)
30+
31+
# Available splits in the dataset
32+
AVAILABLE_SPLITS = [
33+
"captioning_1",
34+
"captioning_2",
35+
"ocr_1",
36+
"ocr_2",
37+
"ocr_3",
38+
"ocr_4",
39+
"ocr_5",
40+
"ocr_6",
41+
"ocr_7",
42+
"ocr_8",
43+
"ocr_9",
44+
"ocr_10",
45+
"vqa_1",
46+
"vqa_2",
47+
"vqa_3",
48+
"vqa_4",
49+
"vqa_5",
50+
"vqa_6",
51+
"vqa_7",
52+
"vqa_8",
53+
"vqa_9",
54+
]
55+
56+
DATASET_REPO = "nvidia/Llama-Nemotron-VLM-Dataset-v1"
57+
58+
59+
def parse_args() -> argparse.Namespace:
60+
"""Parse command-line arguments."""
61+
parser = argparse.ArgumentParser(
62+
description="Load Llama-Nemotron-VLM-Dataset-v1 conversations in VLM format."
63+
)
64+
65+
parser.add_argument(
66+
"--dataset-split",
67+
type=str,
68+
default="ocr_1",
69+
help=f"""Split of the Llama-Nemotron-VLM-Dataset-v1 to load. Default is 'ocr_1'.
70+
Available splits: {", ".join(AVAILABLE_SPLITS)}""",
71+
)
72+
73+
parser.add_argument(
74+
"--output-split-name",
75+
type=str,
76+
default=None,
77+
help=dataset_splits_explanation("llama-nemotron-vlm-v1-<dataset_split>")
78+
+ "\nIf not provided, defaults to 'llama-nemotron-vlm-v1-<dataset_split>'.",
79+
)
80+
81+
parser.add_argument(
82+
"--output-dir",
83+
type=Path,
84+
default=Path("input_conversations/"),
85+
help="Path to save conversations and images. Default is 'input_conversations/'.",
86+
)
87+
88+
return parser.parse_args()
89+
90+
91+
def download_images(dataset_split: str, image_dir: Path) -> None:
92+
"""Download images for the specified split using huggingface_hub."""
93+
import tarfile
94+
95+
images_folder = f"{dataset_split}_images"
96+
target_dir = image_dir / images_folder
97+
98+
if target_dir.exists() and any(target_dir.iterdir()):
99+
print(f"Images already exist at {target_dir}, skipping download.")
100+
return
101+
102+
print(f"Downloading images for {dataset_split}...")
103+
image_dir.mkdir(parents=True, exist_ok=True)
104+
105+
snapshot_download(
106+
repo_id=DATASET_REPO,
107+
repo_type="dataset",
108+
local_dir=str(image_dir),
109+
allow_patterns=[f"{images_folder}/*"],
110+
)
111+
112+
for tar_path in target_dir.glob("*.tar") if target_dir.exists() else []:
113+
print(f"Found tar archive: {tar_path}, extracting...")
114+
try:
115+
with tarfile.open(tar_path, "r:*") as tar:
116+
tar.extractall(path=target_dir)
117+
print(f"Extracted {tar_path}")
118+
tar_path.unlink()
119+
except Exception as e:
120+
print(f"Error extracting {tar_path}: {e}")
121+
print(f"Downloaded images to {target_dir}")
122+
123+
124+
def parse_content_for_vlm(text: str, image_filename: str | None) -> list[dict]:
125+
"""
126+
Parse text content to VLM format, preserving original <image> placeholder in text.
127+
128+
Returns a list of content parts:
129+
- {"type": "image", "image": "<filename>"} for each <image> placeholder
130+
- {"type": "text", "text": "<original_text>"} with <image> preserved
131+
"""
132+
content_parts = []
133+
134+
# Add image entries for each <image> placeholder
135+
num_images = len(re.findall(r"<image>", text, flags=re.IGNORECASE))
136+
if image_filename:
137+
content_parts += [{"type": "image", "image": image_filename} for _ in range(num_images)]
138+
139+
# Add the original text with <image> preserved
140+
content_parts.append({"type": "text", "text": text.replace("<image>", "")})
141+
142+
return content_parts
143+
144+
145+
async def main(args: argparse.Namespace) -> None:
146+
if args.output_split_name is None:
147+
args.output_split_name = f"llama-nemotron-vlm-v1-{args.dataset_split}"
148+
149+
# Image directory is alongside output directory
150+
image_dir = args.output_dir / "images"
151+
152+
# Download images first
153+
download_images(args.dataset_split, image_dir)
154+
155+
# Load dataset
156+
ds = load_dataset(
157+
DATASET_REPO,
158+
split=args.dataset_split,
159+
streaming=False,
160+
verification_mode="no_checks",
161+
)
162+
163+
input_conversations = []
164+
for i in tqdm(range(len(ds)), desc=f"Loading split {args.dataset_split}", total=len(ds)):
165+
entry = ds[i]
166+
conversations = entry.get("conversations", [])
167+
entry_id = entry.get("id", "")
168+
image_filename = entry.get("image", None)
169+
170+
if not conversations or not isinstance(conversations, list):
171+
continue
172+
173+
processed_conversations = []
174+
for msg in conversations:
175+
role = msg.get("from", msg.get("role", "")).lower()
176+
if not role:
177+
continue
178+
if role == "human":
179+
role = "user"
180+
elif role == "gpt":
181+
role = "assistant"
182+
183+
raw_content = msg.get("value", msg.get("content", msg.get("text", "")))
184+
raw_content = raw_content.strip() if isinstance(raw_content, str) else str(raw_content)
185+
if not raw_content:
186+
continue
187+
188+
if "<image>" in raw_content.lower():
189+
content = parse_content_for_vlm(raw_content, image_filename)
190+
else:
191+
content = [{"type": "text", "text": raw_content}]
192+
193+
if content:
194+
processed_conversations.append({"role": role, "content": content})
195+
196+
if processed_conversations:
197+
prompt_id = f"llama-nemotron-vlm-v1-{args.dataset_split}-{i:06}"
198+
if entry_id:
199+
prompt_id = f"{prompt_id}_{entry_id}"
200+
prompt_id = f"{prompt_id}_" + id_for_conversation(processed_conversations)
201+
input_conversations.append(
202+
{"conversation_id": prompt_id, "conversations": processed_conversations}
203+
)
204+
205+
print(f"Loaded {len(input_conversations)} conversations from split {args.dataset_split}.")
206+
update_dataset_file_with_conversations(
207+
input_conversations, args.output_dir, args.output_split_name
208+
)
209+
210+
211+
if __name__ == "__main__":
212+
import asyncio
213+
214+
args = parse_args()
215+
asyncio.run(main(args))

0 commit comments

Comments
 (0)