Skip to content

Commit 81aeefa

Browse files
author
The TensorFlow Datasets Authors
committed
Internal change
PiperOrigin-RevId: 861923747
1 parent 78e38e6 commit 81aeefa

File tree

10 files changed

+431
-347
lines changed

10 files changed

+431
-347
lines changed

tensorflow_datasets/datasets/robonet/robonet_dataset_builder.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -165,28 +165,29 @@ def _build_pcollection(self, pipeline, filedir):
165165
"""Generate examples as dicts."""
166166
beam = tfds.core.lazy_imports.apache_beam
167167

168-
def _process_example(filename):
169-
"""Converts one video from hdf5 format."""
170-
h5py = tfds.core.lazy_imports.h5py
171-
with h5py.File(filename) as hf:
172-
video_bytes = hf['env']['cam0_video']['frames'][:].tobytes()
173-
states = hf['env']['state'][:].astype(np.float32)
174-
states = np.pad(
175-
states, ((0, 0), (0, STATES_DIM - states.shape[1])), 'constant'
176-
)
177-
actions = hf['policy']['actions'][:].astype(np.float32)
178-
actions = np.pad(
179-
actions, ((0, 0), (0, ACTIONS_DIM - actions.shape[1])), 'constant'
180-
)
181-
182-
basename = os.path.basename(filename)
183-
features = {
184-
'video': video_bytes,
185-
'actions': actions,
186-
'states': states,
187-
'filename': basename,
188-
}
189-
return basename, features
190-
191168
filenames = tf.io.gfile.glob(os.path.join(filedir, '*.hdf5'))
192169
return pipeline | beam.Create(filenames) | beam.Map(_process_example)
170+
171+
172+
def _process_example(filename):
173+
"""Converts one video from hdf5 format."""
174+
h5py = tfds.core.lazy_imports.h5py
175+
with h5py.File(filename) as hf:
176+
video_bytes = hf['env']['cam0_video']['frames'][:].tobytes()
177+
states = hf['env']['state'][:].astype(np.float32)
178+
states = np.pad(
179+
states, ((0, 0), (0, STATES_DIM - states.shape[1])), 'constant'
180+
)
181+
actions = hf['policy']['actions'][:].astype(np.float32)
182+
actions = np.pad(
183+
actions, ((0, 0), (0, ACTIONS_DIM - actions.shape[1])), 'constant'
184+
)
185+
186+
basename = os.path.basename(filename)
187+
features = {
188+
'video': video_bytes,
189+
'actions': actions,
190+
'states': states,
191+
'filename': basename,
192+
}
193+
return basename, features

