Skip to content

Commit c1abf50

Browse files
author
The TensorFlow Datasets Authors
committed
TFDS: Refactor WIT (Wikipedia Image Text) builder.
Moves nested CSV reading functions to the top-level to support Cloudpickle serialization. PiperOrigin-RevId: 874648641
1 parent 7552998 commit c1abf50

1 file changed

Lines changed: 131 additions & 109 deletions

File tree

tensorflow_datasets/vision_language/wit_kaggle/wit_kaggle.py

Lines changed: 131 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,117 @@
6565
_BEAM_NAMESPACE = "TFDS_WIT_KAGGLE"
6666

6767

68+
def _get_csv_reader(filename, *, counter):
69+
if filename.suffix == ".gz":
70+
counter("gz_csv_files").inc()
71+
g = tf.io.gfile.GFile(filename, "rb")
72+
f = gzip.open(g, "rt", newline="")
73+
else:
74+
counter("normal_csv_files").inc()
75+
f = tf.io.gfile.GFile(filename, "r")
76+
# Limit to 100 MB. Value must be smaller than the C long maximum value.
77+
csv.field_size_limit(sys.maxsize)
78+
return csv.reader(f, delimiter="\t")
79+
80+
81+
def _read_pixel_rows(filename, *, counter):
82+
r"""Contains image_url \t image_pixel \t metadata_url."""
83+
reader = _get_csv_reader(filename, counter=counter)
84+
for row in reader:
85+
counter("pixel_rows").inc()
86+
image_url, image_representation, metadata_url = row
87+
if image_url:
88+
yield [image_url, (image_representation, metadata_url)]
89+
else:
90+
counter("pixel_rows_no_image_url").inc()
91+
92+
93+
def _read_resnet_rows(filename, *, counter):
94+
r"""Contains image_url \t resnet_embedding."""
95+
reader = _get_csv_reader(filename, counter=counter)
96+
for row in reader:
97+
counter("resnet_rows").inc()
98+
image_url, image_representation = row
99+
if image_url:
100+
yield [image_url, image_representation]
101+
else:
102+
counter("resnet_rows_no_image_url").inc()
103+
104+
105+
def _read_samples_rows(folder_path, *, builder_config, counter):
106+
"""Contains samples: train and test have different fields."""
107+
for filename in tf.io.gfile.listdir(folder_path):
108+
file_path = folder_path / filename
109+
f = tf.io.gfile.GFile(file_path, "r")
110+
# Limit to 100 MB. Value must be smaller than the C long maximum value.
111+
csv.field_size_limit(sys.maxsize)
112+
csv_reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_ALL)
113+
for row in csv_reader:
114+
counter("samples_rows").inc()
115+
sample = {
116+
feature_key: row[feature_key]
117+
for feature_key in builder_config.split_specific_features.keys()
118+
}
119+
image_url = row["image_url"]
120+
if image_url:
121+
yield [image_url, sample]
122+
else:
123+
counter("samples_rows_no_image_url").inc()
124+
125+
126+
def _process_examples(el, *, builder_config, counter):
127+
"""Process examples."""
128+
sample_url, sample_fields = el
129+
# Each image_url can be associated with multiple samples (e.g., multiple
130+
# languages).
131+
for i, sample_info in enumerate(sample_fields["sample_info"]):
132+
sample_id = f"{i}_{sample_url}"
133+
sample = {"image_url": sample_url}
134+
for feature_key in builder_config.split_specific_features.keys():
135+
sample[feature_key] = sample_info[feature_key]
136+
is_boolean_feature = (
137+
builder_config.split_specific_features[feature_key].np_dtype
138+
== np.bool_
139+
)
140+
if is_boolean_feature:
141+
sample[feature_key] = bool_utils.parse_bool(sample[feature_key])
142+
# Test samples don't have gold captions.
143+
if "caption_title_and_reference_description" not in sample_info:
144+
sample["caption_title_and_reference_description"] = ""
145+
146+
# We output image data only if there is at least one image
147+
# representation per image_url.
148+
# Not all of the samples in the competition have corresponding image
149+
# data. In case multiple different image representations are associated
150+
# with the same image_url, we don't know which one is correct and don't
151+
# output any.
152+
if len(set(sample_fields["image_pixels"])) == 1:
153+
sample_image, sample_metadata = sample_fields["image_pixels"][0]
154+
sample["image"] = io.BytesIO(base64.b64decode(sample_image))
155+
sample["metadata_url"] = sample_metadata
156+
else:
157+
if len(set(sample_fields["image_pixels"])) > 1:
158+
counter("image_pixels_multiple").inc()
159+
else:
160+
counter("image_pixels_missing").inc()
161+
sample["image"] = io.BytesIO(base64.b64decode(_EMPTY_IMAGE_BYTES))
162+
sample["metadata_url"] = ""
163+
164+
if len(set(sample_fields["image_resnet"])) == 1:
165+
image_resnet = [
166+
float(x) for x in sample_fields["image_resnet"][0].split(",")
167+
]
168+
sample["embedding"] = image_resnet
169+
else:
170+
if len(set(sample_fields["image_resnet"])) > 1:
171+
counter("image_resnet_multiple").inc()
172+
else:
173+
counter("image_resnet_missing").inc()
174+
sample["embedding"] = builder_config.empty_resnet_embedding
175+
176+
yield sample_id, sample
177+
178+
68179
class WitKaggleConfig(tfds.core.BuilderConfig):
69180
"""BuilderConfig for WitKaggle."""
70181

