|
4 | 4 |
|
5 | 5 | import sys |
6 | 6 | from copy import deepcopy |
7 | | -from typing import TYPE_CHECKING |
| 7 | +from typing import TYPE_CHECKING, Any |
8 | 8 | from unittest import mock |
9 | 9 |
|
10 | 10 | import pytest |
@@ -345,3 +345,97 @@ def test_duckdb_engine_polars_no_pyarrow( |
345 | 345 | result.collect() if expected_type_name == "LazyFrame" else result |
346 | 346 | ) |
347 | 347 | assert len(materialized) == 4 |
| 348 | + |
| 349 | + |
| 350 | +class _RelationProxy: |
| 351 | + # _duckdb.DuckDBPyRelation is a pybind class and rejects attribute |
| 352 | + # assignment, so we wrap it to intercept .pl(). |
| 353 | + def __init__(self, relation: Any, pl_override: Any) -> None: |
| 354 | + self._relation = relation |
| 355 | + self._pl_override = pl_override |
| 356 | + |
| 357 | + def pl(self, *args: Any, **kwargs: Any) -> Any: |
| 358 | + return self._pl_override(self._relation, *args, **kwargs) |
| 359 | + |
| 360 | + def __getattr__(self, name: str) -> Any: |
| 361 | + return getattr(self._relation, name) |
| 362 | + |
| 363 | + |
| 364 | +def _run_with_pl_spy( |
| 365 | + duckdb_connection: duckdb.DuckDBPyConnection, |
| 366 | + pl_impl: Any, |
| 367 | +) -> tuple[Any, list[dict]]: |
| 368 | + """Execute a lazy-polars query with `pl_impl` wrapping the real `pl()`.""" |
| 369 | + from marimo._sql.engines import duckdb as duckdb_engine_mod |
| 370 | + |
| 371 | + pl_calls: list[dict] = [] |
| 372 | + real_wrapped_sql = duckdb_engine_mod.wrapped_sql |
| 373 | + |
| 374 | + def spy(relation: Any, *args: Any, **kwargs: Any) -> Any: |
| 375 | + pl_calls.append(kwargs) |
| 376 | + return pl_impl(relation, *args, **kwargs) |
| 377 | + |
| 378 | + def spy_wrapped_sql(query: str, connection: Any) -> Any: |
| 379 | + return _RelationProxy(real_wrapped_sql(query, connection), spy) |
| 380 | + |
| 381 | + with ( |
| 382 | + mock.patch.object( |
| 383 | + DuckDBEngine, "sql_output_format", return_value="lazy-polars" |
| 384 | + ), |
| 385 | + mock.patch.object( |
| 386 | + duckdb_engine_mod, "wrapped_sql", side_effect=spy_wrapped_sql |
| 387 | + ), |
| 388 | + ): |
| 389 | + engine = DuckDBEngine( |
| 390 | + duckdb_connection, engine_name=VariableName("test_duckdb") |
| 391 | + ) |
| 392 | + result = engine.execute("SELECT * FROM test ORDER BY id") |
| 393 | + |
| 394 | + return result, pl_calls |
| 395 | + |
| 396 | + |
| 397 | +@pytest.mark.skipif( |
| 398 | + not HAS_DUCKDB or not HAS_POLARS, |
| 399 | + reason="DuckDB and Polars not installed", |
| 400 | +) |
| 401 | +def test_duckdb_engine_lazy_polars_uses_streaming( |
| 402 | + duckdb_connection: duckdb.DuckDBPyConnection, |
| 403 | +) -> None: |
| 404 | + # Regression test for #9639: lazy-polars output must stream via |
| 405 | + # pl(lazy=True), not eagerly materialize then .lazy(). |
| 406 | + import polars as pl |
| 407 | + |
| 408 | + def pl_impl(relation: Any, *args: Any, **kwargs: Any) -> Any: |
| 409 | + return relation.pl(*args, **kwargs) |
| 410 | + |
| 411 | + result, pl_calls = _run_with_pl_spy(duckdb_connection, pl_impl) |
| 412 | + |
| 413 | + assert isinstance(result, pl.LazyFrame) |
| 414 | + assert len(result.collect()) == 4 |
| 415 | + assert pl_calls == [{"batch_size": 100_000, "lazy": True}] |
| 416 | + |
| 417 | + |
| 418 | +@pytest.mark.skipif( |
| 419 | + not HAS_DUCKDB or not HAS_POLARS, |
| 420 | + reason="DuckDB and Polars not installed", |
| 421 | +) |
| 422 | +def test_duckdb_engine_lazy_polars_falls_back_on_older_duckdb( |
| 423 | + duckdb_connection: duckdb.DuckDBPyConnection, |
| 424 | +) -> None: |
| 425 | + # Regression test for #9639: DuckDB <1.4 rejects the `lazy` kwarg, and |
| 426 | + # `pl(lazy=True)` also fails without pyarrow. Both must fall back to the |
| 427 | + # Arrow PyCapsule path. |
| 428 | + import polars as pl |
| 429 | + |
| 430 | + def pl_impl(relation: Any, *args: Any, **kwargs: Any) -> Any: |
| 431 | + if "lazy" in kwargs: |
| 432 | + raise TypeError("pl() got an unexpected keyword argument 'lazy'") |
| 433 | + return relation.pl(*args, **kwargs) |
| 434 | + |
| 435 | + result, pl_calls = _run_with_pl_spy(duckdb_connection, pl_impl) |
| 436 | + |
| 437 | + assert isinstance(result, pl.LazyFrame) |
| 438 | + assert len(result.collect()) == 4 |
| 439 | + # Only the first call (lazy=True, raises) reaches `pl`; the fallback uses |
| 440 | + # `to_polars()` (Arrow PyCapsule) and never touches `relation.pl()`. |
| 441 | + assert pl_calls == [{"batch_size": 100_000, "lazy": True}] |
0 commit comments