|
1 | | -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. |
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 | """Multi-partition base classes.""" |
4 | 4 |
|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | | -import dataclasses |
| 7 | +import abc |
8 | 8 | import enum |
9 | | -from collections import defaultdict |
10 | 9 | from enum import IntEnum |
11 | | -from functools import cached_property |
12 | | -from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar |
| 10 | +from typing import TYPE_CHECKING, Any |
13 | 11 |
|
14 | 12 | if TYPE_CHECKING: |
15 | | - from collections.abc import Generator, Iterator, MutableMapping |
| 13 | + from collections.abc import Generator, Iterator |
16 | 14 |
|
17 | 15 | from cudf_polars.dsl.expr import NamedExpr |
18 | 16 | from cudf_polars.dsl.ir import IR |
@@ -57,353 +55,35 @@ def get_key_name(node: Node) -> str: |
57 | 55 | return f"{type(node).__name__.lower()}-{hash(node)}" |
58 | 56 |
|
59 | 57 |
|
60 | | -T = TypeVar("T") |
61 | | - |
62 | | - |
63 | | -@dataclasses.dataclass |
64 | | -class ColumnStat(Generic[T]): |
65 | | - """ |
66 | | - Generic column-statistic. |
67 | | -
|
68 | | - Parameters |
69 | | - ---------- |
70 | | - value |
71 | | - Statistics value. Value will be None |
72 | | - if the statistics is unknown. |
73 | | - exact |
74 | | - Whether the statistics is known exactly. |
75 | | - """ |
76 | | - |
77 | | - value: T | None = None |
78 | | - exact: bool = False |
79 | | - |
80 | | - |
81 | | -@dataclasses.dataclass |
82 | | -class UniqueStats: |
83 | | - """ |
84 | | - Sampled unique-value statistics. |
85 | | -
|
86 | | - Parameters |
87 | | - ---------- |
88 | | - count |
89 | | - Unique-value count. |
90 | | - fraction |
91 | | - Unique-value fraction. This corresponds to the total |
92 | | - number of unique values (count) divided by the total |
93 | | - number of rows. |
94 | | -
|
95 | | - Notes |
96 | | - ----- |
97 | | - This class is used to track unique-value column statistics |
98 | | - that have been sampled from a data source. |
99 | | - """ |
100 | | - |
101 | | - count: ColumnStat[int] = dataclasses.field(default_factory=ColumnStat[int]) |
102 | | - fraction: ColumnStat[float] = dataclasses.field(default_factory=ColumnStat[float]) |
103 | | - |
104 | | - |
105 | | -class DataSourceInfo: |
| 58 | +class DataSourceInfo(abc.ABC): |
106 | 59 | """ |
107 | 60 | Table data source information. |
108 | 61 |
|
109 | 62 | Notes |
110 | 63 | ----- |
111 | | - This class should be sub-classed for specific |
112 | | - data source types (e.g. Parquet, DataFrame, etc.). |
113 | | - The required properties/methods enable lazy |
114 | | - sampling of the underlying datasource. |
115 | | - """ |
116 | | - |
117 | | - _unique_stats_columns: set[str] |
118 | | - _read_columns: set[str] |
119 | | - |
120 | | - @property |
121 | | - def row_count(self) -> ColumnStat[int]: # pragma: no cover |
122 | | - """Data source row-count estimate.""" |
123 | | - raise NotImplementedError("Sub-class must implement row_count.") |
124 | | - |
125 | | - def unique_stats( |
126 | | - self, |
127 | | - column: str, |
128 | | - ) -> UniqueStats: # pragma: no cover |
129 | | - """Return unique-value statistics for a column.""" |
130 | | - raise NotImplementedError("Sub-class must implement unique_stats.") |
131 | | - |
132 | | - def storage_size(self, column: str) -> ColumnStat[int]: |
133 | | - """Return the average column size for a single file.""" |
134 | | - return ColumnStat[int]() |
135 | | - |
136 | | - @property |
137 | | - def unique_stats_columns(self) -> set[str]: |
138 | | - """Return the set of columns needing unique-value information.""" |
139 | | - return self._unique_stats_columns |
140 | | - |
141 | | - def add_unique_stats_column(self, column: str) -> None: |
142 | | - """Add a column needing unique-value information.""" |
143 | | - self._unique_stats_columns.add(column) |
144 | | - |
145 | | - def add_read_column(self, column: str) -> None: |
146 | | - """Add a column needing to be read.""" |
147 | | - self._read_columns.add(column) |
148 | | - |
149 | | - |
150 | | -class DataSourcePair(NamedTuple): |
151 | | - """Pair of table-source and column-name information.""" |
152 | | - |
153 | | - table_source: DataSourceInfo |
154 | | - column_name: str |
155 | | - |
156 | | - |
157 | | -class ColumnSourceInfo: |
158 | | - """ |
159 | | - Source column information. |
160 | | -
|
161 | | - Parameters |
162 | | - ---------- |
163 | | - table_source_pairs |
164 | | - Sequence of DataSourcePair objects. |
165 | | - Union operations will result in multiple elements. |
166 | | -
|
167 | | - Notes |
168 | | - ----- |
169 | | - This is a thin wrapper around DataSourceInfo that provides |
170 | | - direct access to column-specific information. |
| 64 | + Sub-class for specific data source types (e.g. Parquet, DataFrame). |
171 | 65 | """ |
172 | 66 |
|
173 | | - __slots__ = ( |
174 | | - "implied_unique_count", |
175 | | - "table_source_pairs", |
176 | | - ) |
177 | | - table_source_pairs: list[DataSourcePair] |
178 | | - implied_unique_count: ColumnStat[int] |
179 | | - """Unique-value count implied by join heuristics.""" |
180 | | - |
181 | | - def __init__(self, *table_source_pairs: DataSourcePair) -> None: |
182 | | - self.table_source_pairs = list(table_source_pairs) |
183 | | - self.implied_unique_count = ColumnStat[int](None) |
184 | | - |
185 | 67 | @property |
186 | | - def is_unique_stats_column(self) -> bool: |
187 | | - """Return whether this column requires unique-value information.""" |
188 | | - return any( |
189 | | - pair.column_name in pair.table_source.unique_stats_columns |
190 | | - for pair in self.table_source_pairs |
191 | | - ) |
192 | | - |
193 | | - @property |
194 | | - def row_count(self) -> ColumnStat[int]: |
| 68 | + @abc.abstractmethod |
| 69 | + def row_count(self) -> int | None: |
195 | 70 | """Data source row-count estimate.""" |
196 | | - return ColumnStat[int]( |
197 | | - # Use sum of table-source row-count estimates. |
198 | | - value=sum( |
199 | | - value |
200 | | - for pair in self.table_source_pairs |
201 | | - if (value := pair.table_source.row_count.value) is not None |
202 | | - ) |
203 | | - or None, |
204 | | - # Row-count may be exact if there is only one table source. |
205 | | - exact=len(self.table_source_pairs) == 1 |
206 | | - and self.table_source_pairs[0].table_source.row_count.exact, |
207 | | - ) |
208 | | - |
209 | | - def unique_stats(self, *, force: bool = False) -> UniqueStats: |
210 | | - """ |
211 | | - Return unique-value statistics for a column. |
212 | 71 |
|
213 | | - Parameters |
214 | | - ---------- |
215 | | - force |
216 | | - If True, return unique-value statistics even if the column |
217 | | - wasn't marked as needing unique-value information. |
218 | | - """ |
219 | | - if (force or self.is_unique_stats_column) and len(self.table_source_pairs) == 1: |
220 | | - # Single table source. |
221 | | - # TODO: Handle multiple tables sources if/when necessary. |
222 | | - # We may never need to do this if the source unique-value |
223 | | - # statistics are only "used" by the Scan/DataFrameScan nodes. |
224 | | - table_source, column_name = self.table_source_pairs[0] |
225 | | - return table_source.unique_stats(column_name) |
226 | | - else: |
227 | | - # Avoid sampling unique-stats if this column |
228 | | - # wasn't marked as "needing" unique-stats. |
229 | | - return UniqueStats() |
230 | | - |
231 | | - @property |
232 | | - def storage_size(self) -> ColumnStat[int]: |
233 | | - """Return the average column size for a single file.""" |
234 | | - # We don't need to handle concatenated statistics for ``storage_size``. |
235 | | - # Just return the storage size of the first table source. |
236 | | - if self.table_source_pairs: |
237 | | - table_source, column_name = self.table_source_pairs[0] |
238 | | - return table_source.storage_size(column_name) |
239 | | - else: # pragma: no cover; We never call this for empty table sources. |
240 | | - return ColumnStat[int]() |
241 | | - |
242 | | - def add_unique_stats_column(self, column: str | None = None) -> None: |
243 | | - """Add a column needing unique-value information.""" |
244 | | - # We must call add_unique_stats_column for ALL table sources. |
245 | | - for table_source, column_name in self.table_source_pairs: |
246 | | - table_source.add_unique_stats_column(column or column_name) |
247 | | - |
248 | | - def add_read_column(self, column: str | None = None) -> None: |
249 | | - """Add a column needing to be read.""" |
250 | | - for table_source, column_name in self.table_source_pairs: |
251 | | - table_source.add_read_column(column or column_name) |
252 | | - |
253 | | - |
254 | | -class ColumnStats: |
255 | | - """ |
256 | | - Column statistics. |
257 | | -
|
258 | | - Parameters |
259 | | - ---------- |
260 | | - name |
261 | | - Column name. |
262 | | - children |
263 | | - Child ColumnStats objects. |
264 | | - source_info |
265 | | - Column source information. |
266 | | - unique_count |
267 | | - Unique-value count. |
268 | | - """ |
269 | | - |
270 | | - __slots__ = ("children", "name", "source_info", "unique_count") |
271 | | - |
272 | | - name: str |
273 | | - children: tuple[ColumnStats, ...] |
274 | | - source_info: ColumnSourceInfo |
275 | | - unique_count: ColumnStat[int] |
276 | | - |
277 | | - def __init__( |
278 | | - self, |
279 | | - name: str, |
280 | | - *, |
281 | | - children: tuple[ColumnStats, ...] = (), |
282 | | - source_info: ColumnSourceInfo | None = None, |
283 | | - unique_count: ColumnStat[int] | None = None, |
284 | | - ) -> None: |
285 | | - self.name = name |
286 | | - self.children = children |
287 | | - self.source_info = source_info or ColumnSourceInfo() |
288 | | - self.unique_count = unique_count or ColumnStat[int](None) |
289 | | - |
290 | | - def new_parent( |
291 | | - self, |
292 | | - *, |
293 | | - name: str | None = None, |
294 | | - ) -> ColumnStats: |
295 | | - """ |
296 | | - Initialize a new parent ColumnStats object. |
297 | | -
|
298 | | - Parameters |
299 | | - ---------- |
300 | | - name |
301 | | - The new column name. |
302 | | -
|
303 | | - Returns |
304 | | - ------- |
305 | | - A new ColumnStats object. |
306 | | -
|
307 | | - Notes |
308 | | - ----- |
309 | | - This API preserves the original DataSourceInfo reference. |
310 | | - """ |
311 | | - return ColumnStats( |
312 | | - name=name or self.name, |
313 | | - children=(self,), |
314 | | - # Want to reference the same DataSourceInfo |
315 | | - source_info=self.source_info, |
316 | | - ) |
317 | | - |
318 | | - |
319 | | -class JoinKey: |
320 | | - """ |
321 | | - Join-key information. |
322 | | -
|
323 | | - Parameters |
324 | | - ---------- |
325 | | - column_stats |
326 | | - Column statistics for the join key. |
327 | | -
|
328 | | - Notes |
329 | | - ----- |
330 | | - This class is used to track join-key information. |
331 | | - It is used to track the columns being joined on |
332 | | - and the estimated unique-value count for the join key. |
333 | | - """ |
334 | | - |
335 | | - column_stats: tuple[ColumnStats, ...] |
336 | | - implied_unique_count: int | None |
337 | | - """Estimated unique-value count from join heuristics.""" |
338 | | - |
339 | | - def __init__(self, *column_stats: ColumnStats) -> None: |
340 | | - self.column_stats = column_stats |
341 | | - self.implied_unique_count = None |
342 | | - |
343 | | - @cached_property |
344 | | - def source_row_count(self) -> int | None: |
345 | | - """ |
346 | | - Return the estimated row-count of the source columns. |
347 | | -
|
348 | | - Notes |
349 | | - ----- |
350 | | - This is the maximum row-count estimate of the source columns. |
351 | | - """ |
352 | | - return max( |
353 | | - ( |
354 | | - cs.source_info.row_count.value |
355 | | - for cs in self.column_stats |
356 | | - if cs.source_info.row_count.value is not None |
357 | | - ), |
358 | | - default=None, |
359 | | - ) |
360 | | - |
361 | | - |
362 | | -class JoinInfo: |
363 | | - """ |
364 | | - Join information. |
365 | | -
|
366 | | - Notes |
367 | | - ----- |
368 | | - This class is used to track mappings between joined-on |
369 | | - columns and joined-on keys (groups of columns). We need |
370 | | - these mappings to calculate equivalence sets and make |
371 | | - join-based unique-count and row-count estimates. |
372 | | - """ |
373 | | - |
374 | | - __slots__ = ("column_map", "join_map", "key_map") |
375 | | - |
376 | | - column_map: MutableMapping[ColumnStats, set[ColumnStats]] |
377 | | - """Mapping between joined columns.""" |
378 | | - key_map: MutableMapping[JoinKey, set[JoinKey]] |
379 | | - """Mapping between joined keys (groups of columns).""" |
380 | | - join_map: dict[IR, list[JoinKey]] |
381 | | - """Mapping between IR nodes and associated join keys.""" |
382 | | - |
383 | | - def __init__(self) -> None: |
384 | | - self.column_map: MutableMapping[ColumnStats, set[ColumnStats]] = defaultdict( |
385 | | - set[ColumnStats] |
386 | | - ) |
387 | | - self.key_map: MutableMapping[JoinKey, set[JoinKey]] = defaultdict(set[JoinKey]) |
388 | | - self.join_map: dict[IR, list[JoinKey]] = {} |
| 72 | + def column_storage_size(self, column: str) -> int | None: |
| 73 | + """Return the average storage size for a single column in one file.""" |
| 74 | + return None |
389 | 75 |
|
390 | 76 |
|
391 | 77 | class StatsCollector: |
392 | | - """Column statistics collector.""" |
| 78 | + """Scan statistics collector.""" |
393 | 79 |
|
394 | | - __slots__ = ("column_stats", "join_info", "row_count") |
| 80 | + __slots__ = ("scan_stats",) |
395 | 81 |
|
396 | | - row_count: dict[IR, ColumnStat[int]] |
397 | | - """Estimated row count for each IR node.""" |
398 | | - column_stats: dict[IR, dict[str, ColumnStats]] |
399 | | - """Column statistics for each IR node.""" |
400 | | - join_info: JoinInfo |
401 | | - """Join information.""" |
| 82 | + scan_stats: dict[IR, DataSourceInfo] |
| 83 | + """DataSourceInfo for each leaf Scan/DataFrameScan node.""" |
402 | 84 |
|
403 | 85 | def __init__(self) -> None: |
404 | | - self.row_count: dict[IR, ColumnStat[int]] = {} |
405 | | - self.column_stats: dict[IR, dict[str, ColumnStats]] = {} |
406 | | - self.join_info = JoinInfo() |
| 86 | + self.scan_stats: dict[IR, DataSourceInfo] = {} |
407 | 87 |
|
408 | 88 |
|
409 | 89 | class IOPartitionFlavor(IntEnum): |
|
0 commit comments