Skip to content

Commit f87d6d9

Browse files
Allow self-joining of Polars lazyframes (#466)
Fixes #452 This makes sure that Polars' CSE wraps the lazyframe in a `CACHE[...]` in its plan. This way, the lazyframe will only be consumed once. We've seen this reported more often. The issue with joining a streaming result against itself is that a result can only be consumed once. One workaround is using multiple connections, but this only works if duckdb comes up with a query plan where a result is not referenced against itself. As it turns out, Polars allows I/O plugins to register with [`is_pure`](https://docs.pola.rs/api/python/dev/reference/api/polars.io.plugins.register_io_source.html), and Polars will de-duplicate them. Afaict, there's not really any downside. I don't see why anybody would want to treat a streaming result as not-pure, or what that would mean. I guess we'll see if that assumption will pass the test of time.
2 parents 559f6af + c88229d commit f87d6d9

3 files changed

Lines changed: 71 additions & 4 deletions

File tree

duckdb/polars_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,4 +308,4 @@ def source_generator(
308308
else:
309309
yield pl.from_arrow(record_batch) # type: ignore[misc,unused-ignore]
310310

311-
return register_io_source(source_generator, schema=schema)
311+
return register_io_source(source_generator, schema=schema, is_pure=True)

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ stubdeps = [ # dependencies used for typehints in the stubs
229229
"mypy",
230230
"fsspec",
231231
"pandas",
232-
"polars",
232+
"polars>=1.33.0",
233233
"pyarrow; sys_platform != 'win32' or platform_machine != 'ARM64'",
234234
"typing-extensions",
235235
]
@@ -244,7 +244,7 @@ test = [ # dependencies used for running tests
244244
"gcovr; sys_platform != 'win32' or platform_machine != 'ARM64'",
245245
"gcsfs; sys_platform != 'win32' or platform_machine != 'ARM64'",
246246
"packaging",
247-
"polars",
247+
"polars>=1.33.0",
248248
"psutil",
249249
"py4j",
250250
"pyotp",
@@ -276,7 +276,7 @@ scripts = [ # dependencies used for running scripts
276276
"numpy",
277277
"pandas",
278278
"pcpp",
279-
"polars",
279+
"polars>=1.33.0",
280280
"pyarrow; sys_platform != 'win32' or platform_machine != 'ARM64'",
281281
"pytz"
282282
]
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Regression test for duckdb-python issue #452.
2+
3+
Silent row drop when two `db.sql(query).pl(lazy=True)` LazyFrames are joined,
4+
the result is self-rejoined to derive grouping keys, a window expression is
5+
applied downstream, and the plan is collected via `engine="streaming"`.
6+
7+
The streaming output is clamped to ~10.3M rows regardless of input size — at
8+
20K / 30K / 50K variable-length groups (20M / 30M / 50M input rows) the
9+
streaming output is ~10.30M / ~10.30M / ~10.31M.
10+
11+
This test is intentionally heavy: it must cross the bug threshold (>10M rows)
12+
to trigger the failure. Runs in ~30 seconds at N_GROUPS=20_000.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
import pytest
20+
21+
import duckdb
22+
23+
if TYPE_CHECKING:
24+
from pathlib import Path
25+
26+
pl = pytest.importorskip("polars")
27+
np = pytest.importorskip("numpy")
28+
pytest.importorskip("pyarrow")
29+
30+
31+
def test_452_polars_streaming_self_rejoin_does_not_drop_rows(tmp_path: Path) -> None:
32+
n_groups = 20_000
33+
rng = np.random.default_rng(42)
34+
group_lens = np.clip(rng.lognormal(mean=6.8, sigma=0.5, size=n_groups).astype(int), 30, 2900)
35+
g = np.repeat(np.arange(n_groups, dtype=np.int32), group_lens)
36+
t = np.concatenate([np.arange(n, dtype=np.int32) for n in group_lens])
37+
x = rng.uniform(-1.0, 1.0, int(group_lens.sum())).astype(np.float32)
38+
39+
left_path = tmp_path / "left.parquet"
40+
right_path = tmp_path / "right.parquet"
41+
pl.DataFrame({"g": g, "t": t, "x": x}).write_parquet(left_path, row_group_size=200_000)
42+
pl.DataFrame({"g": g, "t": t}).write_parquet(right_path, row_group_size=200_000)
43+
del g, t, x
44+
45+
def build(left_lf: pl.LazyFrame, right_lf: pl.LazyFrame) -> pl.LazyFrame:
46+
joined = left_lf.join(right_lf, on=["g", "t"], how="inner")
47+
keys = joined.select("g").unique()
48+
plan = joined.join(keys, on="g")
49+
return plan.sort(["g", "t"]).select(
50+
"g",
51+
"t",
52+
pl.col("x").rolling_sum(window_size=100).over("g").alias("y"),
53+
)
54+
55+
ref = build(pl.scan_parquet(left_path), pl.scan_parquet(right_path)).collect()
56+
57+
db_l = duckdb.connect(":memory:")
58+
db_r = duckdb.connect(":memory:")
59+
try:
60+
left_lf = db_l.sql(f"select * from read_parquet('{left_path}')").pl(lazy=True)
61+
right_lf = db_r.sql(f"select * from read_parquet('{right_path}')").pl(lazy=True)
62+
out = build(left_lf, right_lf).collect(engine="streaming")
63+
finally:
64+
db_l.close()
65+
db_r.close()
66+
67+
assert out.shape == ref.shape, f"streaming output dropped rows: got {out.shape}, expected {ref.shape}"

0 commit comments

Comments
 (0)