Skip to content

Commit 71b4573

Browse files
committed
Support counter stats for custom empty values. These don't count top level nulls.
PiperOrigin-RevId: 625828088
1 parent c341ccf commit 71b4573

2 files changed

Lines changed: 293 additions & 0 deletions

File tree

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Module that counts rows with given empty value."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import collections
21+
from typing import Iterable
22+
23+
from absl import logging
24+
import numpy as np
25+
import pyarrow as pa
26+
from tensorflow_data_validation import types
27+
from tensorflow_data_validation.arrow import arrow_util
28+
from tensorflow_data_validation.statistics.generators import stats_generator
29+
from tensorflow_data_validation.utils import stats_util
30+
from tfx_bsl.arrow import array_util
31+
32+
from tensorflow_metadata.proto.v0 import statistics_pb2
33+
34+
35+
class _PartialCounterStats(object):
36+
"""Partial feature stats for dates/times."""
37+
38+
def __init__(self) -> None:
39+
self.counter = collections.Counter(
40+
{'int_-1': 0, 'str_empty': 0, 'list_empty': 0}
41+
)
42+
43+
def __add__(self, other: '_PartialCounterStats') -> '_PartialCounterStats':
44+
"""Merges two partial stats."""
45+
self.counter.update(other.counter)
46+
return self
47+
48+
def update(
49+
self,
50+
values: np.ndarray,
51+
value_type: types.FeatureNameStatisticsType,
52+
is_multivalent: bool = False,
53+
) -> None:
54+
"""Updates the partial statistics using the values.
55+
56+
Args:
57+
values: A numpy array of values in a batch.
58+
value_type: The type of the values.
59+
is_multivalent: If the feature is multivalent.
60+
"""
61+
62+
# Multivalent feature handling.
63+
if is_multivalent:
64+
empty_list = (values == 0).sum()
65+
self.counter.update({'list_empty': empty_list})
66+
elif (
67+
value_type == statistics_pb2.FeatureNameStatistics.STRING
68+
or value_type == statistics_pb2.FeatureNameStatistics.BYTES
69+
):
70+
empty_str = 0
71+
for value in values:
72+
if value is not None and not value:
73+
empty_str += 1
74+
self.counter.update({'str_empty': empty_str})
75+
76+
elif (
77+
value_type == statistics_pb2.FeatureNameStatistics.FLOAT
78+
or value_type == statistics_pb2.FeatureNameStatistics.INT
79+
):
80+
empty_neg_1 = 0
81+
for value in values:
82+
if value == -1:
83+
empty_neg_1 += 1
84+
self.counter.update({'int_-1': empty_neg_1})
85+
else:
86+
logging.warning('Unsupported type: %s , %s', values[0].dtype, value_type)
87+
raise ValueError(
88+
'Attempt to update partial time stats with values of an '
89+
'unsupported type.'
90+
)
91+
92+
93+
class EmptyValueCounterGenerator(stats_generator.CombinerFeatureStatsGenerator):
94+
"""Counts rows with given empty values."""
95+
96+
def __init__(self) -> None:
97+
"""Initializes a EmptyValueCounterGenerator."""
98+
99+
super(EmptyValueCounterGenerator, self).__init__(
100+
'EmptyValueCounterGenerator'
101+
)
102+
103+
def create_accumulator(self) -> _PartialCounterStats:
104+
"""Returns a fresh, empty accumulator.
105+
106+
Returns:
107+
An empty accumulator.
108+
"""
109+
return _PartialCounterStats()
110+
111+
def add_input(
112+
self,
113+
accumulator: _PartialCounterStats,
114+
feature_path: types.FeaturePath,
115+
feature_array: pa.Array,
116+
) -> _PartialCounterStats:
117+
"""Returns result of folding a batch of inputs into the current accumulator.
118+
119+
Args:
120+
accumulator: The current accumulator.
121+
feature_path: The path of the feature.
122+
feature_array: An arrow Array representing a batch of feature values which
123+
should be added to the accumulator.
124+
125+
Returns:
126+
The accumulator after updating the statistics for the batch of inputs.
127+
"""
128+
129+
feature_type = stats_util.get_feature_type_from_arrow_type(
130+
feature_path, feature_array.type
131+
)
132+
# Ignore null array.
133+
if feature_type is None or not feature_array:
134+
return accumulator
135+
136+
nest_level = arrow_util.get_nest_level(feature_array.type)
137+
if nest_level > 1:
138+
# Flatten removes top level nulls.
139+
feature_array = feature_array.flatten()
140+
list_lengths = array_util.ListLengthsFromListArray(feature_array)
141+
accumulator.update(
142+
np.asarray(list_lengths), feature_type, is_multivalent=True
143+
)
144+
elif (
145+
feature_type == statistics_pb2.FeatureNameStatistics.STRING
146+
or feature_type == statistics_pb2.FeatureNameStatistics.BYTES
147+
):
148+
149+
def _maybe_get_utf8(val):
150+
return stats_util.maybe_get_utf8(val) if isinstance(val, bytes) else val
151+
152+
values = np.asarray(array_util.flatten_nested(feature_array)[0])
153+
maybe_utf8 = np.vectorize(_maybe_get_utf8, otypes=[object])(values)
154+
accumulator.update(maybe_utf8, feature_type)
155+
elif (
156+
feature_type == statistics_pb2.FeatureNameStatistics.INT
157+
or feature_type == statistics_pb2.FeatureNameStatistics.FLOAT
158+
):
159+
values = np.asarray(array_util.flatten_nested(feature_array)[0])
160+
accumulator.update(values, feature_type)
161+
else:
162+
logging.warning('Unsupported type: %s', feature_type)
163+
raise ValueError(
164+
'Attempt to update partial time stats with values of an '
165+
'unsupported type.'
166+
)
167+
168+
return accumulator
169+
170+
def merge_accumulators(
171+
self, accumulators: Iterable[_PartialCounterStats]
172+
) -> _PartialCounterStats:
173+
"""Merges several accumulators to a single accumulator value.
174+
175+
Args:
176+
accumulators: The accumulators to merge.
177+
178+
Returns:
179+
The merged accumulator.
180+
"""
181+
it = iter(accumulators)
182+
result = next(it)
183+
for acc in it:
184+
result += acc
185+
return result
186+
187+
def extract_output(
188+
self, accumulator: _PartialCounterStats
189+
) -> statistics_pb2.FeatureNameStatistics:
190+
"""Returns the result of converting accumulator into the output value.
191+
192+
This method will add the time_domain custom stat to the proto if the match
193+
ratio is at least self._match_ratio. The match ratio is determined by
194+
dividing the number of values that have the most common valid format by the
195+
total number of values considered. If this method adds the time_domain
196+
custom stat, it also adds the match ratio and the most common valid format
197+
to the proto as custom stats.
198+
199+
Args:
200+
accumulator: The final accumulator value.
201+
202+
Returns:
203+
A proto representing the result of this stats generator.
204+
"""
205+
result = statistics_pb2.FeatureNameStatistics()
206+
for name, count in accumulator.counter.items():
207+
if count:
208+
result.custom_stats.add(name=name, num=count)
209+
return result
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for empty_value_counter_generator."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
from absl.testing import absltest
21+
import pyarrow as pa
22+
from tensorflow_data_validation.statistics.generators import empty_value_counter_generator
23+
from tensorflow_data_validation.utils import test_util
24+
25+
from tensorflow_metadata.proto.v0 import statistics_pb2
26+
27+
28+
class EmptyValueCounterGeneratorTest(
29+
test_util.CombinerFeatureStatsGeneratorTest
30+
):
31+
32+
def test_empty_value_counter_generator_for_string(self):
33+
input_batches = [
34+
pa.array([["abc"], [""]]),
35+
pa.array([[""], ["def"]]),
36+
pa.array([[""], None]),
37+
]
38+
generator = empty_value_counter_generator.EmptyValueCounterGenerator()
39+
self.assertCombinerOutputEqual(
40+
input_batches,
41+
generator,
42+
statistics_pb2.FeatureNameStatistics(
43+
custom_stats=[
44+
statistics_pb2.CustomStatistic(name="str_empty", num=3),
45+
]
46+
),
47+
)
48+
49+
def test_empty_value_counter_generator_for_ints(self):
50+
input_batches = [
51+
pa.array([[0], [-1], [10]]),
52+
pa.array([[0], [-1], None]),
53+
pa.array([[2], [-1], [-1], [100]]),
54+
]
55+
generator = empty_value_counter_generator.EmptyValueCounterGenerator()
56+
self.assertCombinerOutputEqual(
57+
input_batches,
58+
generator,
59+
statistics_pb2.FeatureNameStatistics(
60+
custom_stats=[
61+
statistics_pb2.CustomStatistic(name="int_-1", num=4),
62+
]
63+
),
64+
)
65+
66+
def test_empty_value_counter_generator_for_lists(self):
67+
input_batches = [
68+
pa.array([[[]], None, [["abc", "foo"]]]),
69+
pa.array([[["foo"]], None, [[]], [[]], [[]], [["", "jk", "tst"]]]),
70+
]
71+
generator = empty_value_counter_generator.EmptyValueCounterGenerator()
72+
self.assertCombinerOutputEqual(
73+
input_batches,
74+
generator,
75+
statistics_pb2.FeatureNameStatistics(
76+
custom_stats=[
77+
statistics_pb2.CustomStatistic(name="list_empty", num=4),
78+
]
79+
),
80+
)
81+
82+
83+
if __name__ == "__main__":
84+
absltest.main()

0 commit comments

Comments
 (0)