Skip to content

Commit e45f0c9

Browse files
authored
Merge pull request #407 from posit-dev/fix-avoid-collect-counting-slowdown
fix: avoid inefficient collect/count at interrogation time
2 parents dc6e027 + c1a4e45 commit e45f0c9

3 files changed

Lines changed: 91 additions & 10 deletions

File tree

pointblank/_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,61 @@ def _count_null_values_in_column(
395395
return int(result.item())
396396

397397

398+
def _count_validation_units(
399+
tbl: IntoFrame,
400+
column: str,
401+
) -> tuple[int, int, int, int]:
402+
"""
403+
Compute the row count and pass/fail/null counts for a results table in a single pass.
404+
405+
Given a results table with a boolean `column` (typically ``pb_is_good_``), this returns
406+
the total number of rows, the number of `True` values (passing test units), the number of
407+
`False` values (failing test units), and the number of Null values.
408+
409+
Computing all four quantities in one aggregation is important for LazyFrames: otherwise each
410+
separate count would trigger its own `collect()`, re-executing the entire (potentially
411+
expensive) lazy plan multiple times.
412+
413+
Parameters
414+
----------
415+
tbl
416+
A Narwhals-compatible DataFrame or table-like object.
417+
column
418+
The boolean column to summarize.
419+
420+
Returns
421+
-------
422+
tuple[int, int, int, int]
423+
A tuple of ``(n, n_passed, n_failed, n_null)``.
424+
"""
425+
426+
# Convert the DataFrame to a Narwhals DataFrame (no detrimental effect if
427+
# already a Narwhals DataFrame)
428+
tbl_nw = nw.from_native(tbl)
429+
430+
# Build a single aggregation that computes all counts at once. Casting booleans to Int32
431+
# before summing is required for backends like PySpark (which can't sum booleans), and the
432+
# sums naturally ignore Null values (so `n_passed`/`n_failed` exclude nulls).
433+
result = tbl_nw.select(
434+
nw.len().alias("n"),
435+
nw.col(column).cast(nw.Int32).sum().alias("n_passed"),
436+
(~nw.col(column)).cast(nw.Int32).sum().alias("n_failed"),
437+
nw.col(column).is_null().cast(nw.Int32).sum().alias("n_null"),
438+
)
439+
440+
if is_narwhals_lazyframe(result):
441+
result = result.collect()
442+
443+
row = result.rows(named=True)[0]
444+
445+
n = int(row["n"])
446+
n_passed = int(row["n_passed"] or 0)
447+
n_failed = int(row["n_failed"] or 0)
448+
n_null = int(row["n_null"] or 0)
449+
450+
return n, n_passed, n_failed, n_null
451+
452+
398453
def _is_numeric_dtype(dtype: str) -> bool:
399454
"""
400455
Check if a given data type string represents a numeric type.

pointblank/validate.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@
9696
_check_invalid_fields,
9797
_column_test_prep,
9898
_copy_dataframe,
99-
_count_null_values_in_column,
100-
_count_true_values_in_column,
99+
_count_validation_units,
101100
_derive_bounds,
102101
_format_to_integer_value,
103102
_get_fn_name,
@@ -15437,22 +15436,23 @@ def interrogate(
1543715436
# called `pb_is_good_` that contains boolean values; we can then use this table to
1543815437
# determine the number of test units that passed and failed
1543915438
if results_tbl is not None:
15440-
# Count the number of passing and failing test units
15441-
validation.n_passed = _count_true_values_in_column(
15439+
# Count passing/failing test units and the total row count in a single pass.
15440+
# Doing this together avoids re-executing the (possibly lazy) results-table plan
15441+
# multiple times, which would otherwise scan the data once per count.
15442+
n_units, n_passed, n_failed, n_null = _count_validation_units(
1544215443
tbl=results_tbl, column="pb_is_good_"
1544315444
)
15444-
validation.n_failed = _count_true_values_in_column(
15445-
tbl=results_tbl, column="pb_is_good_", inverse=True
15446-
)
15445+
15446+
validation.n_passed = n_passed
15447+
validation.n_failed = n_failed
1544715448

1544815449
# Solely for the col_vals_in_set assertion type, any Null values in the
1544915450
# `pb_is_good_` column are counted as failing test units
1545015451
if assertion_type == "col_vals_in_set":
15451-
null_count = _count_null_values_in_column(tbl=results_tbl, column="pb_is_good_")
15452-
validation.n_failed += null_count
15452+
validation.n_failed += n_null
1545315453

1545415454
# For column-value validations, the number of test units is the number of rows
15455-
validation.n = get_row_count(data=results_tbl)
15455+
validation.n = n_units
1545615456

1545715457
# Set the `all_passed` attribute based on whether there are any failing test units
1545815458
validation.all_passed = validation.n_failed == 0

tests/test__utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_copy_dataframe,
2525
_count_null_values_in_column,
2626
_count_true_values_in_column,
27+
_count_validation_units,
2728
_derive_bounds,
2829
_derive_single_bound,
2930
_format_to_float_value,
@@ -364,6 +365,31 @@ def test_count_null_values_in_column(tbl_type):
364365
assert _count_null_values_in_column(tbl=data, column="c") == 2
365366

366367

368+
@pytest.mark.parametrize("tbl_type", ["polars", "duckdb"])
369+
def test_count_validation_units(tbl_type):
370+
data = load_dataset(dataset="small_table", tbl_type=tbl_type)
371+
372+
# Column `e` has 8 True and 5 False values (13 rows total, no nulls)
373+
n, n_passed, n_failed, n_null = _count_validation_units(tbl=data, column="e")
374+
375+
assert n == 13
376+
assert n_passed == 8
377+
assert n_failed == 5
378+
assert n_null == 0
379+
380+
381+
def test_count_validation_units_with_nulls():
382+
import polars as pl
383+
384+
df = pl.DataFrame({"pb_is_good_": [True, False, True, None, None]})
385+
386+
# A LazyFrame and an eager DataFrame should yield identical counts; Null values are excluded
387+
# from both the pass and fail counts and surfaced separately
388+
for native in (df, df.lazy()):
389+
n, n_passed, n_failed, n_null = _count_validation_units(tbl=native, column="pb_is_good_")
390+
assert (n, n_passed, n_failed, n_null) == (5, 2, 1, 2)
391+
392+
367393
def test_format_to_integer_value():
368394
assert _format_to_integer_value(0) == "0"
369395
assert _format_to_integer_value(0.3) == "0"

0 commit comments

Comments
 (0)