5656
5757from tensorflow_data_validation import types
5858from tensorflow_data_validation .statistics .generators import stats_generator
59+ from tensorflow_data_validation .utils import quantiles_util
5960from tensorflow_data_validation .utils import schema_util
6061from tensorflow_data_validation .utils import stats_util
6162from tensorflow_data_validation .utils import vocab_util
6263
64+ from tfx_bsl import sketches
65+
6366from google .protobuf import any_pb2
6467from tensorflow_metadata .proto .v0 import schema_pb2
6568from tensorflow_metadata .proto .v0 import statistics_pb2
6669
6770_NL_DOMAIN = 'natural_language_domain'
71+ _NUM_MISRAGRIES_SKETCH_BUCKETS = 16384
72+ _QUANTILES_SKETCH_ERROR = 0.01
73+ _QUANTILES_SKETCH_NUM_ELEMENTS = 2 ^ 32
74+ _QUANTILES_SKETCH_NUM_STREAMS = 1
6875
6976
77+ # TODO(b/175875824): Determine if we should remove NL features from the default
78+ # Top-K computation which is largely redundant.
7079class _PartialNLStats (object ):
7180 """Partial feature stats for natural language."""
7281
73- def __init__ (self ,
74- invalidate = False ,
75- num_in_vocab_tokens : int = 0 ,
76- total_num_tokens : int = 0 ,
77- sum_in_vocab_token_lengths : int = 0 ) -> None :
82+ def __init__ (
83+ self ,
84+ invalidate = False ,
85+ num_in_vocab_tokens : int = 0 ,
86+ total_num_tokens : int = 0 ,
87+ sum_in_vocab_token_lengths : int = 0 ,
88+ ) -> None :
7889 # True only if this feature should never be considered, e.g: some
7990 # value_lists have inconsistent types or feature doesn't have an
8091 # NL domain.
8192 self .invalidate = invalidate
82-
8393 self .num_in_vocab_tokens = num_in_vocab_tokens
8494 self .total_num_tokens = total_num_tokens
8595 self .sum_in_vocab_token_lengths = sum_in_vocab_token_lengths
96+ self .vocab_token_length_quantiles = sketches .QuantilesSketch (
97+ _QUANTILES_SKETCH_ERROR , _QUANTILES_SKETCH_NUM_ELEMENTS ,
98+ _QUANTILES_SKETCH_NUM_STREAMS )
99+ self .token_occurrence_counts = sketches .MisraGriesSketch (
100+ _NUM_MISRAGRIES_SKETCH_BUCKETS )
86101
87102 def __iadd__ (self , other : '_PartialNLStats' ) -> '_PartialNLStats' :
88103 """Merge two partial natual language stats."""
@@ -91,15 +106,28 @@ def __iadd__(self, other: '_PartialNLStats') -> '_PartialNLStats':
91106 self .num_in_vocab_tokens += other .num_in_vocab_tokens
92107 self .total_num_tokens += other .total_num_tokens
93108 self .sum_in_vocab_token_lengths += other .sum_in_vocab_token_lengths
109+ self .vocab_token_length_quantiles .Merge (other .vocab_token_length_quantiles )
110+ self .token_occurrence_counts .Merge (other .token_occurrence_counts )
94111 return self
95112
96113
114+ def _update_accumulator_with_in_vocab_string_tokens (
115+ accumulator : _PartialNLStats , token_list : List [Text ]):
116+ accumulator .num_in_vocab_tokens += len (token_list )
117+ accumulator .token_occurrence_counts .AddValues (pa .array (token_list ))
118+
119+ token_len_list = [len (t ) for t in token_list ]
120+ accumulator .sum_in_vocab_token_lengths += sum (token_len_list )
121+ accumulator .vocab_token_length_quantiles .AddValues (pa .array (token_len_list ))
122+
123+
97124def _compute_int_listscalar_statistics (
98125 row : List [int ], accumulator : _PartialNLStats ,
99126 excluded_string_tokens : Set [Text ], excluded_int_tokens : Set [int ],
100127 oov_string_tokens : Set [Text ], unused_vocab : Optional [Dict [Text , int ]],
101128 rvocab : Optional [Dict [int , Text ]]):
102129 """Compute statistics for an integer listscalar."""
130+ filtered_entry_str_list = []
103131 for entry in row :
104132 if entry in excluded_int_tokens :
105133 continue
@@ -110,9 +138,11 @@ def _compute_int_listscalar_statistics(
110138 if entry_str in excluded_string_tokens :
111139 continue
112140 if entry_str not in oov_string_tokens :
113- accumulator .num_in_vocab_tokens += 1
114- accumulator .sum_in_vocab_token_lengths += len (entry_str )
141+ filtered_entry_str_list .append (entry_str )
115142 accumulator .total_num_tokens += 1
143+ if filtered_entry_str_list :
144+ _update_accumulator_with_in_vocab_string_tokens (accumulator ,
145+ filtered_entry_str_list )
116146
117147
118148def _compute_str_listscalar_statistics (
@@ -121,16 +151,44 @@ def _compute_str_listscalar_statistics(
121151 oov_string_tokens : Set [Text ], vocab : Optional [Dict [Text , int ]],
122152 unused_rvocab : Optional [Dict [int , Text ]]):
123153 """Compute statistics for a string listscalar."""
154+ filtered_entry_list = []
124155 for entry in row :
125156 if entry in excluded_string_tokens :
126157 continue
127158 if (vocab is not None and entry in vocab and
128159 vocab [entry ] in excluded_int_tokens ):
129160 continue
130161 if entry not in oov_string_tokens :
131- accumulator .num_in_vocab_tokens += 1
132- accumulator .sum_in_vocab_token_lengths += len (entry )
162+ filtered_entry_list .append (entry )
133163 accumulator .total_num_tokens += 1
164+ if filtered_entry_list :
165+ _update_accumulator_with_in_vocab_string_tokens (accumulator ,
166+ filtered_entry_list )
167+
168+
169+ def _populate_token_length_histogram (
170+ nls : statistics_pb2 .NaturalLanguageStatistics , accumulator : _PartialNLStats ,
171+ num_quantiles_histogram_buckets : int ):
172+ """Populate the token length histogram."""
173+ quantiles = accumulator .vocab_token_length_quantiles .GetQuantiles (
174+ num_quantiles_histogram_buckets )
175+ quantiles = quantiles .flatten ().to_pylist ()
176+
177+ if quantiles :
178+ quantiles_histogram = quantiles_util .generate_quantiles_histogram (
179+ quantiles , accumulator .num_in_vocab_tokens ,
180+ num_quantiles_histogram_buckets )
181+ nls .token_length_histogram .CopyFrom (quantiles_histogram )
182+
183+
184+ def _populate_token_rank_histogram (
185+ nls : statistics_pb2 .NaturalLanguageStatistics , accumulator : _PartialNLStats ,
186+ num_rank_histogram_buckets : int ):
187+ """Populate the token rank histogram."""
188+ entries = accumulator .token_occurrence_counts .Estimate ().to_pylist ()
189+ for i , e in enumerate (entries [:num_rank_histogram_buckets ]):
190+ nls .rank_histogram .buckets .add (
191+ low_rank = i , high_rank = i , label = e ['values' ], sample_count = e ['counts' ])
134192
135193
136194class NLStatsGenerator (stats_generator .CombinerFeatureStatsGenerator ):
@@ -140,17 +198,27 @@ class NLStatsGenerator(stats_generator.CombinerFeatureStatsGenerator):
140198 natural_language_domain.
141199 """
142200
143- def __init__ (self ,
144- schema : Optional [schema_pb2 .Schema ] = None ,
145- vocab_paths : Dict [Text , Text ] = None ) -> None :
201+ def __init__ (self , schema : Optional [schema_pb2 .Schema ],
202+ vocab_paths : Optional [Dict [Text , Text ]],
203+ num_quantiles_histogram_buckets : int ,
204+ num_rank_histogram_buckets : int ) -> None :
146205 """Initializes a NLStatsGenerator.
147206
148207 Args:
149208 schema: An optional schema for the dataset.
150209 vocab_paths: A dictonary mapping vocab names to vocab paths.
210+ num_quantiles_histogram_buckets: Number of quantiles to use for
211+ histograms.
212+ num_rank_histogram_buckets: Number of buckets to allow for rank
213+ histograms.
151214 """
152215 self ._schema = schema
153216 self ._vocab_paths = vocab_paths
217+ self ._num_quantiles_histogram_buckets = num_quantiles_histogram_buckets
218+ assert num_rank_histogram_buckets <= _NUM_MISRAGRIES_SKETCH_BUCKETS , (
219+ 'num_rank_histogram_buckets cannot be greater than %d' %
220+ _NUM_MISRAGRIES_SKETCH_BUCKETS )
221+ self ._num_rank_histogram_buckets = num_rank_histogram_buckets
154222 self ._nld_vocabularies = {}
155223 self ._nld_excluded_string_tokens = {}
156224 self ._nld_excluded_int_tokens = {}
@@ -260,6 +328,10 @@ def merge_accumulators(
260328 result += accumulator
261329 return result
262330
331+ def compact (self , accumulator : _PartialNLStats ) -> _PartialNLStats :
332+ accumulator .vocab_token_length_quantiles .Compact ()
333+ return accumulator
334+
263335 def extract_output (
264336 self ,
265337 accumulator : _PartialNLStats ) -> statistics_pb2 .FeatureNameStatistics :
@@ -287,6 +359,19 @@ def extract_output(
287359 accumulator .num_in_vocab_tokens )
288360 result .custom_stats .add (
289361 name = 'nl_avg_token_length' , num = nls .avg_token_length )
362+ if self ._num_quantiles_histogram_buckets :
363+ _populate_token_length_histogram (nls , accumulator ,
364+ self ._num_quantiles_histogram_buckets )
365+ if nls .token_length_histogram .buckets :
366+ result .custom_stats .add (
367+ name = 'nl_token_length_histogram' ,
368+ histogram = nls .token_length_histogram )
369+ if self ._num_rank_histogram_buckets :
370+ _populate_token_rank_histogram (nls , accumulator ,
371+ self ._num_rank_histogram_buckets )
372+ if nls .rank_histogram .buckets :
373+ result .custom_stats .add (
374+ name = 'nl_rank_tokens' , rank_histogram = nls .rank_histogram )
290375 my_proto = any_pb2 .Any ()
291376 result .custom_stats .add (name = 'nl_statistics' , any = my_proto .Pack (nls ))
292377 return result
0 commit comments