6565_BEAM_NAMESPACE = "TFDS_WIT_KAGGLE"
6666
6767
68+ def _get_csv_reader (filename , * , counter ):
69+ if filename .suffix == ".gz" :
70+ counter ("gz_csv_files" ).inc ()
71+ g = tf .io .gfile .GFile (filename , "rb" )
72+ f = gzip .open (g , "rt" , newline = "" )
73+ else :
74+ counter ("normal_csv_files" ).inc ()
75+ f = tf .io .gfile .GFile (filename , "r" )
76+ # Limit to 100 MB. Value must be smaller than the C long maximum value.
77+ csv .field_size_limit (sys .maxsize )
78+ return csv .reader (f , delimiter = "\t " )
79+
80+
81+ def _read_pixel_rows (filename , * , counter ):
82+ r"""Contains image_url \t image_pixel \t metadata_url."""
83+ reader = _get_csv_reader (filename , counter = counter )
84+ for row in reader :
85+ counter ("pixel_rows" ).inc ()
86+ image_url , image_representation , metadata_url = row
87+ if image_url :
88+ yield [image_url , (image_representation , metadata_url )]
89+ else :
90+ counter ("pixel_rows_no_image_url" ).inc ()
91+
92+
93+ def _read_resnet_rows (filename , * , counter ):
94+ r"""Contains image_url \t resnet_embedding."""
95+ reader = _get_csv_reader (filename , counter = counter )
96+ for row in reader :
97+ counter ("resnet_rows" ).inc ()
98+ image_url , image_representation = row
99+ if image_url :
100+ yield [image_url , image_representation ]
101+ else :
102+ counter ("resnet_rows_no_image_url" ).inc ()
103+
104+
105+ def _read_samples_rows (folder_path , * , builder_config , counter ):
106+ """Contains samples: train and test have different fields."""
107+ for filename in tf .io .gfile .listdir (folder_path ):
108+ file_path = folder_path / filename
109+ f = tf .io .gfile .GFile (file_path , "r" )
110+ # Limit to 100 MB. Value must be smaller than the C long maximum value.
111+ csv .field_size_limit (sys .maxsize )
112+ csv_reader = csv .DictReader (f , delimiter = "\t " , quoting = csv .QUOTE_ALL )
113+ for row in csv_reader :
114+ counter ("samples_rows" ).inc ()
115+ sample = {
116+ feature_key : row [feature_key ]
117+ for feature_key in builder_config .split_specific_features .keys ()
118+ }
119+ image_url = row ["image_url" ]
120+ if image_url :
121+ yield [image_url , sample ]
122+ else :
123+ counter ("samples_rows_no_image_url" ).inc ()
124+
125+
126+ def _process_examples (el , * , builder_config , counter ):
127+ """Process examples."""
128+ sample_url , sample_fields = el
129+ # Each image_url can be associated with multiple samples (e.g., multiple
130+ # languages).
131+ for i , sample_info in enumerate (sample_fields ["sample_info" ]):
132+ sample_id = f"{ i } _{ sample_url } "
133+ sample = {"image_url" : sample_url }
134+ for feature_key in builder_config .split_specific_features .keys ():
135+ sample [feature_key ] = sample_info [feature_key ]
136+ is_boolean_feature = (
137+ builder_config .split_specific_features [feature_key ].np_dtype
138+ == np .bool_
139+ )
140+ if is_boolean_feature :
141+ sample [feature_key ] = bool_utils .parse_bool (sample [feature_key ])
142+ # Test samples don't have gold captions.
143+ if "caption_title_and_reference_description" not in sample_info :
144+ sample ["caption_title_and_reference_description" ] = ""
145+
146+ # We output image data only if there is at least one image
147+ # representation per image_url.
148+ # Not all of the samples in the competition have corresponding image
149+ # data. In case multiple different image representations are associated
150+ # with the same image_url, we don't know which one is correct and don't
151+ # output any.
152+ if len (set (sample_fields ["image_pixels" ])) == 1 :
153+ sample_image , sample_metadata = sample_fields ["image_pixels" ][0 ]
154+ sample ["image" ] = io .BytesIO (base64 .b64decode (sample_image ))
155+ sample ["metadata_url" ] = sample_metadata
156+ else :
157+ if len (set (sample_fields ["image_pixels" ])) > 1 :
158+ counter ("image_pixels_multiple" ).inc ()
159+ else :
160+ counter ("image_pixels_missing" ).inc ()
161+ sample ["image" ] = io .BytesIO (base64 .b64decode (_EMPTY_IMAGE_BYTES ))
162+ sample ["metadata_url" ] = ""
163+
164+ if len (set (sample_fields ["image_resnet" ])) == 1 :
165+ image_resnet = [
166+ float (x ) for x in sample_fields ["image_resnet" ][0 ].split ("," )
167+ ]
168+ sample ["embedding" ] = image_resnet
169+ else :
170+ if len (set (sample_fields ["image_resnet" ])) > 1 :
171+ counter ("image_resnet_multiple" ).inc ()
172+ else :
173+ counter ("image_resnet_missing" ).inc ()
174+ sample ["embedding" ] = builder_config .empty_resnet_embedding
175+
176+ yield sample_id , sample
177+
178+
68179class WitKaggleConfig (tfds .core .BuilderConfig ):
69180 """BuilderConfig for WitKaggle."""
70181
@@ -285,119 +396,15 @@ def _generate_examples(
285396 beam = tfds .core .lazy_imports .apache_beam
286397 counter = functools .partial (beam .metrics .Metrics .counter , _BEAM_NAMESPACE )
287398
288- def _get_csv_reader (filename ):
289- if filename .suffix == ".gz" :
290- counter ("gz_csv_files" ).inc ()
291- g = tf .io .gfile .GFile (filename , "rb" )
292- f = gzip .open (g , "rt" , newline = "" )
293- else :
294- counter ("normal_csv_files" ).inc ()
295- f = tf .io .gfile .GFile (filename , "r" )
296- # Limit to 100 MB. Value must be smaller than the C long maximum value.
297- csv .field_size_limit (sys .maxsize )
298- return csv .reader (f , delimiter = "\t " )
299-
300- def _read_pixel_rows (filename ):
301- r"""Contains image_url \t image_pixel \t metadata_url."""
302- reader = _get_csv_reader (filename )
303- for row in reader :
304- counter ("pixel_rows" ).inc ()
305- image_url , image_representation , metadata_url = row
306- if image_url :
307- yield [image_url , (image_representation , metadata_url )]
308- else :
309- counter ("pixel_rows_no_image_url" ).inc ()
310-
311- def _read_resnet_rows (filename ):
312- r"""Contains image_url \t resnet_embedding."""
313- reader = _get_csv_reader (filename )
314- for row in reader :
315- counter ("resnet_rows" ).inc ()
316- image_url , image_representation = row
317- if image_url :
318- yield [image_url , image_representation ]
319- else :
320- counter ("resnet_rows_no_image_url" ).inc ()
321-
322- def _read_samples_rows (folder_path ):
323- """Contains samples: train and test have different fields."""
324- for filename in tf .io .gfile .listdir (folder_path ):
325- file_path = folder_path / filename
326- f = tf .io .gfile .GFile (file_path , "r" )
327- # Limit to 100 MB. Value must be smaller than the C long maximum value.
328- csv .field_size_limit (sys .maxsize )
329- csv_reader = csv .DictReader (f , delimiter = "\t " , quoting = csv .QUOTE_ALL )
330- for row in csv_reader :
331- counter ("samples_rows" ).inc ()
332- sample = {
333- feature_key : row [feature_key ]
334- for feature_key in self .builder_config .split_specific_features .keys ()
335- }
336- image_url = row ["image_url" ]
337- if image_url :
338- yield [image_url , sample ]
339- else :
340- counter ("samples_rows_no_image_url" ).inc ()
341-
342- def _process_examples (el ):
343- sample_url , sample_fields = el
344- # Each image_url can be associated with multiple samples (e.g., multiple
345- # languages).
346- for i , sample_info in enumerate (sample_fields ["sample_info" ]):
347- sample_id = f"{ i } _{ sample_url } "
348- sample = {"image_url" : sample_url }
349- for feature_key in self .builder_config .split_specific_features .keys ():
350- sample [feature_key ] = sample_info [feature_key ]
351- is_boolean_feature = (
352- self .builder_config .split_specific_features [feature_key ].np_dtype
353- == np .bool_
354- )
355- if is_boolean_feature :
356- sample [feature_key ] = bool_utils .parse_bool (sample [feature_key ])
357- # Test samples don't have gold captions.
358- if "caption_title_and_reference_description" not in sample_info :
359- sample ["caption_title_and_reference_description" ] = ""
360-
361- # We output image data only if there is at least one image
362- # representation per image_url.
363- # Not all of the samples in the competition have corresponding image
364- # data. In case multiple different image representations are associated
365- # with the same image_url, we don't know which one is correct and don't
366- # output any.
367- if len (set (sample_fields ["image_pixels" ])) == 1 :
368- sample_image , sample_metadata = sample_fields ["image_pixels" ][0 ]
369- sample ["image" ] = io .BytesIO (base64 .b64decode (sample_image ))
370- sample ["metadata_url" ] = sample_metadata
371- else :
372- if len (set (sample_fields ["image_pixels" ])) > 1 :
373- counter ("image_pixels_multiple" ).inc ()
374- else :
375- counter ("image_pixels_missing" ).inc ()
376- sample ["image" ] = io .BytesIO (base64 .b64decode (_EMPTY_IMAGE_BYTES ))
377- sample ["metadata_url" ] = ""
378-
379- if len (set (sample_fields ["image_resnet" ])) == 1 :
380- image_resnet = [
381- float (x ) for x in sample_fields ["image_resnet" ][0 ].split ("," )
382- ]
383- sample ["embedding" ] = image_resnet
384- else :
385- if len (set (sample_fields ["image_resnet" ])) > 1 :
386- counter ("image_resnet_multiple" ).inc ()
387- else :
388- counter ("image_resnet_missing" ).inc ()
389- sample ["embedding" ] = self .builder_config .empty_resnet_embedding
390-
391- yield sample_id , sample
392-
393399 # Read embeddings and bytes representations from (possibly compressed) csv.
394400 image_resnet_files = [
395401 image_resnet_path / f for f in tf .io .gfile .listdir (image_resnet_path )
396402 ]
397403 resnet_collection = (
398404 pipeline
399405 | "Collection from resnet files" >> beam .Create (image_resnet_files )
400- | "Get embeddings per image" >> beam .FlatMap (_read_resnet_rows )
406+ | "Get embeddings per image"
407+ >> beam .FlatMap (functools .partial (_read_resnet_rows , counter = counter ))
401408 )
402409
403410 image_pixel_files = [
@@ -406,14 +413,22 @@ def _process_examples(el):
406413 pixel_collection = (
407414 pipeline
408415 | "Collection from pixel files" >> beam .Create (image_pixel_files )
409- | "Get pixels per image" >> beam .FlatMap (_read_pixel_rows )
416+ | "Get pixels per image"
417+ >> beam .FlatMap (functools .partial (_read_pixel_rows , counter = counter ))
410418 )
411419
412420 # Read samples from tsv files.
413421 sample_collection = (
414422 pipeline
415423 | "Collection from sample files" >> beam .Create (samples_path )
416- | "Get samples" >> beam .FlatMap (_read_samples_rows )
424+ | "Get samples"
425+ >> beam .FlatMap (
426+ functools .partial (
427+ _read_samples_rows ,
428+ builder_config = self .builder_config ,
429+ counter = counter ,
430+ )
431+ )
417432 )
418433
419434 # Combine the features and yield examples.
@@ -425,5 +440,12 @@ def _process_examples(el):
425440 }
426441 | "Group by image_url" >> beam .CoGroupByKey ()
427442 | "Reshuffle" >> beam .Reshuffle ()
428- | "Process and yield examples" >> beam .FlatMap (_process_examples )
443+ | "Process and yield examples"
444+ >> beam .FlatMap (
445+ functools .partial (
446+ _process_examples ,
447+ builder_config = self .builder_config ,
448+ counter = counter ,
449+ )
450+ )
429451 )
0 commit comments