@@ -285,119 +396,15 @@ def _generate_examples(
285396
beam = tfds.core.lazy_imports.apache_beam
286397
counter = functools.partial(beam.metrics.Metrics.counter, _BEAM_NAMESPACE)
287398

288-
def _get_csv_reader(filename):
289-
if filename.suffix == ".gz":
290-
counter("gz_csv_files").inc()
291-
g = tf.io.gfile.GFile(filename, "rb")
292-
f = gzip.open(g, "rt", newline="")
293-
else:
294-
counter("normal_csv_files").inc()
295-
f = tf.io.gfile.GFile(filename, "r")
296-
# Limit to 100 MB. Value must be smaller than the C long maximum value.
297-
csv.field_size_limit(sys.maxsize)
298-
return csv.reader(f, delimiter="\t")
299-
300-
def _read_pixel_rows(filename):
301-
r"""Contains image_url \t image_pixel \t metadata_url."""
302-
reader = _get_csv_reader(filename)
303-
for row in reader:
304-
counter("pixel_rows").inc()
305-
image_url, image_representation, metadata_url = row
306-
if image_url:
307-
yield [image_url, (image_representation, metadata_url)]
308-
else:
309-
counter("pixel_rows_no_image_url").inc()
310-
311-
def _read_resnet_rows(filename):
312-
r"""Contains image_url \t resnet_embedding."""
313-
reader = _get_csv_reader(filename)
314-
for row in reader:
315-
counter("resnet_rows").inc()
316-
image_url, image_representation = row
317-
if image_url:
318-
yield [image_url, image_representation]
319-
else:
320-
counter("resnet_rows_no_image_url").inc()
321-
322-
def _read_samples_rows(folder_path):
323-
"""Contains samples: train and test have different fields."""
324-
for filename in tf.io.gfile.listdir(folder_path):
325-
file_path = folder_path / filename
326-
f = tf.io.gfile.GFile(file_path, "r")
327-
# Limit to 100 MB. Value must be smaller than the C long maximum value.
328-
csv.field_size_limit(sys.maxsize)
329-
csv_reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_ALL)
330-
for row in csv_reader:
331-
counter("samples_rows").inc()
332-
sample = {
333-
feature_key: row[feature_key]
334-
for feature_key in self.builder_config.split_specific_features.keys()
335-
}
336-
image_url = row["image_url"]
337-
if image_url:
338-
yield [image_url, sample]
339-
else:
340-
counter("samples_rows_no_image_url").inc()
341-
342-
def _process_examples(el):
343-
sample_url, sample_fields = el
344-
# Each image_url can be associated with multiple samples (e.g., multiple
345-
# languages).
346-
for i, sample_info in enumerate(sample_fields["sample_info"]):
347-
sample_id = f"{i}_{sample_url}"
348-
sample = {"image_url": sample_url}
349-
for feature_key in self.builder_config.split_specific_features.keys():
350-
sample[feature_key] = sample_info[feature_key]
351-
is_boolean_feature = (
352-
self.builder_config.split_specific_features[feature_key].np_dtype
353-
== np.bool_
354-
)
355-
if is_boolean_feature:
356-
sample[feature_key] = bool_utils.parse_bool(sample[feature_key])
357-
# Test samples don't have gold captions.
358-
if "caption_title_and_reference_description" not in sample_info:
359-
sample["caption_title_and_reference_description"] = ""
360-
361-
# We output image data only if there is at least one image
362-
# representation per image_url.
363-
# Not all of the samples in the competition have corresponding image
364-
# data. In case multiple different image representations are associated
365-
# with the same image_url, we don't know which one is correct and don't
366-
# output any.
367-
if len(set(sample_fields["image_pixels"])) == 1:
368-
sample_image, sample_metadata = sample_fields["image_pixels"][0]
369-
sample["image"] = io.BytesIO(base64.b64decode(sample_image))
370-
sample["metadata_url"] = sample_metadata
371-
else:
372-
if len(set(sample_fields["image_pixels"])) > 1:
373-
counter("image_pixels_multiple").inc()
374-
else:
375-
counter("image_pixels_missing").inc()
376-
sample["image"] = io.BytesIO(base64.b64decode(_EMPTY_IMAGE_BYTES))
377-
sample["metadata_url"] = ""
378-
379-
if len(set(sample_fields["image_resnet"])) == 1:
380-
image_resnet = [
381-
float(x) for x in sample_fields["image_resnet"][0].split(",")
382-
]
383-
sample["embedding"] = image_resnet
384-
else:
385-
if len(set(sample_fields["image_resnet"])) > 1:
386-
counter("image_resnet_multiple").inc()
387-
else:
388-
counter("image_resnet_missing").inc()
389-
sample["embedding"] = self.builder_config.empty_resnet_embedding
390-
391-
yield sample_id, sample
392-
393399
# Read embeddings and bytes representations from (possibly compressed) csv.
394400
image_resnet_files = [
395401
image_resnet_path / f for f in tf.io.gfile.listdir(image_resnet_path)
396402
]
397403
resnet_collection = (
398404
pipeline
399405
| "Collection from resnet files" >> beam.Create(image_resnet_files)
400-
| "Get embeddings per image" >> beam.FlatMap(_read_resnet_rows)
406+
| "Get embeddings per image"
407+
>> beam.FlatMap(functools.partial(_read_resnet_rows, counter=counter))
401408
)
402409

