Skip to content

Commit c8cc00e

Browse files
authored
Fix stable ID for DataFrameScan (rapidsai#22091)
cuDF-Polars tracing infrastructure relies on the concept of "stable ID" for each IR node. This ID is **not** currently stable for `DataFrameScan`. This PR adds a simple `_expand_hashable` helper utility to correct this. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) URL: rapidsai#22091
1 parent 903ec6f commit c8cc00e

2 files changed

Lines changed: 32 additions & 1 deletion

File tree

python/cudf_polars/cudf_polars/dsl/nodebase.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818
T = TypeVar("T", bound="Node[Any]")
1919

2020

21+
def _expand_hashable(obj: Any) -> Any:
22+
"""Expand nested Node instances to their hashable form."""
23+
if isinstance(obj, Node):
24+
return _expand_hashable(obj.get_hashable())
25+
elif isinstance(obj, tuple):
26+
return tuple(_expand_hashable(x) for x in obj)
27+
return obj
28+
29+
2130
class Node(Generic[T]):
2231
"""
2332
An abstract node type.
@@ -103,7 +112,7 @@ def get_stable_id(self) -> int:
103112
try:
104113
return self._stable_hash_value
105114
except AttributeError:
106-
content = repr(self.get_hashable()).encode("utf-8")
115+
content = repr(_expand_hashable(self)).encode("utf-8")
107116
self._stable_hash_value = int(hashlib.md5(content).hexdigest()[:8], 16)
108117
return self._stable_hash_value
109118

python/cudf_polars/tests/experimental/test_dataframescan.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,17 @@
1010
import polars as pl
1111

1212
from cudf_polars import Translator
13+
from cudf_polars.dsl.traversal import traversal
1314
from cudf_polars.experimental.parallel import lower_ir_graph
1415
from cudf_polars.testing.asserts import assert_gpu_result_equal
1516
from cudf_polars.utils.config import ConfigOptions
1617

1718

19+
def _assert_stable_ids_match(orig, loaded) -> None:
20+
for a, b in zip(traversal([orig]), traversal([loaded]), strict=True):
21+
assert a.get_stable_id() == b.get_stable_id()
22+
23+
1824
@pytest.fixture(scope="module")
1925
def df():
2026
return pl.LazyFrame(
@@ -63,6 +69,21 @@ def test_dataframescan_concat(df, engine):
6369
assert_gpu_result_equal(df2, engine=engine)
6470

6571

72+
def test_join_in_memory_lazy_stable_id_pickle():
73+
engine = pl.GPUEngine(
74+
raise_on_fail=True,
75+
executor="streaming",
76+
executor_options={"max_rows_per_partition": 1_000},
77+
)
78+
left = pl.LazyFrame({"k": [1, 2, 3], "x": [10, 20, 30]}).collect().lazy()
79+
right = pl.LazyFrame({"k": [2, 3, 4], "y": [1, 2, 3]}).collect().lazy()
80+
ir, _, _ = lower_ir_graph(
81+
Translator(left.join(right, on="k")._ldf.visit(), engine).translate_ir(),
82+
ConfigOptions.from_polars_engine(engine),
83+
)
84+
_assert_stable_ids_match(ir, pickle.loads(pickle.dumps(ir)))
85+
86+
6687
def test_dataframescan_pickle(df):
6788
_engine = pl.GPUEngine(
6889
raise_on_fail=True,
@@ -79,3 +100,4 @@ def test_dataframescan_pickle(df):
79100
# Verify the unpickled IR is equivalent
80101
assert type(unpickled_ir) is type(ir)
81102
assert unpickled_ir.schema == ir.schema
103+
_assert_stable_ids_match(ir, unpickled_ir)

0 commit comments

Comments
 (0)