-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathiterator.py
More file actions
130 lines (110 loc) · 4.24 KB
/
iterator.py
File metadata and controls
130 lines (110 loc) · 4.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright 2024 RecML authors <recommendations-ml@google.com>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data loading and preprocessing for feeding Jax models."""
from collections.abc import Callable
import os
from typing import Any
import clu.data as clu_data
from etils import epath
import numpy as np
import tensorflow as tf
import jax
Iterator = clu_data.DatasetIterator
class TFDatasetIterator(clu_data.DatasetIterator):
"""An iterator for TF Datasets that supports postprocessing."""
def __init__(
self,
dataset: tf.data.Dataset,
postprocessor: Callable[..., Any] | None = None,
checkpoint: bool = False,
):
"""Initializes the iterator.
Args:
dataset: The TF Dataset to iterate over.
postprocessor: An optional postprocessor to apply to each batch. This is
useful for sending embedded ID features to a separate accelerator.
checkpoint: Whether to checkpoint the iterator state.
"""
self._dataset = dataset
self._iterator = iter(dataset)
self._postprocessor = postprocessor
self._prefetched_batch = None
self._element_spec = None
self._checkpoint = None
if checkpoint:
self._checkpoint = tf.train.Checkpoint(ds=self._iterator)
def __next__(self) -> clu_data.Element:
"""Returns the next batch."""
if self._prefetched_batch is not None:
batch = self._prefetched_batch
self._prefetched_batch = None
else:
batch = next(self._iterator)
if self._postprocessor is not None:
batch = self._postprocessor(batch)
def _maybe_to_numpy(
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
return x
if hasattr(x, "_numpy"):
numpy = x._numpy() # pylint: disable=protected-access
elif hasattr(x, "numpy"):
numpy = x.numpy()
else:
return x
if isinstance(numpy, np.ndarray):
# Tensors are expected to be immutable, so we disable writes.
numpy.setflags(write=False)
return numpy
return jax.tree.map(_maybe_to_numpy, batch)
@property
def element_spec(self) -> clu_data.ElementSpec:
if self._element_spec is not None:
return self._element_spec
batch = next(self._iterator)
if self._postprocessor is not None:
batch = self._postprocessor(batch)
self._prefetched_batch = batch
def _to_element_spec(
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
) -> clu_data.ArraySpec:
if isinstance(x, tf.SparseTensor):
return clu_data.ArraySpec(
dtype=x.dtype.as_numpy_dtype,
shape=tuple(x.shape[0], *[None for _ in x.shape[1:]]),
)
if isinstance(x, tf.RaggedTensor):
return clu_data.ArraySpec(
dtype=x.dtype.as_numpy_dtype, # pylint: disable=attribute-error
shape=tuple(x.shape.as_list()), # pylint: disable=attribute-error
)
if isinstance(x, tf.Tensor):
return clu_data.ArraySpec(
dtype=x.dtype.as_numpy_dtype, shape=tuple(x.shape.as_list())
)
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))
element_spec = jax.tree.map(_to_element_spec, batch)
self._element_spec = element_spec
return element_spec
def reset(self):
self._iterator = iter(self._dataset)
if self._checkpoint is not None:
self._checkpoint = tf.train.Checkpoint(ds=self._iterator)
def save(self, filename: epath.Path):
if self._checkpoint is not None:
self._checkpoint.write(os.fspath(filename))
def restore(self, filename: epath.Path):
if self._checkpoint is not None:
self._checkpoint.read(os.fspath(filename)).assert_consumed()