Skip to content

Commit 15b0dd1

Browse files
committed
Implements some metrics in natural language stats generator.
PiperOrigin-RevId: 353746552
1 parent ac70603 commit 15b0dd1

3 files changed

Lines changed: 184 additions & 40 deletions

File tree

tensorflow_data_validation/statistics/generators/natural_language_stats_generator.py

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,33 +56,48 @@
5656

5757
from tensorflow_data_validation import types
5858
from tensorflow_data_validation.statistics.generators import stats_generator
59+
from tensorflow_data_validation.utils import quantiles_util
5960
from tensorflow_data_validation.utils import schema_util
6061
from tensorflow_data_validation.utils import stats_util
6162
from tensorflow_data_validation.utils import vocab_util
6263

64+
from tfx_bsl import sketches
65+
6366
from google.protobuf import any_pb2
6467
from tensorflow_metadata.proto.v0 import schema_pb2
6568
from tensorflow_metadata.proto.v0 import statistics_pb2
6669

6770
_NL_DOMAIN = 'natural_language_domain'
71+
_NUM_MISRAGRIES_SKETCH_BUCKETS = 16384
72+
_QUANTILES_SKETCH_ERROR = 0.01
73+
_QUANTILES_SKETCH_NUM_ELEMENTS = 2 ^ 32
74+
_QUANTILES_SKETCH_NUM_STREAMS = 1
6875

6976

77+
# TODO(b/175875824): Determine if we should remove NL features from the default
78+
# Top-K computation which is largely redundant.
7079
class _PartialNLStats(object):
7180
"""Partial feature stats for natural language."""
7281

73-
def __init__(self,
74-
invalidate=False,
75-
num_in_vocab_tokens: int = 0,
76-
total_num_tokens: int = 0,
77-
sum_in_vocab_token_lengths: int = 0) -> None:
82+
def __init__(
83+
self,
84+
invalidate=False,
85+
num_in_vocab_tokens: int = 0,
86+
total_num_tokens: int = 0,
87+
sum_in_vocab_token_lengths: int = 0,
88+
) -> None:
7889
# True only if this feature should never be considered, e.g: some
7990
# value_lists have inconsistent types or feature doesn't have an
8091
# NL domain.
8192
self.invalidate = invalidate
82-
8393
self.num_in_vocab_tokens = num_in_vocab_tokens
8494
self.total_num_tokens = total_num_tokens
8595
self.sum_in_vocab_token_lengths = sum_in_vocab_token_lengths
96+
self.vocab_token_length_quantiles = sketches.QuantilesSketch(
97+
_QUANTILES_SKETCH_ERROR, _QUANTILES_SKETCH_NUM_ELEMENTS,
98+
_QUANTILES_SKETCH_NUM_STREAMS)
99+
self.token_occurrence_counts = sketches.MisraGriesSketch(
100+
_NUM_MISRAGRIES_SKETCH_BUCKETS)
86101

87102
def __iadd__(self, other: '_PartialNLStats') -> '_PartialNLStats':
88103
"""Merge two partial natual language stats."""
@@ -91,15 +106,28 @@ def __iadd__(self, other: '_PartialNLStats') -> '_PartialNLStats':
91106
self.num_in_vocab_tokens += other.num_in_vocab_tokens
92107
self.total_num_tokens += other.total_num_tokens
93108
self.sum_in_vocab_token_lengths += other.sum_in_vocab_token_lengths
109+
self.vocab_token_length_quantiles.Merge(other.vocab_token_length_quantiles)
110+
self.token_occurrence_counts.Merge(other.token_occurrence_counts)
94111
return self
95112

96113