tensorflow_datasets/rl_unplugged/rlu_rwrl/rlu_rwrl.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,31 @@ def tf_feature_to_tfds_feature(
245245
raise ValueError(f'Unsupported type {type(nested)}')
246246

247247

248+
def _generate_examples_one_file_fn(
249+
path,
250+
feature_description,
251+
tf_example_to_step_ds_fn,
252+
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
253+
"""Yields examples from one file."""
254+
counter = 0
255+
key_prefix = os.path.basename(path)
256+
# Dataset of tf.Examples containing full episodes.
257+
example_ds = tf.data.TFRecordDataset(filenames=str(path))
258+
# Dataset of episodes, each represented as a dataset of steps.
259+
episode_ds = example_ds.map(
260+
functools.partial(
261+
tf_example_to_step_ds_fn,
262+
feature_description=feature_description,
263+
),
264+
num_parallel_calls=tf.data.experimental.AUTOTUNE,
265+
)
266+
episode_ds = tfds.as_numpy(episode_ds)
267+
for e in episode_ds:
268+
episode_id = counter
269+
yield f'{key_prefix}/{episode_id}', e
270+
counter += 1
271+
272+
248273
class RluRwrl(rlu_common.RLUBuilder):
249274
"""DatasetBuilder for rlu_rwrl dataset."""
250275

@@ -368,26 +393,8 @@ def _generate_examples(self, paths):
368393

369394
feature_description = tf_example_to_feature_description(example_item)
370395

371-
def _generate_examples_one_file(
372-
path,
373-
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
374-
"""Yields examples from one file."""
375-
counter = 0
376-
key_prefix = os.path.basename(path)
377-
# Dataset of tf.Examples containing full episodes.
378-
example_ds = tf.data.TFRecordDataset(filenames=str(path))
379-
# Dataset of episodes, each represented as a dataset of steps.
380-
episode_ds = example_ds.map(
381-
functools.partial(
382-
self.tf_example_to_step_ds,
383-
feature_description=feature_description,
384-
),
385-
num_parallel_calls=tf.data.experimental.AUTOTUNE,
386-
)
387-
episode_ds = tfds.as_numpy(episode_ds)
388-
for e in episode_ds:
389-
episode_id = counter
390-
yield f'{key_prefix}/{episode_id}', e
391-
counter += 1
392-
393-
return beam.Create(file_paths) | beam.FlatMap(_generate_examples_one_file)
396+
return beam.Create(file_paths) | beam.FlatMap(
397+
_generate_examples_one_file_fn,
398+
feature_description=feature_description,
399+
tf_example_to_step_ds_fn=self.tf_example_to_step_ds,
400+
)

tensorflow_datasets/robotics/dataset_importer_builder.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import annotations
1919

2020
import abc
21+
import functools
2122
import os
2223
from typing import Any
2324

@@ -32,6 +33,24 @@
3233

3334

3435

36+
def _dataset_importer_converter_fn(example, decode_fn, keys_to_strip):
37+
"""Beam converter function for DatasetImporterBuilder."""
38+
# Decode the RLDS Episode and transform it to numpy.
39+
example_out = dict(example)
40+
example_out['steps'] = tf.data.Dataset.from_tensor_slices(
41+
example_out['steps']
42+
).map(decode_fn)
43+
steps = list(iter(example_out['steps'].take(-1)))
44+
example_out['steps'] = steps
45+
example_out = dataset_utils.as_numpy(example_out)
46+
example_id = example_out['tfds_id'].decode('utf-8')
47+
del example_out['tfds_id']
48+
for key in keys_to_strip:
49+
if key in example_out:
50+
del example_out[key]
51+
yield example_id, example_out
52+
53+
3554
class DatasetImporterBuilder(
3655
tfds.core.GeneratorBasedBuilder, skip_registration=True
3756
):
@@ -118,24 +137,11 @@ def _generate_examples(
118137

119138
decode_fn = builder.info.features['steps'].feature.decode_example
120139

121-
def converter_fn(example):
122-
# Decode the RLDS Episode and transform it to numpy.
123-
example_out = dict(example)
124-
example_out['steps'] = tf.data.Dataset.from_tensor_slices(
125-
example_out['steps']
126-
).map(decode_fn)
127-
steps = list(iter(example_out['steps'].take(-1)))
128-
example_out['steps'] = steps
129-
130-
example_out = dataset_utils.as_numpy(example_out)
131-
132-
example_id = example_out['tfds_id'].decode('utf-8')
133-
del example_out['tfds_id']
134-
for key in self.KEYS_TO_STRIP:
135-
if key in example_out:
136-
del example_out[key]
137-
138-
yield example_id, example_out
140+
converter_fn = functools.partial(
141+
_dataset_importer_converter_fn,
142+
decode_fn=decode_fn,
143+
keys_to_strip=self.KEYS_TO_STRIP,
144+
)
139145

140146
return f'read_tfds_dataset@{split}' >> beam_utils.ReadFromTFDS(
141147
builder=builder,

tensorflow_datasets/structured/covid19/covid19.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
response, weather, and more.
2121
"""
2222

23+
import functools
2324
import numpy as np
2425
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
2526
import tensorflow_datasets.public_api as tfds
@@ -48,6 +49,29 @@
4849
_BATCH_SIZE = 10000
4950

5051

52+
def _cast_according_to_column(feature_type, v):
53+
if feature_type == tf.string and isinstance(v, (float, int)):
54+
return str(v)
55+
return v
56+
57+
58+
def _load_shard(index: int, dl_manager, archive_path, columns, features):
59+
"""Load a shard of the dataset."""
60+
pd = tfds.core.lazy_imports.pandas
61+
# There is only one file so by using the for we guarantee that the file
62+
# will be closed.
63+
for _, file in dl_manager.iter_archive(archive_path):
64+
df = pd.read_csv(file, skiprows=index, nrows=_BATCH_SIZE)
65+
result = []
66+
for i, row in df.iterrows():
67+
example = {
68+
k: _cast_according_to_column(features[k].dtype, v)
69+
for k, v in zip(columns, row.values)
70+
}
71+
result.append((index + i, example))
72+
return result
73+
74+
5175
class Covid19(tfds.core.GeneratorBasedBuilder):
5276
"""DatasetBuilder for covid19 dataset."""
5377

@@ -787,31 +811,18 @@ def _generate_examples(
787811
pd = tfds.core.lazy_imports.pandas
788812
beam = tfds.core.lazy_imports.apache_beam
789813

790-
def cast_according_to_column(feature_type, v):
791-
if feature_type == tf.string and isinstance(v, (float, int)):
792-
return str(v)
793-
return v
794-
795814
file_handles = dl_manager.iter_archive(archive_path)
796815
_, file = next(file_handles)
797816

798817
columns = pd.read_csv(file, nrows=1).columns
799-
800-
def load_shard(index: int):
801-
# There is only one file so by using the for we guarantee that the file
802-
# will be closed.
803-
for _, file in dl_manager.iter_archive(archive_path):
804-
df = pd.read_csv(file, skiprows=index, nrows=_BATCH_SIZE)
805-
features = self.info.features
806-
result = []
807-
for i, row in df.iterrows():
808-
example = {
809-
k: cast_according_to_column(features[k].dtype, v)
810-
for k, v in zip(columns, row.values)
811-
}
812-
result.append((index + i, example))
813-
return result
818+
features = self.info.features
814819

815820
return beam.Create(list(range(0, _N_RECORDS, _BATCH_SIZE))) | beam.FlatMap(
816-
load_shard
821+
functools.partial(
822+
_load_shard,
823+
dl_manager=dl_manager,
824+
archive_path=archive_path,
825+
columns=columns,
826+
features=features,
827+
)
817828
)

tensorflow_datasets/structured/web_graph/web_graph.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,24 @@
8282
"""
8383

8484

85+
def _get_int_feature(example: tf.train.Example, feature_name: str) -> List[int]:
86+
return example.features.feature[feature_name].int64_list.value
87+
88+
89+
def _process_example(example: bytes, is_test=False):
90+
"""Process a single example."""
91+
example = tf.train.Example.FromString(example)
92+
row_tag = _get_int_feature(example, 'row_tag')[0]
93+
col_tag = np.array(_get_int_feature(example, 'col_tag'), dtype=np.int64)
94+
if is_test:
95+
gt_tag = _get_int_feature(example, 'gt_tag')
96+
else:
97+
gt_tag = []
98+
gt_tag = np.array(gt_tag, dtype=np.int64)
99+
return_dict = {'row_tag': row_tag, 'col_tag': col_tag, 'gt_tag': gt_tag}
100+
return row_tag, return_dict
101+
102+
85103
@dataclasses.dataclass
86104
class WebGraphConfig(tfds.core.BuilderConfig):
87105
"""Palmer Penguins dataset builder config."""
@@ -225,23 +243,6 @@ def _generate_examples(self, pipeline, files, split: str):
225243
"""Yields examples."""
226244
beam = tfds.core.lazy_imports.apache_beam
227245

228-
def _get_int_feature(
229-
example: tf.train.Example, feature_name: str
230-
) -> List[int]:
231-
return example.features.feature[feature_name].int64_list.value
232-
233-
def _process_example(example: bytes, is_test=False):
234-
example = tf.train.Example.FromString(example)
235-
row_tag = _get_int_feature(example, 'row_tag')[0]
236-
col_tag = np.array(_get_int_feature(example, 'col_tag'), dtype=np.int64)
237-
if is_test:
238-
gt_tag = _get_int_feature(example, 'gt_tag')
239-
else:
240-
gt_tag = []
241-
gt_tag = np.array(gt_tag, dtype=np.int64)
242-
return_dict = {'row_tag': row_tag, 'col_tag': col_tag, 'gt_tag': gt_tag}
243-
return row_tag, return_dict
244-
245246
return (
246247
pipeline
247248
| f'{split}_create' >> beam.Create(files)

tensorflow_datasets/text/c4.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,28 @@
349349
]
350350

351351

352+
def _download_wet_file(path, dl_dir):
353+
"""Download WET file if it doesn't already exist."""
354+
url = f"{_DOWNLOAD_HOST}/{path}"
355+
out_path = epath.Path(dl_dir) / path
356+
if out_path.exists():
357+
c4_utils.get_counter_inc_fn("download_wet_url")("exists")
358+
return out_path
359+
tmp_dir = epath.Path(f"{os.fspath(out_path)}.incomplete{uuid.uuid4().hex}")
360+
try:
361+
tmp_dir.mkdir(parents=True, exist_ok=True)
362+
downloader = tfds.download.download_manager.get_downloader()
363+
with downloader.tqdm():
364+
# TODO(slebedev): Investigate why pytype infers Promise[Future[...]].
365+
dl_path = downloader.download(url, tmp_dir).get().path # type: ignore
366+
dl_path = epath.Path(dl_path)
367+
dl_path.rename(out_path)
368+
finally:
369+
tmp_dir.rmtree(missing_ok=True)
370+
c4_utils.get_counter_inc_fn("download_wet_url")("downloaded")
371+
return out_path
372+
373+
352374
class C4Config(tfds.core.BuilderConfig):
353375
"""BuilderConfig for C4 dataset."""
354376

@@ -605,30 +627,6 @@ def _get_pages_pcollection(self, pipeline, file_paths, dl_manager):
605627
"""Build PCollection of un-split page content."""
606628
beam = tfds.core.lazy_imports.apache_beam
607629

608-
def download_wet_file(path, dl_dir):
609-
url = f"{_DOWNLOAD_HOST}/{path}"
610-
out_path = epath.Path(dl_dir) / path
611-
612-
if out_path.exists():
613-
c4_utils.get_counter_inc_fn("download_wet_url")("exists")
614-
return out_path
615-
616-
tmp_dir = epath.Path(
617-
f"{os.fspath(out_path)}.incomplete{uuid.uuid4().hex}"
618-
)
619-
try:
620-
tmp_dir.mkdir(parents=True, exist_ok=True)
621-
downloader = tfds.download.download_manager.get_downloader()
622-
with downloader.tqdm():
623-
# TODO(slebedev): Investigate why pytype infers Promise[Future[...]].
624-
dl_path = downloader.download(url, tmp_dir).get().path # type: ignore
625-
dl_path = epath.Path(dl_path)
626-
dl_path.rename(out_path)
627-
finally:
628-
tmp_dir.rmtree(missing_ok=True)
629-
c4_utils.get_counter_inc_fn("download_wet_url")("downloaded")
630-
return out_path
631-
632630
wet_file_paths = (
633631
pipeline
634632
| "create_wet_path_urls" >> beam.Create(file_paths["wet_path_urls"])
@@ -640,7 +638,7 @@ def download_wet_file(path, dl_dir):
640638
| "filter_corrupt_wet_files"
641639
>> beam.Filter(lambda p: p not in _KNOWN_CORRUPT_WET_FILES)
642640
| beam.Map(
643-
download_wet_file,
641+
_download_wet_file,
644642
dl_dir=os.path.join(dl_manager.download_dir, "c4_wet_files"),
645643
)
646644
)

0 commit comments

Comments
 (0)