@@ -395,6 +395,61 @@ def _count_null_values_in_column(
395395 return int (result .item ())
396396
397397
398+ def _count_validation_units (
399+ tbl : IntoFrame ,
400+ column : str ,
401+ ) -> tuple [int , int , int , int ]:
402+ """
403+ Compute the row count and pass/fail/null counts for a results table in a single pass.
404+
405+ Given a results table with a boolean `column` (typically ``pb_is_good_``), this returns
406+ the total number of rows, the number of `True` values (passing test units), the number of
407+ `False` values (failing test units), and the number of Null values.
408+
409+ Computing all four quantities in one aggregation is important for LazyFrames: otherwise each
410+ separate count would trigger its own `collect()`, re-executing the entire (potentially
411+ expensive) lazy plan multiple times.
412+
413+ Parameters
414+ ----------
415+ tbl
416+ A Narwhals-compatible DataFrame or table-like object.
417+ column
418+ The boolean column to summarize.
419+
420+ Returns
421+ -------
422+ tuple[int, int, int, int]
423+ A tuple of ``(n, n_passed, n_failed, n_null)``.
424+ """
425+
426+ # Convert the DataFrame to a Narwhals DataFrame (no detrimental effect if
427+ # already a Narwhals DataFrame)
428+ tbl_nw = nw .from_native (tbl )
429+
430+ # Build a single aggregation that computes all counts at once. Casting booleans to Int32
431+ # before summing is required for backends like PySpark (which can't sum booleans), and the
432+ # sums naturally ignore Null values (so `n_passed`/`n_failed` exclude nulls).
433+ result = tbl_nw .select (
434+ nw .len ().alias ("n" ),
435+ nw .col (column ).cast (nw .Int32 ).sum ().alias ("n_passed" ),
436+ (~ nw .col (column )).cast (nw .Int32 ).sum ().alias ("n_failed" ),
437+ nw .col (column ).is_null ().cast (nw .Int32 ).sum ().alias ("n_null" ),
438+ )
439+
440+ if is_narwhals_lazyframe (result ):
441+ result = result .collect ()
442+
443+ row = result .rows (named = True )[0 ]
444+
445+ n = int (row ["n" ])
446+ n_passed = int (row ["n_passed" ] or 0 )
447+ n_failed = int (row ["n_failed" ] or 0 )
448+ n_null = int (row ["n_null" ] or 0 )
449+
450+ return n , n_passed , n_failed , n_null
451+
452+
398453def _is_numeric_dtype (dtype : str ) -> bool :
399454 """
400455 Check if a given data type string represents a numeric type.
0 commit comments