114+
def _update_accumulator_with_in_vocab_string_tokens(
115+
accumulator: _PartialNLStats, token_list: List[Text]):
116+
accumulator.num_in_vocab_tokens += len(token_list)
117+
accumulator.token_occurrence_counts.AddValues(pa.array(token_list))
118+
119+
token_len_list = [len(t) for t in token_list]
120+
accumulator.sum_in_vocab_token_lengths += sum(token_len_list)
121+
accumulator.vocab_token_length_quantiles.AddValues(pa.array(token_len_list))
122+
123+
97124
def _compute_int_listscalar_statistics(
98125
row: List[int], accumulator: _PartialNLStats,
99126
excluded_string_tokens: Set[Text], excluded_int_tokens: Set[int],
100127
oov_string_tokens: Set[Text], unused_vocab: Optional[Dict[Text, int]],
101128
rvocab: Optional[Dict[int, Text]]):
102129
"""Compute statistics for an integer listscalar."""
130+
filtered_entry_str_list = []
103131
for entry in row:
104132
if entry in excluded_int_tokens:
105133
continue
@@ -110,9 +138,11 @@ def _compute_int_listscalar_statistics(
110138
if entry_str in excluded_string_tokens:
111139
continue
112140
if entry_str not in oov_string_tokens:
113-
accumulator.num_in_vocab_tokens += 1
114-
accumulator.sum_in_vocab_token_lengths += len(entry_str)
141+
filtered_entry_str_list.append(entry_str)
115142
accumulator.total_num_tokens += 1
143+
if filtered_entry_str_list:
144+
_update_accumulator_with_in_vocab_string_tokens(accumulator,
145+
filtered_entry_str_list)
116146

117147

118148
def _compute_str_listscalar_statistics(
@@ -121,16 +151,44 @@ def _compute_str_listscalar_statistics(
121151
oov_string_tokens: Set[Text], vocab: Optional[Dict[Text, int]],
122152
unused_rvocab: Optional[Dict[int, Text]]):
123153
"""Compute statistics for a string listscalar."""
154+
filtered_entry_list = []
124155
for entry in row:
125156
if entry in excluded_string_tokens:
126157
continue
127158
if (vocab is not None and entry in vocab and
128159
vocab[entry] in excluded_int_tokens):
129160
continue
130161
if entry not in oov_string_tokens:
131-
accumulator.num_in_vocab_tokens += 1
132-
accumulator.sum_in_vocab_token_lengths += len(entry)
162+
filtered_entry_list.append(entry)
133163
accumulator.total_num_tokens += 1
164+
if filtered_entry_list:
165+
_update_accumulator_with_in_vocab_string_tokens(accumulator,
166+
filtered_entry_list)
167+
168+
169+
def _populate_token_length_histogram(
170+
nls: statistics_pb2.NaturalLanguageStatistics, accumulator: _PartialNLStats,
171+
num_quantiles_histogram_buckets: int):
172+
"""Populate the token length histogram."""
173+
quantiles = accumulator.vocab_token_length_quantiles.GetQuantiles(
174+
num_quantiles_histogram_buckets)
175+
quantiles = quantiles.flatten().to_pylist()
176+
177+
if quantiles:
178+
quantiles_histogram = quantiles_util.generate_quantiles_histogram(
179+
quantiles, accumulator.num_in_vocab_tokens,
180+
num_quantiles_histogram_buckets)
181+
nls.token_length_histogram.CopyFrom(quantiles_histogram)
182+
183+
184+
def _populate_token_rank_histogram(
185+
nls: statistics_pb2.NaturalLanguageStatistics, accumulator: _PartialNLStats,
186+
num_rank_histogram_buckets: int):
187+
"""Populate the token rank histogram."""
188+
entries = accumulator.token_occurrence_counts.Estimate().to_pylist()
189+
for i, e in enumerate(entries[:num_rank_histogram_buckets]):
190+
nls.rank_histogram.buckets.add(
191+
low_rank=i, high_rank=i, label=e['values'], sample_count=e['counts'])
134192

135193

136194
class NLStatsGenerator(stats_generator.CombinerFeatureStatsGenerator):
@@ -140,17 +198,27 @@ class NLStatsGenerator(stats_generator.CombinerFeatureStatsGenerator):
140198
natural_language_domain.
141199
"""
142200

143-
def __init__(self,
144-
schema: Optional[schema_pb2.Schema] = None,
145-
vocab_paths: Dict[Text, Text] = None) -> None:
201+
def __init__(self, schema: Optional[schema_pb2.Schema],
202+
vocab_paths: Optional[Dict[Text, Text]],
203+
num_quantiles_histogram_buckets: int,
204+
num_rank_histogram_buckets: int) -> None:
146205
"""Initializes a NLStatsGenerator.
147206
148207
Args:
149208
schema: An optional schema for the dataset.
150209
vocab_paths: A dictonary mapping vocab names to vocab paths.
210+
num_quantiles_histogram_buckets: Number of quantiles to use for
211+
histograms.
212+
num_rank_histogram_buckets: Number of buckets to allow for rank
213+
histograms.
151214
"""
152215
self._schema = schema
153216
self._vocab_paths = vocab_paths
217+
self._num_quantiles_histogram_buckets = num_quantiles_histogram_buckets
218+
assert num_rank_histogram_buckets <= _NUM_MISRAGRIES_SKETCH_BUCKETS, (
219+
'num_rank_histogram_buckets cannot be greater than %d' %
220+
_NUM_MISRAGRIES_SKETCH_BUCKETS)
221+
self._num_rank_histogram_buckets = num_rank_histogram_buckets
154222
self._nld_vocabularies = {}
155223
self._nld_excluded_string_tokens = {}
156224
self._nld_excluded_int_tokens = {}
@@ -260,6 +328,10 @@ def merge_accumulators(
260328
result += accumulator
261329
return result
262330

331+
def compact(self, accumulator: _PartialNLStats) -> _PartialNLStats:
332+
accumulator.vocab_token_length_quantiles.Compact()
333+
return accumulator
334+
263335
def extract_output(
264336
self,
265337
accumulator: _PartialNLStats) -> statistics_pb2.FeatureNameStatistics:
@@ -287,6 +359,19 @@ def extract_output(
287359
accumulator.num_in_vocab_tokens)
288360
result.custom_stats.add(
289361
name='nl_avg_token_length', num=nls.avg_token_length)
362+
if self._num_quantiles_histogram_buckets:
363+
_populate_token_length_histogram(nls, accumulator,
364+
self._num_quantiles_histogram_buckets)
365+
if nls.token_length_histogram.buckets:
366+
result.custom_stats.add(
367+
name='nl_token_length_histogram',
368+
histogram=nls.token_length_histogram)
369+
if self._num_rank_histogram_buckets:
370+
_populate_token_rank_histogram(nls, accumulator,
371+
self._num_rank_histogram_buckets)
372+
if nls.rank_histogram.buckets:
373+
result.custom_stats.add(
374+
name='nl_rank_tokens', rank_histogram=nls.rank_histogram)
290375
my_proto = any_pb2.Any()
291376
result.custom_stats.add(name='nl_statistics', any=my_proto.Pack(nls))
292377
return result

0 commit comments

Comments
 (0)