-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpandas.py
More file actions
208 lines (163 loc) · 7.48 KB
/
pandas.py
File metadata and controls
208 lines (163 loc) · 7.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""Extension module to process files with Pandas."""
# LICENSE HEADER MANAGED BY add-license-header
# Copyright (c) 2023-2025 Blue Brain Project, EPFL.
#
# This file is part of dir-content-diff.
# See https://github.com/BlueBrain/dir-content-diff for further info.
#
# SPDX-License-Identifier: Apache-2.0
# LICENSE HEADER MANAGED BY add-license-header
from dir_content_diff import register_comparator
from dir_content_diff.base_comparators import BaseComparator
from dir_content_diff.util import import_error_message
try:
import pandas as pd
except ImportError: # pragma: no cover
import_error_message(__name__)
class DataframeComparator(BaseComparator):
"""Comparator for :class:`pandas.DataFrame` objects."""
def format_data(self, data, ref=None, replace_pattern=None):
"""Format the compared :class:`pandas.DataFrame`.
Args:
data (pandas.DataFrame): The DataFrame to format.
ref (pandas.DataFrame): (Optional) The reference DataFrame.
**replace_pattern (dict): (Optional) The columns that contain a given pattern which
must be made replaced.
The dictionary must have the following format:
.. code-block:: python
{
(<pattern>, <new_value>, <optional regex flag>): [col1, col2]
}
.. note::
The formatting errors are stored in `self.current_state["format_errors"]`.
It contains a dict in which the keys are the columns with detected issues and the
values are the actual descriptions of these issues.
Returns:
pandas.DataFrame: The formatted compared data.
"""
self.current_state["format_errors"] = errors = {}
if replace_pattern is not None:
for pat, cols in replace_pattern.items():
pattern = pat[0]
new_value = pat[1]
if len(pat) > 2:
flags = pat[2]
else:
flags = 0
for col in cols:
if ref is not None and col not in ref.columns:
errors[col] = (
"The column is missing in the reference DataFrame, please fix the "
"'replace_pattern' argument."
)
elif col not in data.columns:
errors[col] = (
"The column is missing in the compared DataFrame, please fix the "
"'replace_pattern' argument."
)
elif hasattr(data[col], "str"):
# If all values are NaN, Pandas casts the column dtype to float, so the str
# attribute is not available.
data[col] = data[col].str.replace(
pattern,
new_value,
flags=flags,
regex=True,
)
return data
def diff(self, ref, comp, *args, ignore_columns=None, **kwargs):
"""Compare two :class:`pandas.DataFrame` objects.
This function calls :func:`pandas.testing.assert_series_equal`, read the doc of this
function for details on args and kwargs.
Args:
ref (pandas.DataFrame): The reference DataFrame.
comp (pandas.DataFrame): The compared DataFrame.
**ignore_columns (list(str)): (Optional) The columns that should not be checked.
Returns:
bool or str: ``False`` if the DataFrames are considered as equal or a string explaining
why they are not considered equal.
"""
errors = self.current_state.get("format_errors", {})
if ignore_columns is not None:
ref.drop(columns=ignore_columns, inplace=True, errors="ignore")
comp.drop(columns=ignore_columns, inplace=True, errors="ignore")
for col in ref.columns:
if col in errors:
continue
try:
if col not in comp.columns:
errors[col] = "The column is missing in the compared DataFrame."
else:
pd.testing.assert_series_equal(ref[col], comp[col], *args, **kwargs)
errors[col] = True
except AssertionError as e:
errors[col] = e.args[0]
for col in comp.columns:
if col not in errors and col not in ref.columns:
errors[col] = "The column is missing in the reference DataFrame."
not_equals = {k: v for k, v in errors.items() if v is not True}
if len(not_equals) == 0:
return False
return not_equals
def format_diff(self, difference):
"""Format one element difference."""
k, v = difference
return f"\nColumn '{k}': {v}"
def sort(self, differences):
"""Do not sort the differences to keep the column order."""
return differences
class CsvComparator(DataframeComparator):
"""Comparator for CSV files."""
def load(self, path, **kwargs):
"""Load a CSV file into a :class:`pandas.DataFrame` object."""
return pd.read_csv(path, **kwargs)
def save(self, data, path, **kwargs):
"""Save data to a CSV file."""
index = kwargs.pop("index", False)
data.to_csv(path, index=index, **kwargs)
class HdfComparator(DataframeComparator):
"""Comparator for HDF files."""
def load(self, path, **kwargs):
"""Load a HDF file into a :class:`pandas.DataFrame` object."""
return pd.read_hdf(path, **kwargs)
def save(self, data, path, **kwargs):
"""Save data to a HDF file."""
index = kwargs.pop("index", False)
key = kwargs.pop("key", "data")
data.to_hdf(path, index=index, key=key, **kwargs)
class FeatherComparator(DataframeComparator):
"""Comparator for Feather files."""
def load(self, path, **kwargs):
"""Load a Feather file into a :class:`pandas.DataFrame` object."""
return pd.read_feather(path, **kwargs)
def save(self, data, path, **kwargs):
"""Save data to a Feather file."""
data.to_feather(path, **kwargs)
class ParquetComparator(DataframeComparator):
"""Comparator for Parquet files."""
def load(self, path, **kwargs):
"""Load a Parquet file into a :class:`pandas.DataFrame` object."""
return pd.read_parquet(path, **kwargs)
def save(self, data, path, **kwargs):
"""Save data to a Parquet file."""
data.to_parquet(path, **kwargs)
class StataComparator(DataframeComparator):
"""Comparator for Stata files."""
def load(self, path, **kwargs):
"""Load a Stata file into a :class:`pandas.DataFrame` object."""
return pd.read_stata(path, **kwargs)
def save(self, data, path, **kwargs):
"""Save data to a Stata file."""
data.to_stata(path, **kwargs)
def register():
"""Register Pandas extensions."""
register_comparator(".csv", CsvComparator())
register_comparator(".tsv", CsvComparator())
register_comparator(".h4", HdfComparator())
register_comparator(".h5", HdfComparator())
register_comparator(".hdf", HdfComparator())
register_comparator(".hdf4", HdfComparator())
register_comparator(".hdf5", HdfComparator())
register_comparator(".feather", FeatherComparator())
register_comparator(".parquet", ParquetComparator())
register_comparator(".dta", StataComparator())