2525from sqlglot .helper import ensure_list
2626from sqlglot .optimizer .qualify_columns import quote_identifiers
2727
28- from sqlmesh .core .dialect import select_from_values_for_batch_range
28+ from sqlmesh .core .dialect import add_table , select_from_values_for_batch_range
2929from sqlmesh .core .engine_adapter .shared import DataObject , TransactionType
3030from sqlmesh .core .model .kind import TimeColumn
3131from sqlmesh .core .schema_diff import SchemaDiffer
@@ -911,7 +911,7 @@ def scd_type_2(
911911 self ,
912912 target_table : TableName ,
913913 source_table : QueryOrDF ,
914- unique_key : t .Sequence [str ],
914+ unique_key : t .Sequence [exp . Expression ],
915915 valid_from_name : str ,
916916 valid_to_name : str ,
917917 updated_at_name : str ,
@@ -937,7 +937,7 @@ def scd_type_2(
937937 exp .Select () # type: ignore
938938 .with_ (
939939 "source" ,
940- exp .select (* unmanaged_columns )
940+ exp .select (exp . true (). as_ ( "_exists" ), * unmanaged_columns )
941941 .distinct (* unique_key )
942942 .from_ (source_query .subquery ("raw_source" )), # type: ignore
943943 )
@@ -964,8 +964,8 @@ def scd_type_2(
964964 "latest" ,
965965 on = exp .and_ (
966966 * [
967- exp . column ( col , table = "static" ).eq (exp . column ( col , table = "latest" ))
968- for col in unique_key
967+ add_table ( key , "static" ).eq (add_table ( key , "latest" ))
968+ for key in unique_key
969969 ]
970970 ),
971971 join_type = "left" ,
@@ -976,7 +976,8 @@ def scd_type_2(
976976 .with_ (
977977 "latest_deleted" ,
978978 exp .select (
979- * unique_key ,
979+ exp .true ().as_ ("_exists" ),
980+ * (part .as_ (f"_key{ i } " ) for i , part in enumerate (unique_key )),
980981 f"MAX({ valid_to_name } ) AS { valid_to_name } " ,
981982 )
982983 .from_ ("deleted" )
@@ -987,34 +988,43 @@ def scd_type_2(
987988 .with_ (
988989 "joined" ,
989990 exp .select (
990- * (f"latest.{ col } AS t_{ col } " for col in columns_to_types ),
991- * (f"source.{ col } AS s_{ col } " for col in unmanaged_columns ),
991+ exp .column ("_exists" , table = "source" ),
992+ * (
993+ exp .column (col , table = "latest" ).as_ (f"t_{ col } " )
994+ for col in columns_to_types
995+ ),
996+ * (exp .column (col , table = "source" ).as_ (col ) for col in unmanaged_columns ),
992997 )
993998 .from_ ("latest" )
994999 .join (
9951000 "source" ,
9961001 on = exp .and_ (
9971002 * [
998- exp . column ( col , table = "latest" ).eq (exp . column ( col , table = "source" ))
999- for col in unique_key
1003+ add_table ( key , "latest" ).eq (add_table ( key , "source" ))
1004+ for key in unique_key
10001005 ]
10011006 ),
10021007 join_type = "left" ,
10031008 )
10041009 .union (
10051010 exp .select (
1006- * (f"latest.{ col } AS t_{ col } " for col in columns_to_types ),
1007- * (f"source.{ col } AS s_{ col } " for col in unmanaged_columns ),
1011+ exp .column ("_exists" , table = "source" ),
1012+ * (
1013+ exp .column (col , table = "latest" ).as_ (f"t_{ col } " )
1014+ for col in columns_to_types
1015+ ),
1016+ * (
1017+ exp .column (col , table = "source" ).as_ (col )
1018+ for col in unmanaged_columns
1019+ ),
10081020 )
10091021 .from_ ("latest" )
10101022 .join (
10111023 "source" ,
10121024 on = exp .and_ (
10131025 * [
1014- exp .column (col , table = "latest" ).eq (
1015- exp .column (col , table = "source" )
1016- )
1017- for col in unique_key
1026+ add_table (key , "latest" ).eq (add_table (key , "source" ))
1027+ for key in unique_key
10181028 ]
10191029 ),
10201030 join_type = "right" ,
@@ -1025,25 +1035,32 @@ def scd_type_2(
10251035 .with_ (
10261036 "updated_rows" ,
10271037 exp .select (
1028- * (f"COALESCE(t_{ col } , s_{ col } ) as { col } " for col in unmanaged_columns ),
1038+ * (
1039+ exp .func (
1040+ "COALESCE" ,
1041+ exp .column (f"t_{ col } " , table = "joined" ),
1042+ exp .column (col , table = "joined" ),
1043+ ).as_ (col )
1044+ for col in unmanaged_columns
1045+ ),
10291046 f"""
10301047 CASE
10311048 WHEN t_{ valid_from_name } IS NULL
1032- AND latest_deleted.{ unique_key [ 0 ] } IS NOT NULL
1049+ AND latest_deleted._exists IS NOT NULL
10331050 THEN CASE
1034- WHEN latest_deleted.{ valid_to_name } > s_ { updated_at_name }
1051+ WHEN latest_deleted.{ valid_to_name } > { updated_at_name }
10351052 THEN latest_deleted.{ valid_to_name }
1036- ELSE s_ { updated_at_name }
1053+ ELSE { updated_at_name }
10371054 END
10381055 WHEN t_{ valid_from_name } IS NULL
10391056 THEN { self ._to_utc_timestamp ('1970-01-01 00:00:00+00:00' )}
10401057 ELSE t_{ valid_from_name }
10411058 END AS { valid_from_name } """ ,
10421059 f"""
10431060 CASE
1044- WHEN s_ { updated_at_name } > t_{ updated_at_name }
1045- THEN s_ { updated_at_name }
1046- WHEN s_ { unique_key [ 0 ] } IS NULL
1061+ WHEN { updated_at_name } > t_{ updated_at_name }
1062+ THEN { updated_at_name }
1063+ WHEN joined._exists IS NULL
10471064 THEN { self ._to_utc_timestamp (to_ts (execution_time ))}
10481065 ELSE t_{ valid_to_name }
10491066 END AS { valid_to_name } """ ,
@@ -1053,10 +1070,10 @@ def scd_type_2(
10531070 "latest_deleted" ,
10541071 on = exp .and_ (
10551072 * [
1056- exp . column ( f"s_ { col } " , table = "joined" ).eq (
1057- exp .column (col , table = "latest_deleted" )
1073+ add_table ( part , "joined" ).eq (
1074+ exp .column (f"_key { i } " , "latest_deleted" )
10581075 )
1059- for col in unique_key
1076+ for i , part in enumerate ( unique_key )
10601077 ]
10611078 ),
10621079 join_type = "left" ,
@@ -1066,14 +1083,12 @@ def scd_type_2(
10661083 .with_ (
10671084 "inserted_rows" ,
10681085 exp .select (
1069- * ( f"s_ { col } as { col } " for col in unmanaged_columns ) ,
1070- f"s_ { updated_at_name } as { valid_from_name } " ,
1086+ * unmanaged_columns ,
1087+ f"{ updated_at_name } as { valid_from_name } " ,
10711088 f"{ self ._to_utc_timestamp (exp .null ())} as { valid_to_name } " ,
10721089 )
10731090 .from_ ("joined" )
1074- .where (
1075- f"t_{ unique_key [0 ]} IS NOT NULL AND s_{ unique_key [0 ]} IS NOT NULL AND s_{ updated_at_name } > t_{ updated_at_name } "
1076- ),
1091+ .where (f"{ updated_at_name } > t_{ updated_at_name } " ),
10771092 )
10781093 .select ("*" )
10791094 .from_ ("static" )
@@ -1097,18 +1112,15 @@ def merge(
10971112 target_table : TableName ,
10981113 source_table : QueryOrDF ,
10991114 columns_to_types : t .Optional [t .Dict [str , exp .DataType ]],
1100- unique_key : t .Sequence [str ],
1115+ unique_key : t .Sequence [exp . Expression ],
11011116 ) -> None :
11021117 source_queries , columns_to_types = self ._get_source_queries_and_columns_to_types (
11031118 source_table , columns_to_types , target_table = target_table
11041119 )
11051120 columns_to_types = columns_to_types or self .columns (target_table )
11061121 on = exp .and_ (
11071122 * (
1108- exp .EQ (
1109- this = exp .column (part , MERGE_TARGET_ALIAS ),
1110- expression = exp .column (part , MERGE_SOURCE_ALIAS ),
1111- )
1123+ add_table (part , MERGE_TARGET_ALIAS ).eq (add_table (part , MERGE_SOURCE_ALIAS ))
11121124 for part in unique_key
11131125 )
11141126 )
0 commit comments