Skip to content

Commit 65f1d05

Browse files
zwestricktfx-copybara
authored andcommitted
Feature skew: adds visualization helpers for match stats (very simple) and confusion counts.
PiperOrigin-RevId: 470787863
1 parent baede31 commit 65f1d05

3 files changed

Lines changed: 132 additions & 1 deletion

File tree

tensorflow_data_validation/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@
4848
from tensorflow_data_validation.utils.display_util import compare_slices
4949
from tensorflow_data_validation.utils.display_util import display_anomalies
5050
from tensorflow_data_validation.utils.display_util import display_schema
51+
from tensorflow_data_validation.utils.display_util import get_confusion_count_dataframes
52+
from tensorflow_data_validation.utils.display_util import get_match_stats_dataframe
5153
from tensorflow_data_validation.utils.display_util import get_skew_result_dataframe
5254
from tensorflow_data_validation.utils.display_util import get_statistics_html
5355
from tensorflow_data_validation.utils.display_util import visualize_statistics
5456

57+
5558
# Import schema utilities.
5659
from tensorflow_data_validation.utils.schema_util import get_domain
5760
from tensorflow_data_validation.utils.schema_util import get_feature

tensorflow_data_validation/utils/display_util.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from __future__ import print_function
2121

2222
import base64
23+
import collections
2324
import sys
2425
from typing import Dict, Iterable, List, Optional, Text, Tuple, Union
2526

@@ -607,3 +608,56 @@ def get_skew_result_dataframe(
607608
return pd.DataFrame(
608609
result,
609610
columns=columns).sort_values('feature_name').reset_index(drop=True)
611+
612+
613+
def get_match_stats_dataframe(
614+
match_stats: feature_skew_results_pb2.MatchStats) -> pd.DataFrame:
615+
"""Formats MatchStats as a pandas dataframe."""
616+
return pd.DataFrame.from_dict({
617+
'base_with_id_count': [match_stats.base_with_id_count],
618+
'test_with_id_count': [match_stats.test_with_id_count],
619+
'identifiers_count': [match_stats.identifiers_count],
620+
'ids_missing_in_base_count': [match_stats.ids_missing_in_base_count],
621+
'ids_missing_in_test_count': [match_stats.ids_missing_in_test_count],
622+
'matching_pairs_count': [match_stats.matching_pairs_count],
623+
'base_missing_id_count': [match_stats.base_missing_id_count],
624+
'test_missing_id_count': [match_stats.test_missing_id_count],
625+
'duplicate_id_count': [match_stats.duplicate_id_count],
626+
})
627+
628+
629+
def get_confusion_count_dataframes(
630+
confusion: Iterable[feature_skew_results_pb2.ConfusionCount]
631+
) -> Dict[str, pd.DataFrame]:
632+
"""Returns a pandas dataframe representation of a sequence of ConfusionCount.
633+
634+
Args:
635+
confusion: An interable over ConfusionCount protos.
636+
Returns: A map from feature name to a pandas dataframe containing match counts
637+
along with base and test counts for all unequal value pairs in the input.
638+
"""
639+
confusion = list(confusion)
640+
confusion_per_feature = collections.defaultdict(list)
641+
for c in confusion:
642+
confusion_per_feature[c.feature_name].append(c)
643+
644+
def _build_df(confusion):
645+
base_count_per_value = collections.defaultdict(lambda: 0)
646+
test_count_per_value = collections.defaultdict(lambda: 0)
647+
value_counts = []
648+
for c in confusion:
649+
base_count_per_value[c.base.bytes_value] += c.count
650+
test_count_per_value[c.test.bytes_value] += c.count
651+
value_counts.append((c.base.bytes_value, c.test.bytes_value, c.count))
652+
df = pd.DataFrame(
653+
value_counts, columns=('Base value', 'Test value', 'Pair count'))
654+
df['Base count'] = df['Base value'].apply(lambda x: base_count_per_value[x])
655+
df['Test count'] = df['Test value'].apply(lambda x: test_count_per_value[x])
656+
df['Fraction of base'] = df['Pair count'] / df['Base count']
657+
df = df[df['Base value'] != df['Test value']].sort_values(
658+
['Base value', 'Fraction of base']).reset_index(drop=True)
659+
return df[[
660+
'Base value', 'Test value', 'Pair count', 'Base count', 'Test count'
661+
]]
662+
663+
return {k: _build_df(v) for k, v in confusion_per_feature.items()}

tensorflow_data_validation/utils/display_util_test.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def test_formats_skew_results(self):
884884
])
885885
self.assertTrue(df.equals(expected))
886886

887-
def test_formats_empty_results(self):
887+
def test_formats_empty_skew_results(self):
888888
skew_results = []
889889
df = display_util.get_skew_result_dataframe(skew_results)
890890
expected = pd.DataFrame([],
@@ -895,5 +895,79 @@ def test_formats_empty_results(self):
895895
])
896896
self.assertTrue(df.equals(expected))
897897

898+
def test_formats_confusion_counts(self):
899+
confusion = [
900+
text_format.Parse(
901+
"""
902+
feature_name: "foo"
903+
base {
904+
bytes_value: "val1"
905+
}
906+
test {
907+
bytes_value: "val1"
908+
}
909+
count: 99
910+
""", feature_skew_results_pb2.ConfusionCount()),
911+
text_format.Parse(
912+
"""
913+
feature_name: "foo"
914+
base {
915+
bytes_value: "val1"
916+
}
917+
test {
918+
bytes_value: "val2"
919+
}
920+
count: 1
921+
""", feature_skew_results_pb2.ConfusionCount()),
922+
text_format.Parse(
923+
"""
924+
feature_name: "foo"
925+
base {
926+
bytes_value: "val2"
927+
}
928+
test {
929+
bytes_value: "val3"
930+
}
931+
count: 1
932+
""", feature_skew_results_pb2.ConfusionCount()),
933+
text_format.Parse(
934+
"""
935+
feature_name: "foo"
936+
base {
937+
bytes_value: "val3"
938+
}
939+
test {
940+
bytes_value: "val3"
941+
}
942+
count: 100
943+
""", feature_skew_results_pb2.ConfusionCount()),
944+
text_format.Parse(
945+
"""
946+
feature_name: "bar"
947+
base {
948+
bytes_value: "val1"
949+
}
950+
test {
951+
bytes_value: "val2"
952+
}
953+
count: 1
954+
""", feature_skew_results_pb2.ConfusionCount())
955+
]
956+
dfs = display_util.get_confusion_count_dataframes(confusion)
957+
self.assertSameElements(dfs.keys(), ['foo', 'bar'])
958+
self.assertTrue(dfs['foo'].equals(
959+
pd.DataFrame(
960+
[[b'val1', b'val2', 1, 100, 1], [b'val2', b'val3', 1, 1, 101]],
961+
columns=[
962+
'Base value', 'Test value', 'Pair count', 'Base count',
963+
'Test count'
964+
])))
965+
self.assertTrue(dfs['bar'].equals(
966+
pd.DataFrame([[b'val1', b'val2', 1, 1, 1]],
967+
columns=[
968+
'Base value', 'Test value', 'Pair count', 'Base count',
969+
'Test count'
970+
])))
971+
898972
if __name__ == '__main__':
899973
absltest.main()

0 commit comments

Comments
 (0)