-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Expand file tree
/
Copy pathtableDiff.py
More file actions
757 lines (683 loc) · 31.4 KB
/
tableDiff.py
File metadata and controls
757 lines (683 loc) · 31.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
# Copyright 2024 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=missing-module-docstring
import logging
import random
import string
import traceback
from decimal import Decimal
from functools import reduce
from itertools import islice
from typing import Dict, Iterable, List, Optional, Tuple, cast
from urllib.parse import urlparse
import data_diff
import sqlalchemy.types
from data_diff.diff_tables import DiffResultWrapper
from data_diff.errors import DataDiffMismatchingKeyTypesError
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
from pydantic import BaseModel, Field
from sqlalchemy import Column as SAColumn
from sqlalchemy import literal, select
from sqlalchemy.engine import make_url
from metadata.data_quality.validations import utils
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin,
)
from metadata.data_quality.validations.models import (
TableDiffRuntimeParameters,
TableParameter,
)
from metadata.generated.schema.entity.data.table import Column
from metadata.generated.schema.entity.services.connections.database.sapHanaConnection import (
SapHanaScheme,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.generated.schema.tests.basic import (
TestCaseResult,
TestCaseStatus,
TestResultValue,
)
from metadata.generated.schema.type.basic import ProfileSampleType
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.orm.converter.base import build_orm_col
from metadata.profiler.orm.functions.md5 import MD5
from metadata.profiler.orm.functions.substr import Substr
from metadata.profiler.orm.registry import Dialects, PythonDialects
from metadata.utils.collections import CaseInsensitiveList
from metadata.utils.credentials import normalize_pem_string
from metadata.utils.logger import test_suite_logger
logger = test_suite_logger()
SUPPORTED_DIALECTS = [
Dialects.Snowflake,
Dialects.BigQuery,
Dialects.Athena,
Dialects.Redshift,
Dialects.Postgres,
Dialects.MySQL,
Dialects.MSSQL,
Dialects.Oracle,
Dialects.Trino,
SapHanaScheme.hana.value,
Dialects.Databricks,
Dialects.UnityCatalog,
]
class SchemaDiffResult(BaseModel):
class Config:
arbitrary_types_allowed = True
populate_by_name = True
serviceType: str
fullyQualifiedTableName: str
schema_: Dict[str, Dict[str, str]] = Field(alias="schema")
def __str__(self):
return " ".join(f"{k}={v!r}" for k, v in self.model_dump(by_alias=True).items())
class ColumnDiffResult(BaseModel):
class Config:
arbitrary_types_allowed = True
removed: List[str]
added: List[str]
changed: List[str]
schemaTable1: SchemaDiffResult
schemaTable2: SchemaDiffResult
def build_sample_where_clause(table: TableParameter, key_columns: List[str], salt: str, hex_nounce: str) -> str:
sql_alchemy_columns = [
build_orm_col(i, c, table.database_service_type)
for i, c in enumerate(table.columns)
if c.name.root in key_columns
]
reduced_concat = reduce(lambda c1, c2: c1.concat(c2), sql_alchemy_columns + [literal(salt)])
sqa_dialect = make_url(f"{PythonDialects[table.database_service_type.name].value}://").get_dialect()
return str(
select()
.filter(
Substr(
MD5(reduced_concat),
1,
8,
)
< hex_nounce
)
.whereclause.compile(
dialect=sqa_dialect(),
compile_kwargs={"literal_binds": True},
)
)
def compile_and_clauses(elements) -> str:
"""Compile a list of elements into a string with 'and' clauses.
Args:
elements: A string or a list of strings or lists
Returns:
A string with 'and' clauses
Raises:
ValueError: If the input is not a string or a list
Examples:
>>> compile_and_clauses("a")
'a'
>>> compile_and_clauses(["a", "b"])
'a and b'
>>> compile_and_clauses([["a", "b"], "c"])
'(a and b) and c'
"""
if isinstance(elements, str):
return elements
if isinstance(elements, list):
if len(elements) == 1:
return compile_and_clauses(elements[0])
return " and ".join(
(f"({compile_and_clauses(e)})" if isinstance(e, list) else compile_and_clauses(e)) for e in elements
)
raise ValueError("Input must be a string or a list")
class UnsupportedDialectError(Exception):
def __init__(self, param: str, dialect: str):
super().__init__(f"Unsupported dialect in param {param}: {dialect}")
def masked(s: str, mask: bool = True) -> str:
"""Mask a string if masked is True otherwise return the string.
Only for development purposes, do not use in production.
Change it False if you want to see the data in the logs.
Args:
s: string to mask
mask: boolean to mask the string
Returns:
masked string if mask is True otherwise return the string
"""
return "***" if mask else s
def is_numeric(t: type) -> bool:
"""Check if a type is numeric.
Args:
t: type to check
Returns:
True if the type is numeric otherwise False
"""
return t in [int, float, Decimal]
class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
"""
Compare two tables and fail if the number of differences exceeds a threshold
"""
runtime_params: TableDiffRuntimeParameters
def _run_validation(self):
"""Run validation for the table diff test"""
self.runtime_params = self.get_runtime_parameters(TableDiffRuntimeParameters)
try:
self._validate_dialects()
return self._run()
except DataDiffMismatchingKeyTypesError as e:
result = TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=TestCaseStatus.Failed,
result=str(e),
)
return result
except UnsupportedDialectError as e:
logger.error(f"[Data Diff]: Unsupported dialect: {e}")
result = TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=TestCaseStatus.Aborted,
result=str(e),
)
return result
except Exception as e:
logger.error(f"Unexpected error while running the table diff test: {str(e)}\n{traceback.format_exc()}")
result = TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=TestCaseStatus.Aborted,
result=f"ERROR: Unexpected error while running the table diff test: {str(e)}",
)
logger.debug(result.result)
return result
def _run_dimensional_validation(self):
"""Execute dimensional validation for table diff test
Table diff tests don't currently support dimensional validation.
This method returns an empty list to indicate no dimensional results.
Returns:
List: Empty list for now (placeholder for future implementation)
"""
# TODO: Implement dimensional validation for table diff tests if needed
# This would involve grouping by dimension columns and checking diffs per group
return []
def _run(self) -> TestCaseResult:
column_diff: ColumnDiffResult = self.get_column_diff()
threshold = self.get_test_case_param_value(self.test_case.parameterValues, "threshold", int, default=0)
if column_diff:
# If there are column differences, we set extra_columns to the common columns for the diff
# Exclude incomparable columns (different data types) from the comparison
# Also exclude key columns since they are handled separately and should not be in extra_columns
common_columns = list(
(set(column_diff.schemaTable1.schema_.keys()) & set(column_diff.schemaTable2.schema_.keys()))
- set(column_diff.changed)
- set(self.runtime_params.table1.key_columns or [])
- set(self.runtime_params.table2.key_columns or [])
)
self.runtime_params.extraColumns = common_columns
self.runtime_params.table1.extra_columns = common_columns
self.runtime_params.table2.extra_columns = common_columns
table_diff_iter = self.get_table_diff()
if not threshold or self.test_case.computePassedFailedRowCount:
stats = table_diff_iter.get_stats_dict()
if stats["total"] > 0:
logger.debug("Sample of failed rows:")
# depending on the data, this require scanning a lot of data
# so we only log the sample in debug mode. data can be sensitive
# so it is masked by default
for s in islice(
self.safe_table_diff_iterator(),
10 if logger.level <= logging.DEBUG else 0,
):
logger.debug("%s", str([s[0]] + [masked(st) for st in s[1]]))
test_case_result = self.get_row_diff_test_case_result(
threshold,
stats["total"],
stats["updated"],
stats["exclusive_A"],
stats["exclusive_B"],
column_diff,
)
count = self._compute_row_count(self.runner, None) # type: ignore
test_case_result.passedRows = stats["unchanged"]
test_case_result.passedRowsPercentage = test_case_result.passedRows / count * 100
test_case_result.failedRowsPercentage = test_case_result.failedRows / count * 100
return test_case_result
return self.get_row_diff_test_case_result(
threshold,
self.calculate_diffs_with_limit(table_diff_iter, threshold),
column_diff,
)
def get_incomparable_columns(self) -> List[str]:
"""Get the columns that have types that are not comparable between the two tables. For example
a column that is a string in one table and an integer in the other.
Returns:
List[str]: A list of column names that have incomparable types
"""
table1 = data_diff.connect_to_table(
self.runtime_params.table1.serviceUrl,
self.runtime_params.table1.path,
self.runtime_params.table1.key_columns,
extra_columns=self.runtime_params.extraColumns,
case_sensitive=self.get_case_sensitive(),
key_content=normalize_pem_string(self.runtime_params.table1.privateKey.get_secret_value())
if self.runtime_params.table1.privateKey
else None,
private_key_passphrase=self.runtime_params.table1.passPhrase.get_secret_value()
if self.runtime_params.table1.passPhrase
else None,
).with_schema()
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.table2.key_columns,
extra_columns=self.runtime_params.extraColumns,
case_sensitive=self.get_case_sensitive(),
key_content=normalize_pem_string(self.runtime_params.table2.privateKey.get_secret_value())
if self.runtime_params.table2.privateKey
else None,
private_key_passphrase=self.runtime_params.table2.passPhrase.get_secret_value()
if self.runtime_params.table2.passPhrase
else None,
).with_schema()
result = []
for column in table1.key_columns + table1.extra_columns:
col1 = table1._schema.get(column) # pylint: disable=protected-access
if col1 is None:
# Skip columns that are not in the first table. We cover this case in get_changed_added_columns.
continue
col2 = table2._schema.get(column) # pylint: disable=protected-access
if col2 is None:
# Skip columns that are not in the second table. We cover this case in get_changed_added_columns.
continue
col1_type = self._get_column_python_type(col1)
col2_type = self._get_column_python_type(col2)
if is_numeric(col1_type) and is_numeric(col2_type):
continue
if col1_type != col2_type:
result.append(column)
return result
@staticmethod
def _get_column_python_type(column: SAColumn):
"""Try to resolve the python_type of a column by cascading through different SQLAlchemy types.
If no type is found, return the name of the column type. This is usually undesirable since it can
be very database specific, but it is better than nothing.
Args:
column: An SQLAlchemy column object
"""
result = None
try:
result = column.python_type
except AttributeError:
pass
try:
result = getattr(sqlalchemy.types, type(column).__name__)().python_type
except AttributeError:
pass
try:
result = getattr(sqlalchemy.types, type(column).__name__.upper())().python_type
except AttributeError:
pass
if result == ArithAlphanumeric:
result = str
elif result == bool:
result = int
elif result is None:
return type(result)
return result
def get_table_diff(self) -> DiffResultWrapper:
"""Calls data_diff.diff_tables with the parameters from the test case."""
left_where, right_where = self.sample_where_clause()
table1 = data_diff.connect_to_table(
self.runtime_params.table1.serviceUrl,
self.runtime_params.table1.path,
self.runtime_params.table1.key_columns, # type: ignore
extra_columns=self.runtime_params.table1.extra_columns,
case_sensitive=self.get_case_sensitive(),
where=left_where,
key_content=self.runtime_params.table1.privateKey.get_secret_value()
if self.runtime_params.table1.privateKey
else None,
private_key_passphrase=self.runtime_params.table1.passPhrase.get_secret_value()
if self.runtime_params.table1.passPhrase
else None,
)
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.table2.key_columns, # type: ignore
extra_columns=self.runtime_params.table2.extra_columns,
case_sensitive=self.get_case_sensitive(),
where=right_where,
key_content=self.runtime_params.table2.privateKey.get_secret_value()
if self.runtime_params.table2.privateKey
else None,
private_key_passphrase=self.runtime_params.table2.passPhrase.get_secret_value()
if self.runtime_params.table2.passPhrase
else None,
)
data_diff_kwargs = {
"where": self.get_where(),
}
logger.debug(
"Calling table diff with parameters: table1=%s, table2=%s, kwargs=%s",
table1.table_path,
table2.table_path,
",".join(f"{k}={v}" for k, v in data_diff_kwargs.items()),
)
return data_diff.diff_tables(table1, table2, **data_diff_kwargs) # type: ignore
def get_where(self) -> Optional[str]:
"""Returns the where clause from the test case parameters or None if it is a blank string."""
return self.runtime_params.whereClause or None
def sample_where_clause(self) -> Tuple[Optional[str], Optional[str]]:
"""We use a where clause to sample the data for the diff. This is useful because with data diff
we do not have access to the underlying 'SELECT' statement. This method generates a where clause
that selects a random sample of the data based on the profile sample configuration.
The method uses the md5 hash of the key columns and a random salt to select a random sample of the data.
This ensures that the same data is selected from the two tables of the comparison.
Example:
-- Table 1 -- | -- Table 2 --
id | name | id | name
1 | Alice | 1 | Alice
2 | Bob | 2 | Bob
3 | Charlie | 3 | Charlie
4 | David | 4 | Edward
5 | Edward | 6 | Frank
If we want a sample of 20% of the data, the where clause will intend to select one of the rows
on Table 1 and the hash will ensure that the same row is selected on Table 2. We want to avoid selecting rows
with different ids because the comparison will not be sensible.
"""
config = self.runtime_params.table_profile_config
if config is None:
return None, None
profile_sample_config = config.profileSampleConfig if config else None
sample_config = profile_sample_config.root if profile_sample_config else None
static = sample_config.config if sample_config else None
profile_sample = getattr(static, "profileSample", None) if static else None
profile_sample_type = getattr(static, "profileSampleType", None) if static else None
if profile_sample is None or (profile_sample_type == ProfileSampleType.PERCENTAGE and profile_sample == 100):
return None, None
if DatabaseServiceType.Mssql in [
self.runtime_params.table1.database_service_type,
self.runtime_params.table2.database_service_type,
]:
logger.warning("Sampling not supported in MSSQL. Skipping sampling.")
return None, None
nounce = self.calculate_nounce()
# SQL MD5 returns a 32 character hex string even with leading zeros so we need to
# pad the nounce to 8 characters in preserve lexical order.
# example: SELECT md5('j(R1wzR*y[^GxWJ5B>L{-HLETRD');
hex_nounce = hex(nounce)[2:].rjust(8, "0")
# TODO: using strings for this is sub-optimal. But using bytes buffers requires a by-database
# implementaiton. We can use this as default and add database specific implementations as the
# need arises.
salt = "".join(
random.choices(string.ascii_letters + string.digits, k=5)
) # 1 / ~62^5 should be enough entropy. Use letters and digits to avoid messing with SQL syntax
return (
build_sample_where_clause(
self.runtime_params.table1,
self.maybe_case_sensitive(self.runtime_params.table1.key_columns),
salt,
hex_nounce,
),
build_sample_where_clause(
self.runtime_params.table2,
self.maybe_case_sensitive(self.runtime_params.table2.key_columns),
salt,
hex_nounce,
),
)
def maybe_case_sensitive(self, iterable: Iterable[str]) -> list[str]:
return CaseInsensitiveList(iterable) if not self.get_case_sensitive() else list(iterable)
def calculate_nounce(self, max_nounce=2**32 - 1) -> int:
"""Calculate the nounce based on the profile sample configuration. The nounce is
the sample fraction projected to a number on a scale of 0 to max_nounce"""
config = self.runtime_params.table_profile_config
profile_sample_config = config.profileSampleConfig if config else None
sample_config = profile_sample_config.root if profile_sample_config else None
static = sample_config.config if sample_config else None
profile_sample = getattr(static, "profileSample", 100)
profile_sample_type = getattr(static, "profileSampleType", None)
if profile_sample_type == ProfileSampleType.PERCENTAGE:
return int(max_nounce * profile_sample / 100)
if profile_sample_type == ProfileSampleType.ROWS:
row_count = self.get_total_row_count()
if row_count is None:
raise ValueError("Row count is required for ROWS profile sample type")
return int(max_nounce * (profile_sample / row_count))
raise ValueError("Invalid profile sample type")
def get_row_diff_test_case_result(
self,
threshold: int,
total_diffs: int,
changed: Optional[int] = None,
removed: Optional[int] = None,
added: Optional[int] = None,
column_diff: Optional[ColumnDiffResult] = None,
) -> TestCaseResult:
"""Build a test case result for a row diff test. If the number of differences is less than the threshold,
the test will pass, otherwise it will fail. The result will contain the number of added, removed, and changed
rows, as well as the total number of differences.
Args:
threshold: The maximum number of differences allowed before the test fails
total_diffs: The total number of differences between the tables
changed: The number of rows that have been changed
removed: The number of rows that have been removed
added: The number of rows that have been added
Returns:
TestCaseResult: The result of the row diff test
"""
test_case_results = [
TestResultValue(name="removedRows", value=str(removed)),
TestResultValue(name="addedRows", value=str(added)),
TestResultValue(name="changedRows", value=str(changed)),
TestResultValue(name="diffCount", value=str(total_diffs)),
]
if column_diff:
test_case_results.extend(
[
TestResultValue(name="removedColumns", value=str(len(column_diff.removed))),
TestResultValue(name="addedColumns", value=str(len(column_diff.added))),
TestResultValue(name="changedColumns", value=str(len(column_diff.changed))),
TestResultValue(name="schemaTable1", value=str(column_diff.schemaTable1)),
TestResultValue(name="schemaTable2", value=str(column_diff.schemaTable2)),
]
)
has_column_diff = column_diff is not None and (column_diff.removed or column_diff.added or column_diff.changed)
if has_column_diff:
result_message = (
f"Schema mismatch detected: "
f"{len(column_diff.removed)} removed, "
f"{len(column_diff.added)} added, "
f"{len(column_diff.changed)} changed columns. "
f"Found {total_diffs} different rows."
)
else:
result_message = f"Found {total_diffs} different rows which is more than the threshold of {threshold}"
return TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=self.get_test_case_status(
not has_column_diff and ((threshold or total_diffs) == 0 or total_diffs < threshold)
),
result=result_message,
failedRows=total_diffs,
validateColumns=False,
testResultValue=test_case_results,
)
def _validate_dialects(self):
for name, param in [
("table1.serviceUrl", self.runtime_params.table1.serviceUrl),
("table2.serviceUrl", self.runtime_params.table2.serviceUrl),
]:
if isinstance(param, dict):
dialect = param.get("driver")
else:
dialect = urlparse(param).scheme
if dialect not in SUPPORTED_DIALECTS:
raise UnsupportedDialectError(name, dialect)
def get_column_diff(self) -> Optional[ColumnDiffResult]:
"""Get the column diff between the two tables. If there are no differences, return None."""
removed, added = self.get_changed_added_columns(
[
c
for c in self.runtime_params.table1.columns
if c.name.root not in self.runtime_params.table1.key_columns
],
[
c
for c in self.runtime_params.table2.columns
if c.name.root not in self.runtime_params.table2.key_columns
],
self.get_case_sensitive(),
)
changed = self.get_incomparable_columns()
if removed or added or changed:
return ColumnDiffResult(
removed=removed,
added=added,
changed=changed,
schemaTable1=SchemaDiffResult(
serviceType=self.runtime_params.table1.database_service_type.name,
fullyQualifiedTableName=self.runtime_params.table1.fullyQualifiedName
or self.runtime_params.table1.path,
schema={
c.name.root: {
"type": c.dataTypeDisplay,
"constraints": c.constraint.value if c.constraint else "",
}
for c in self.runtime_params.table1.columns
},
),
schemaTable2=SchemaDiffResult(
serviceType=self.runtime_params.table2.database_service_type.name,
fullyQualifiedTableName=self.runtime_params.table2.fullyQualifiedName
or self.runtime_params.table2.path,
schema={
c.name.root: {
"type": c.dataTypeDisplay,
"constraints": c.constraint.value if c.constraint else "",
}
for c in self.runtime_params.table2.columns
},
),
)
return None
@staticmethod
def get_changed_added_columns(
left: List[Column], right: List[Column], case_sensitive: bool
) -> Optional[Tuple[List[str], List[str]]]:
"""Given a list of columns from two tables, return the columns that are removed and added.
Args:
left: List of columns from the first table
right: List of columns from the second table
Returns:
A tuple of lists containing the removed and added columns or None if there are no differences
"""
removed: List[str] = []
added: List[str] = []
right_columns_dict: Dict[str, Column] = {c.name.root: c for c in right}
if not case_sensitive:
right_columns_dict = cast(Dict[str, Column], CaseInsensitiveDict(right_columns_dict))
for column in left:
table2_column = right_columns_dict.get(column.name.root)
if table2_column is None:
removed.append(column.name.root)
continue
del right_columns_dict[column.name.root]
added.extend(right_columns_dict.keys())
return removed, added
def column_validation_result(
self,
removed: List[str],
added: List[str],
changed: List[str],
) -> TestCaseResult:
"""Build the result for a column validation result. Messages will only be added
for non-empty categories. Values will be populated reported for all categories.
Args:
removed: List of removed columns
added: List of added columns
changed: List of changed columns
Returns:
TestCaseResult: The result of the column validation with a meaningful message
"""
message = f"Tables have {sum(map(len, [removed, added, changed]))} different columns:"
if removed:
message += f"\n Removed columns: {', '.join(removed)}\n"
if added:
message += f"\n Added columns: {', '.join(added)}\n"
if changed:
message += "\n Changed columns:"
table1_columns = {c.name.root: c for c in self.runtime_params.table1.columns}
table2_columns = {c.name.root: c for c in self.runtime_params.table2.columns}
if not self.get_case_sensitive():
table1_columns = CaseInsensitiveDict(table1_columns)
table2_columns = CaseInsensitiveDict(table2_columns)
for col in changed:
col1 = table1_columns[col]
col2 = table2_columns[col]
message += f"\n {col}: {col1.dataType.value} -> {col2.dataType.value}"
return TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=TestCaseStatus.Failed,
result=message,
testResultValue=[
TestResultValue(name="removedColumns", value=str(len(removed))),
TestResultValue(name="addedColumns", value=str(len(added))),
TestResultValue(name="changedColumns", value=str(len(changed))),
],
)
def calculate_diffs_with_limit(self, diff_iter: Iterable[Tuple[str, Tuple[str, ...]]], limit: int) -> int:
"""Given an iterator of diffs like
- ('+', (...))
- ('-', (...))
...
Calculate the total diffs by combining diffs for the same key. This gives an accurate count of the total diffs
as opposed to self.calculate_diff_num(diff_list)just counting the number of diffs in the list.
Args:
diff_iter: iterator returned from the data_diff algorithm
Returns:
int: accurate count of the total diffs up to the limit
"""
len_key_columns = len(self.runtime_params.keyColumns)
key_set = set()
# combine diffs on same key to "!"
for _, values in diff_iter:
k = values[:len_key_columns]
if k in key_set:
continue
key_set.add(k)
if len(key_set) > limit:
len(key_set)
return len(key_set)
def safe_table_diff_iterator(self) -> DiffResultWrapper:
"""A safe iterator object which properly closes the diff object when the generator is exhausted.
Otherwise the data_diff library will continue to hold the connection open and eventually
raise a KeyError.
"""
gen = self.get_table_diff()
try:
yield from gen
finally:
try:
gen.diff.close()
except KeyError as ex:
if str(ex) == "2":
# This is a known issue in data_diff where the diff object is closed
pass
def get_case_sensitive(self):
return utils.get_bool_test_case_param(self.test_case.parameterValues, "caseSensitiveColumns")
def get_row_count(self) -> Optional[int]:
return self._compute_row_count(self.runner, None)
def get_total_row_count(self) -> Optional[int]:
row_count = Metrics.rowCount()
try:
row = self.runner.select_first_from_table(row_count.fn())
return row._asdict().get(Metrics.rowCount.name)
except Exception as e:
logger.error(f"Error getting row count: {e}")
return None