@@ -121,15 +121,16 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
121121 target_index = target_table .select (join_cols_set ).append_column (TARGET_INDEX_COLUMN_NAME , pa .array (range (len (target_table ))))
122122
123123 # Step 3: Perform an inner join to find which rows from source exist in target
124- matching_indices = source_index .join (target_index , keys = list (join_cols_set ), join_type = "inner" )
124+ # PyArrow joins ignore null values, and we want null==null to hold, so we compute the join in Python.
125+ # This is equivalent to:
126+ # matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
127+ source_indices = {tuple (row [col ] for col in join_cols ): row [SOURCE_INDEX_COLUMN_NAME ] for row in source_index .to_pylist ()}
128+ target_indices = {tuple (row [col ] for col in join_cols ): row [TARGET_INDEX_COLUMN_NAME ] for row in target_index .to_pylist ()}
129+ matching_indices = [(s , t ) for key , s in source_indices .items () if (t := target_indices .get (key )) is not None ]
125130
126131 # Step 4: Compare all rows using Python
127132 to_update_indices = []
128- for source_idx , target_idx in zip (
129- matching_indices [SOURCE_INDEX_COLUMN_NAME ].to_pylist (),
130- matching_indices [TARGET_INDEX_COLUMN_NAME ].to_pylist (),
131- strict = True ,
132- ):
133+ for source_idx , target_idx in matching_indices :
133134 source_row = source_table .slice (source_idx , 1 )
134135 target_row = target_table .slice (target_idx , 1 )
135136
0 commit comments