Skip to content

Commit 0d4af87

Browse files
Merge pull request #3789 from AI-Hypercomputer:eval_data_shard
PiperOrigin-RevId: 908922221
2 parents bb4ae4f + 8caa3fc commit 0d4af87

1 file changed

Lines changed: 327 additions & 0 deletions

File tree

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Script to reshard a TFDS dataset into a specific number of shards.
17+
This is useful when the number of hosts for dataloading is larger than the number of shards in the dataset.
18+
19+
Example (num_workers, buffer_size, dataset_name are optional):
20+
21+
For Single Split Dataset:
22+
23+
python3 tools/data_generation/reshard_tfds.py \
24+
--src_dir gs://your-bucket/origin_folder \
25+
--dst_dir gs://your-bucket/new_folder \
26+
--num_shards 2048 \
27+
--split train \
28+
--num_workers 16 \
29+
--buffer_size 33554432
30+
31+
For Multiple Splits Dataset:
32+
33+
python3 tools/data_generation/reshard_tfds.py \
34+
--src_dir gs://your-bucket/origin_folder \
35+
--dst_dir gs://your-bucket/new_folder \
36+
--num_shards 2048 \
37+
--split train,validation \
38+
--num_workers 16 \
39+
--buffer_size 33554432
40+
41+
"""
42+
43+
import argparse
44+
import os
45+
import json
46+
import multiprocessing
47+
import queue
48+
import threading
49+
50+
import tensorflow as tf
51+
import tensorflow_datasets as tfds
52+
from tqdm import tqdm
53+
54+
55+
def parse_args():
56+
"""Parses command-line arguments for the resharding script."""
57+
parser = argparse.ArgumentParser(description="Reshard a TFDS dataset.")
58+
parser.add_argument("--src_dir", type=str, required=True, help="Source TFDS directory (e.g., gs://bucket/c4/en/3.0.1)")
59+
parser.add_argument("--dst_dir", type=str, required=True, help="Destination directory")
60+
parser.add_argument("--num_shards", type=int, default=2048, help="Number of shards for the output (default: 2048)")
61+
parser.add_argument("--split", type=str, default="train", help="Split(s) to reshard, comma-separated (default: train)")
62+
parser.add_argument(
63+
"--dataset_name", type=str, default=None, help="Optional dataset name. If not set, inferred from metadata."
64+
)
65+
parser.add_argument("--num_workers", type=int, default=16, help="Optional number of workers (default: 16)")
66+
parser.add_argument(
67+
"--buffer_size",
68+
type=int,
69+
default=32 * 1024 * 1024,
70+
help="Optional buffer size in bytes for TFRecordDataset (default: 32MB)",
71+
)
72+
return parser.parse_args()
73+
74+
75+
def get_shard_path(dst_dir, dataset_name, split, shard_index, total_shards):
76+
"""Constructs the standard TFDS filename for a specific shard."""
77+
shard_name = f"{dataset_name}-{split}.tfrecord-{shard_index:05d}-of-{total_shards:05d}"
78+
return os.path.join(dst_dir, shard_name)
79+
80+
81+
def reshard_raw_bytes_worker(
82+
worker_id, num_workers, src_files, dst_dir, dataset_name, split, total_shards, buffer_size, progress_queue
83+
):
84+
"""
85+
Worker function that reads raw TFRecord bytes and distributes them to target shards.
86+
87+
Each worker reads a subset of the source dataset and writes to a specific subset
88+
of target shards (based on its worker_id) to avoid write collisions.
89+
"""
90+
# Dictionary to keep track of active writers and the number of records written to each
91+
writers = {}
92+
shard_lengths = {}
93+
94+
def get_writer(shard_idx):
95+
"""Helper to lazily initialize a TFRecordWriter for a given target shard."""
96+
if shard_idx not in writers:
97+
path = get_shard_path(dst_dir, dataset_name, split, shard_idx, total_shards)
98+
writers[shard_idx] = tf.io.TFRecordWriter(path)
99+
shard_lengths[shard_idx] = 0
100+
return writers[shard_idx]
101+
102+
# Initialize a tf.data.Dataset to read raw bytes from the source TFRecord files.
103+
# A large buffer size (default 32MB) is used to improve I/O throughput, especially on GCS.
104+
ds = tf.data.TFRecordDataset(src_files, compression_type=None, buffer_size=buffer_size)
105+
106+
# Shard the dataset so this worker only processes its designated portion of the data
107+
ds = ds.shard(num_workers, worker_id)
108+
109+
# Iterate through the worker's data slice and write each record to its target shard
110+
i = -1
111+
for i, record_bytes in enumerate(ds):
112+
# Calculate the global index of this record among all records processed
113+
i_global = i * num_workers + worker_id
114+
115+
# Determine which target shard this record belongs to (round-robin distribution)
116+
target_shard_idx = i_global % total_shards
117+
118+
writer = get_writer(target_shard_idx)
119+
writer.write(record_bytes.numpy())
120+
shard_lengths[target_shard_idx] += 1
121+
122+
# Send progress update every 1000 records
123+
if (i + 1) % 1000 == 0:
124+
progress_queue.put(1000)
125+
126+
# Send any remaining progress
127+
remainder = (i + 1) % 1000
128+
if remainder > 0:
129+
progress_queue.put(remainder)
130+
131+
# Close all writers opened by this worker to ensure data is flushed to disk
132+
for writer in writers.values():
133+
writer.close()
134+
135+
return shard_lengths
136+
137+
138+
def progress_listener(q, total_examples):
139+
"""Listens to the progress queue and updates a single tqdm progress bar."""
140+
pbar = tqdm(total=total_examples, desc="Resharding Progress", unit=" records", unit_scale=True)
141+
while True:
142+
try:
143+
# Block briefly to wait for updates
144+
update = q.get(timeout=0.1)
145+
if update == "DONE":
146+
break
147+
pbar.update(update)
148+
except queue.Empty:
149+
continue
150+
pbar.close()
151+
152+
153+
def main():
154+
"""Main execution flow for reading metadata, sharding data, and updating dataset info."""
155+
args = parse_args()
156+
157+
# Create destination directory if it doesn't exist
158+
if not tf.io.gfile.exists(args.dst_dir):
159+
tf.io.gfile.makedirs(args.dst_dir)
160+
161+
target_splits = [s.strip() for s in args.split.split(",") if s.strip()]
162+
163+
# Load source metadata once
164+
print(f"Loading metadata from {args.src_dir}...")
165+
info_path = os.path.join(args.src_dir, "dataset_info.json")
166+
if not tf.io.gfile.exists(info_path):
167+
raise FileNotFoundError(f"Required metadata file not found: {info_path}")
168+
169+
with tf.io.gfile.GFile(info_path, "r") as f:
170+
info_json = json.load(f)
171+
172+
dataset_name = args.dataset_name or info_json.get("name")
173+
if not dataset_name:
174+
try:
175+
# Attempt to verify dataset name using TFDS standard builder
176+
builder = tfds.builder_from_directory(args.src_dir)
177+
dataset_name = builder.name
178+
except Exception as e: # pylint: disable=broad-exception-caught
179+
print(f"Warning: Could not load metadata via tfds.builder_from_directory: {e}")
180+
print("Warning: Dataset name could not be determined, and output filenames will use 'unknown'.")
181+
dataset_name = "unknown"
182+
183+
# Use a multiprocessing Manager to share a queue between workers and the main process
184+
with multiprocessing.Manager() as manager:
185+
for split_name in target_splits:
186+
print(f"\n--- Processing split: {split_name} ---")
187+
num_examples = 0
188+
189+
# Handle splits metadata whether it's a list or dictionary
190+
splits_meta = info_json.get("splits", {})
191+
if isinstance(splits_meta, list):
192+
split_item = next((s for s in splits_meta if s["name"] == split_name), None)
193+
if split_item:
194+
num_examples = int(split_item.get("numExamples", split_item.get("num_examples", 0)))
195+
else:
196+
split_item = splits_meta.get(split_name)
197+
if split_item:
198+
num_examples = int(split_item.get("numExamples", split_item.get("num_examples", 0)))
199+
200+
# Find source TFRecord files using common TFDS naming patterns
201+
pattern = os.path.join(args.src_dir, f"{dataset_name}-{split_name}.tfrecord*")
202+
src_files = tf.io.gfile.glob(pattern)
203+
src_files.sort()
204+
205+
if not src_files:
206+
pattern = os.path.join(args.src_dir, f"{split_name}.tfrecord*")
207+
src_files = tf.io.gfile.glob(pattern)
208+
src_files.sort()
209+
210+
if not src_files:
211+
raise FileNotFoundError(f"Could not find TFRecord files for split '{split_name}' in {args.src_dir}")
212+
213+
print(f"Found {len(src_files)} source files for split '{split_name}' ({num_examples} examples).")
214+
215+
# Setup multiprocessing pool
216+
num_workers = args.num_workers
217+
218+
# Ensure the target number of shards is divisible by the number of workers
219+
# to maintain proper load balancing and deterministic write distributions
220+
if args.num_shards % num_workers != 0:
221+
for i in range(num_workers, 0, -1):
222+
if args.num_shards % i == 0:
223+
num_workers = i
224+
break
225+
print(f"Adjusted num_workers to {num_workers} to be a factor of {args.num_shards}")
226+
227+
print(f"Resharding into {args.num_shards} shards using {num_workers} workers...")
228+
229+
progress_queue = manager.Queue()
230+
231+
# Start the listener thread in the background to consume progress updates
232+
listener_thread = threading.Thread(
233+
target=progress_listener, args=(progress_queue, num_examples if num_examples > 0 else None)
234+
)
235+
listener_thread.start()
236+
237+
# Prepare worker arguments and launch the pool
238+
tasks = []
239+
for i in range(num_workers):
240+
tasks.append(
241+
(
242+
i,
243+
num_workers,
244+
src_files,
245+
args.dst_dir,
246+
dataset_name,
247+
split_name,
248+
args.num_shards,
249+
args.buffer_size,
250+
progress_queue,
251+
)
252+
)
253+
254+
with multiprocessing.Pool(processes=num_workers) as pool:
255+
results = pool.starmap(reshard_raw_bytes_worker, tasks)
256+
257+
# Signal the listener thread that work is complete and wait for it to join
258+
progress_queue.put("DONE")
259+
listener_thread.join()
260+
261+
# Aggregate the results (shard lengths) from all workers
262+
all_shard_lengths = {}
263+
for r in results:
264+
all_shard_lengths.update(r)
265+
266+
# Verify the total number of examples processed matches the original metadata
267+
total_count = sum(all_shard_lengths.values())
268+
print(f"Successfully resharded {total_count} examples for '{split_name}'.")
269+
if num_examples > 0 and total_count != num_examples:
270+
print(f"Warning: Total examples {total_count} does not match original {num_examples} for split '{split_name}'.")
271+
272+
# Update the shard count and lengths in the JSON metadata for this split
273+
shard_lengths_list = [all_shard_lengths.get(i, 0) for i in range(args.num_shards)]
274+
275+
if "splits" not in info_json:
276+
info_json["splits"] = {}
277+
splits_meta = info_json["splits"]
278+
279+
if isinstance(splits_meta, list):
280+
found = False
281+
for split_item in splits_meta:
282+
if split_item.get("name") == split_name:
283+
split_item["shardLengths"] = [str(l) for l in shard_lengths_list]
284+
split_item["numShards"] = str(args.num_shards)
285+
split_item["numExamples"] = str(total_count)
286+
found = True
287+
break
288+
if not found:
289+
splits_meta.append(
290+
{
291+
"name": split_name,
292+
"shardLengths": [str(l) for l in shard_lengths_list],
293+
"numShards": str(args.num_shards),
294+
"numExamples": str(total_count),
295+
}
296+
)
297+
else:
298+
if split_name in splits_meta:
299+
splits_meta[split_name]["shardLengths"] = [str(l) for l in shard_lengths_list]
300+
splits_meta[split_name]["numShards"] = str(args.num_shards)
301+
if "numExamples" in splits_meta[split_name]:
302+
splits_meta[split_name]["numExamples"] = str(total_count)
303+
else:
304+
splits_meta[split_name]["num_examples"] = str(total_count)
305+
else:
306+
splits_meta[split_name] = {
307+
"shardLengths": [str(l) for l in shard_lengths_list],
308+
"numShards": str(args.num_shards),
309+
"numExamples": str(total_count),
310+
}
311+
312+
# Create and save updated dataset_info.json for the new dataset
313+
print("\nCreating new dataset_info.json...")
314+
dst_info_path = os.path.join(args.dst_dir, "dataset_info.json")
315+
with tf.io.gfile.GFile(dst_info_path, "w") as f:
316+
json.dump(info_json, f, indent=4)
317+
318+
# Copy features.json if it exists (necessary for some TFDS versions/formats)
319+
features_path = os.path.join(args.src_dir, "features.json")
320+
if tf.io.gfile.exists(features_path):
321+
tf.io.gfile.copy(features_path, os.path.join(args.dst_dir, "features.json"), overwrite=True)
322+
323+
print(f"Done! Resharded dataset available at {args.dst_dir}")
324+
325+
326+
if __name__ == "__main__":
327+
main()

0 commit comments

Comments
 (0)