403410
image_pixel_files = [
@@ -406,14 +413,22 @@ def _process_examples(el):
406413
pixel_collection = (
407414
pipeline
408415
| "Collection from pixel files" >> beam.Create(image_pixel_files)
409-
| "Get pixels per image" >> beam.FlatMap(_read_pixel_rows)
416+
| "Get pixels per image"
417+
>> beam.FlatMap(functools.partial(_read_pixel_rows, counter=counter))
410418
)
411419

412420
# Read samples from tsv files.
413421
sample_collection = (
414422
pipeline
415423
| "Collection from sample files" >> beam.Create(samples_path)
416-
| "Get samples" >> beam.FlatMap(_read_samples_rows)
424+
| "Get samples"
425+
>> beam.FlatMap(
426+
functools.partial(
427+
_read_samples_rows,
428+
builder_config=self.builder_config,
429+
counter=counter,
430+
)
431+
)
417432
)
418433

419434
# Combine the features and yield examples.
@@ -425,5 +440,12 @@ def _process_examples(el):
425440
}
426441
| "Group by image_url" >> beam.CoGroupByKey()
427442
| "Reshuffle" >> beam.Reshuffle()
428-
| "Process and yield examples" >> beam.FlatMap(_process_examples)
443+
| "Process and yield examples"
444+
>> beam.FlatMap(
445+
functools.partial(
446+
_process_examples,
447+
builder_config=self.builder_config,
448+
counter=counter,
449+
)
450+
)
429451
)

0 commit comments

Comments
 (0)