11import json
2- import logging
32import os
43import time
54import rich
5+
6+ from collections import defaultdict
67from dataclasses import dataclass
78from packaging .version import parse as parse_version
8- from typing import List , Optional , Dict , Tuple
9+ from typing import List , Optional , Dict , Tuple , Set
10+ from .utils import getLogger
11+ from .version import __version__
912from pathlib import Path
1013
1114import requests
1215
16+ logger = getLogger (__name__ )
17+
1318
1419def import_dbt ():
1520 try :
@@ -72,7 +77,6 @@ def dbt_diff(
7277 set_entrypoint_name ("CLI-dbt" )
7378 dbt_parser = DbtParser (profiles_dir_override , project_dir_override , is_cloud )
7479 models = dbt_parser .get_models ()
75- dbt_parser .set_project_dict ()
7680 datadiff_variables = dbt_parser .get_datadiff_variables ()
7781 config_prod_database = datadiff_variables .get ("prod_database" )
7882 config_prod_schema = datadiff_variables .get ("prod_schema" )
@@ -105,7 +109,7 @@ def dbt_diff(
105109 + " <> "
106110 + "." .join (diff_vars .dev_path )
107111 + "[/] \n "
108- + "Skipped due to missing primary- key tag(s) .\n "
112+ + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags .\n "
109113 )
110114
111115 rich .print ("Diffs Complete!" )
@@ -121,7 +125,8 @@ def _get_diff_vars(
121125) -> DiffVars :
122126 dev_database = model .database
123127 dev_schema = model .schema_
124- primary_keys = dbt_parser .get_primary_keys (model )
128+
129+ primary_keys = dbt_parser .get_pk_from_model (model , dbt_parser .unique_columns , "primary-key" )
125130
126131 prod_database = config_prod_database if config_prod_database else dev_database
127132 prod_schema = config_prod_schema if config_prod_schema else dev_schema
@@ -162,7 +167,7 @@ def _local_diff(diff_vars: DiffVars) -> None:
162167 table2_columns = list (table2 .get_schema ())
163168 # Not ideal, but we don't have more specific exceptions yet
164169 except Exception as ex :
165- logging . info (ex )
170+ logger . debug (ex )
166171 rich .print (
167172 "[red]"
168173 + prod_qualified_string
@@ -287,15 +292,16 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
287292
288293class DbtParser :
289294 def __init__ (self , profiles_dir_override : str , project_dir_override : str , is_cloud : bool ) -> None :
295+ self .parse_run_results , self .parse_manifest , self .ProfileRenderer , self .yaml = import_dbt ()
290296 self .profiles_dir = Path (profiles_dir_override or default_profiles_dir ())
291297 self .project_dir = Path (project_dir_override or default_project_dir ())
292298 self .is_cloud = is_cloud
293299 self .connection = None
294- self .project_dict = None
300+ self .project_dict = self .get_project_dict ()
301+ self .manifest_obj = self .get_manifest_obj ()
295302 self .requires_upper = False
296303 self .threads = None
297-
298- self .parse_run_results , self .parse_manifest , self .ProfileRenderer , self .yaml = import_dbt ()
304+ self .unique_columns = self .get_unique_columns ()
299305
300306 def get_datadiff_variables (self ) -> dict :
301307 return self .project_dict .get ("vars" ).get ("data_diff" )
@@ -315,24 +321,24 @@ def get_models(self):
315321 f"Found dbt: v{ dbt_version } Expected the dbt project's version to be >= { LOWER_DBT_V } and < { UPPER_DBT_V } "
316322 )
317323
318- with open (self .project_dir / MANIFEST_PATH ) as manifest :
319- manifest_dict = json .load (manifest )
320- manifest_obj = self .parse_manifest (manifest = manifest_dict )
321-
322324 success_models = [x .unique_id for x in run_results_obj .results if x .status .name == "success" ]
323- models = [manifest_obj .nodes .get (x ) for x in success_models ]
325+ models = [self . manifest_obj .nodes .get (x ) for x in success_models ]
324326 if not models :
325327 raise ValueError ("Expected > 0 successful models runs from the last dbt command." )
326328
327- rich . print (f"Found { str ( len ( models )) } successful model runs from the last dbt command. " )
329+ print (f"Running with data-diff= { __version__ } \n " )
328330 return models
329331
330- def get_primary_keys (self , model ):
331- return list ((x .name for x in model .columns .values () if "primary-key" in x .tags ))
332+ def get_manifest_obj (self ):
333+ with open (self .project_dir / MANIFEST_PATH ) as manifest :
334+ manifest_dict = json .load (manifest )
335+ manifest_obj = self .parse_manifest (manifest = manifest_dict )
336+ return manifest_obj
332337
333- def set_project_dict (self ):
338+ def get_project_dict (self ):
334339 with open (self .project_dir / PROJECT_FILE ) as project :
335- self .project_dict = self .yaml .safe_load (project )
340+ project_dict = self .yaml .safe_load (project )
341+ return project_dict
336342
337343 def _get_connection_creds (self ) -> Tuple [Dict [str , str ], str ]:
338344 profiles_path = self .profiles_dir / PROFILES_FILE
@@ -437,3 +443,81 @@ def set_connection(self):
437443 raise NotImplementedError (f"Provider { conn_type } is not yet supported for dbt diffs" )
438444
439445 self .connection = conn_info
446+
447+ def get_pk_from_model (self , node , unique_columns : dict , pk_tag : str ) -> List [str ]:
448+ try :
449+ # Get a set of all the column names
450+ column_names = {name for name , params in node .columns .items ()}
451+ # Check if the tag is present on a table level
452+ if pk_tag in node .meta :
453+ # Get all the PKs that are also present as a column
454+ pks = [pk for pk in pk_tag in node .meta [pk_tag ] if pk in column_names ]
455+ if pks :
456+ # If there are any left, return it
457+ logger .debug ("Found PKs via Table META: " + str (pks ))
458+ return pks
459+
460+ from_meta = [name for name , params in node .columns .items () if pk_tag in params .meta ] or None
461+ if from_meta :
462+ logger .debug ("Found PKs via META: " + str (from_meta ))
463+ return from_meta
464+
465+ from_tags = [name for name , params in node .columns .items () if pk_tag in params .tags ] or None
466+ if from_tags :
467+ logger .debug ("Found PKs via Tags: " + str (from_tags ))
468+ return from_tags
469+
470+ if node .unique_id in unique_columns :
471+ from_uniq = unique_columns .get (node .unique_id )
472+ if from_uniq is not None :
473+ logger .debug ("Found PKs via Uniqueness tests: " + str (from_uniq ))
474+ return list (from_uniq )
475+
476+ except (KeyError , IndexError , TypeError ) as e :
477+ raise e
478+
479+ logger .debug ("Found no PKs" )
480+ return []
481+
482+ def get_unique_columns (self ) -> Dict [str , Set [str ]]:
483+ manifest = self .manifest_obj
484+ cols_by_uid = defaultdict (set )
485+ for node in manifest .nodes .values ():
486+ try :
487+ if not (node .resource_type .value == "test" and hasattr (node , "test_metadata" )):
488+ continue
489+
490+ if node .depends_on is None or node .depends_on .nodes is []:
491+ continue
492+
493+ uid = node .depends_on .nodes [0 ]
494+ model_node = manifest .nodes [uid ]
495+
496+ if node .test_metadata .name == "unique" :
497+ column_name : str = node .test_metadata .kwargs ["column_name" ]
498+ for col in self ._parse_concat_pk_definition (column_name ):
499+ if model_node is None or col in model_node .columns :
500+ # skip anything that is not a column.
501+ # for example, string literals used in concat
502+ # like "pk1 || '-' || pk2"
503+ cols_by_uid [uid ].add (col )
504+
505+ if node .test_metadata .name == "unique_combination_of_columns" :
506+ for col in node .test_metadata .kwargs ["combination_of_columns" ]:
507+ cols_by_uid [uid ].add (col )
508+
509+ except (KeyError , IndexError , TypeError ) as e :
510+ logger .warning ("Failure while finding unique cols: %s" , e )
511+
512+ return cols_by_uid
513+
514+ def _parse_concat_pk_definition (self , definition : str ) -> List [str ]:
515+ definition = definition .strip ()
516+ if definition .lower ().startswith ("concat(" ) and definition .endswith (")" ):
517+ definition = definition [7 :- 1 ] # Removes concat( and )
518+ columns = definition .split ("," )
519+ else :
520+ columns = definition .split ("||" )
521+
522+ stripped_columns = [col .strip ('" ()' ) for col in columns ]
523+ return stripped_columns
0 commit comments