Skip to content

Commit abb9332

Browse files
committed
Fix validity masks in Arrow UDF
1 parent 724671f commit abb9332

1 file changed

Lines changed: 15 additions & 7 deletions

File tree

src/duckdb_py/python_udf.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ static scalar_function_t CreateVectorizedFunction(PyObject *function, PythonExce
203203
}
204204
}
205205
if (any_null) {
206-
FlatVector::ValidityMutable(result).SetInvalid(i);
207206
continue;
208207
}
209208
selvec.set_index(index++, i);
@@ -261,22 +260,31 @@ static scalar_function_t CreateVectorizedFunction(PyObject *function, PythonExce
261260
}
262261
if (count) {
263262
SelectionVector inverted(input_size);
264-
// Create a SelVec that inverts the filtering
265-
// example: count: 6, null_indices: 1,3
266-
// input selvec: [0, 2, 4, 5]
267-
// inverted selvec: [0, 0, 1, 1, 2, 3]
263+
// Map each target row back to a source row in temp. Non-null target rows map to
264+
// their UDF output; null target rows point at the next non-null source row (their
265+
// data is later masked out by SetNull).
266+
// example: input_size: 6, null_indices: 1,3
267+
// selvec (non-null indices): [0, 2, 4, 5]
268+
// inverted selvec: [0, 1, 1, 2, 2, 3]
268269
idx_t src_index = 0;
269270
for (idx_t i = 0; i < input_size; i++) {
270-
// Fill the gaps with the previous index
271271
inverted.set_index(i, src_index);
272272
if (src_index + 1 < count && selvec.get_index(src_index) == i) {
273273
src_index++;
274274
}
275275
}
276276
VectorOperations::Copy(temp, result, inverted, count, 0, 0, input_size);
277277
}
278+
// Apply the null mask: any position not present in selvec was a null input row.
279+
// VectorOperations::Copy unconditionally overwrites the result's validity from
280+
// the source's, so we must do this after the Copy.
281+
idx_t sel_idx = 0;
278282
for (idx_t i = 0; i < input_size; i++) {
279-
FlatVector::SetNull(result, i, !FlatVector::Validity(result).RowIsValid(i));
283+
if (sel_idx < count && selvec.get_index(sel_idx) == i) {
284+
sel_idx++;
285+
} else {
286+
FlatVector::SetNull(result, i, true);
287+
}
280288
}
281289
result.Verify();
282290
} else {

0 commit comments

Comments
 (0)