Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit c180cb1

Browse files
authored
Align PK support with Datafold SaaS (#446)
* squash align pk support * black -l 120 * unused import * set -> Set for python 3.7, 3.8 * Running with data-diff={version}
1 parent d21c140 commit c180cb1

2 files changed

Lines changed: 185 additions & 41 deletions

File tree

data_diff/dbt.py

Lines changed: 103 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
import json
2-
import logging
32
import os
43
import time
54
import rich
5+
6+
from collections import defaultdict
67
from dataclasses import dataclass
78
from packaging.version import parse as parse_version
8-
from typing import List, Optional, Dict, Tuple
9+
from typing import List, Optional, Dict, Tuple, Set
10+
from .utils import getLogger
11+
from .version import __version__
912
from pathlib import Path
1013

1114
import requests
1215

16+
logger = getLogger(__name__)
17+
1318

1419
def import_dbt():
1520
try:
@@ -72,7 +77,6 @@ def dbt_diff(
7277
set_entrypoint_name("CLI-dbt")
7378
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, is_cloud)
7479
models = dbt_parser.get_models()
75-
dbt_parser.set_project_dict()
7680
datadiff_variables = dbt_parser.get_datadiff_variables()
7781
config_prod_database = datadiff_variables.get("prod_database")
7882
config_prod_schema = datadiff_variables.get("prod_schema")
@@ -105,7 +109,7 @@ def dbt_diff(
105109
+ " <> "
106110
+ ".".join(diff_vars.dev_path)
107111
+ "[/] \n"
108-
+ "Skipped due to missing primary-key tag(s).\n"
112+
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
109113
)
110114

111115
rich.print("Diffs Complete!")
@@ -121,7 +125,8 @@ def _get_diff_vars(
121125
) -> DiffVars:
122126
dev_database = model.database
123127
dev_schema = model.schema_
124-
primary_keys = dbt_parser.get_primary_keys(model)
128+
129+
primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")
125130

126131
prod_database = config_prod_database if config_prod_database else dev_database
127132
prod_schema = config_prod_schema if config_prod_schema else dev_schema
@@ -162,7 +167,7 @@ def _local_diff(diff_vars: DiffVars) -> None:
162167
table2_columns = list(table2.get_schema())
163168
# Not ideal, but we don't have more specific exceptions yet
164169
except Exception as ex:
165-
logging.info(ex)
170+
logger.debug(ex)
166171
rich.print(
167172
"[red]"
168173
+ prod_qualified_string
@@ -287,15 +292,16 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
287292

288293
class DbtParser:
289294
def __init__(self, profiles_dir_override: str, project_dir_override: str, is_cloud: bool) -> None:
295+
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
290296
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
291297
self.project_dir = Path(project_dir_override or default_project_dir())
292298
self.is_cloud = is_cloud
293299
self.connection = None
294-
self.project_dict = None
300+
self.project_dict = self.get_project_dict()
301+
self.manifest_obj = self.get_manifest_obj()
295302
self.requires_upper = False
296303
self.threads = None
297-
298-
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
304+
self.unique_columns = self.get_unique_columns()
299305

300306
def get_datadiff_variables(self) -> dict:
301307
return self.project_dict.get("vars").get("data_diff")
@@ -315,24 +321,24 @@ def get_models(self):
315321
f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V} and < {UPPER_DBT_V}"
316322
)
317323

318-
with open(self.project_dir / MANIFEST_PATH) as manifest:
319-
manifest_dict = json.load(manifest)
320-
manifest_obj = self.parse_manifest(manifest=manifest_dict)
321-
322324
success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
323-
models = [manifest_obj.nodes.get(x) for x in success_models]
325+
models = [self.manifest_obj.nodes.get(x) for x in success_models]
324326
if not models:
325327
raise ValueError("Expected > 0 successful models runs from the last dbt command.")
326328

327-
rich.print(f"Found {str(len(models))} successful model runs from the last dbt command.")
329+
print(f"Running with data-diff={__version__}\n")
328330
return models
329331

330-
def get_primary_keys(self, model):
331-
return list((x.name for x in model.columns.values() if "primary-key" in x.tags))
332+
def get_manifest_obj(self):
333+
with open(self.project_dir / MANIFEST_PATH) as manifest:
334+
manifest_dict = json.load(manifest)
335+
manifest_obj = self.parse_manifest(manifest=manifest_dict)
336+
return manifest_obj
332337

333-
def set_project_dict(self):
338+
def get_project_dict(self):
334339
with open(self.project_dir / PROJECT_FILE) as project:
335-
self.project_dict = self.yaml.safe_load(project)
340+
project_dict = self.yaml.safe_load(project)
341+
return project_dict
336342

337343
def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
338344
profiles_path = self.profiles_dir / PROFILES_FILE
@@ -437,3 +443,81 @@ def set_connection(self):
437443
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")
438444

439445
self.connection = conn_info
446+
447+
def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str]:
448+
try:
449+
# Get a set of all the column names
450+
column_names = {name for name, params in node.columns.items()}
451+
# Check if the tag is present on a table level
452+
if pk_tag in node.meta:
453+
# Get all the PKs that are also present as a column
454+
pks = [pk for pk in pk_tag in node.meta[pk_tag] if pk in column_names]
455+
if pks:
456+
# If there are any left, return it
457+
logger.debug("Found PKs via Table META: " + str(pks))
458+
return pks
459+
460+
from_meta = [name for name, params in node.columns.items() if pk_tag in params.meta] or None
461+
if from_meta:
462+
logger.debug("Found PKs via META: " + str(from_meta))
463+
return from_meta
464+
465+
from_tags = [name for name, params in node.columns.items() if pk_tag in params.tags] or None
466+
if from_tags:
467+
logger.debug("Found PKs via Tags: " + str(from_tags))
468+
return from_tags
469+
470+
if node.unique_id in unique_columns:
471+
from_uniq = unique_columns.get(node.unique_id)
472+
if from_uniq is not None:
473+
logger.debug("Found PKs via Uniqueness tests: " + str(from_uniq))
474+
return list(from_uniq)
475+
476+
except (KeyError, IndexError, TypeError) as e:
477+
raise e
478+
479+
logger.debug("Found no PKs")
480+
return []
481+
482+
def get_unique_columns(self) -> Dict[str, Set[str]]:
483+
manifest = self.manifest_obj
484+
cols_by_uid = defaultdict(set)
485+
for node in manifest.nodes.values():
486+
try:
487+
if not (node.resource_type.value == "test" and hasattr(node, "test_metadata")):
488+
continue
489+
490+
if node.depends_on is None or node.depends_on.nodes is []:
491+
continue
492+
493+
uid = node.depends_on.nodes[0]
494+
model_node = manifest.nodes[uid]
495+
496+
if node.test_metadata.name == "unique":
497+
column_name: str = node.test_metadata.kwargs["column_name"]
498+
for col in self._parse_concat_pk_definition(column_name):
499+
if model_node is None or col in model_node.columns:
500+
# skip anything that is not a column.
501+
# for example, string literals used in concat
502+
# like "pk1 || '-' || pk2"
503+
cols_by_uid[uid].add(col)
504+
505+
if node.test_metadata.name == "unique_combination_of_columns":
506+
for col in node.test_metadata.kwargs["combination_of_columns"]:
507+
cols_by_uid[uid].add(col)
508+
509+
except (KeyError, IndexError, TypeError) as e:
510+
logger.warning("Failure while finding unique cols: %s", e)
511+
512+
return cols_by_uid
513+
514+
def _parse_concat_pk_definition(self, definition: str) -> List[str]:
515+
definition = definition.strip()
516+
if definition.lower().startswith("concat(") and definition.endswith(")"):
517+
definition = definition[7:-1] # Removes concat( and )
518+
columns = definition.split(",")
519+
else:
520+
columns = definition.split("||")
521+
522+
stripped_columns = [col.strip('" ()') for col in columns]
523+
return stripped_columns

0 commit comments

Comments
 (0)