Skip to content

Commit 76dee1c

Browse files
committed
Add unit tests and simplify python wrapper for literal
1 parent 87a0e30 commit 76dee1c

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

python/datafusion/expr.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,6 @@ def literal(value: Any) -> Expr:
562562
"""
563563
if isinstance(value, str):
564564
value = pa.scalar(value, type=pa.string_view())
565-
if not isinstance(value, pa.Scalar):
566-
value = pa.scalar(value)
567565
return Expr(expr_internal.RawExpr.literal(value))
568566

569567
@staticmethod
@@ -576,7 +574,6 @@ def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
576574
"""
577575
if isinstance(value, str):
578576
value = pa.scalar(value, type=pa.string_view())
579-
value = value if isinstance(value, pa.Scalar) else pa.scalar(value)
580577

581578
return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata))
582579

python/tests/test_expr.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from datetime import date, datetime, time, timezone
2121
from decimal import Decimal
2222

23+
import arro3.core
24+
import nanoarrow
2325
import pyarrow as pa
2426
import pytest
2527
from datafusion import (
@@ -980,6 +982,34 @@ def test_literal_metadata(ctx):
980982
assert expected_field.metadata == actual_field.metadata
981983

982984

985+
def test_scalar_conversion() -> None:
986+
expected_value = lit(1)
987+
assert str(expected_value) == "Expr(Int64(1))"
988+
989+
# Test pyarrow imports
990+
assert expected_value == lit(pa.scalar(1))
991+
assert expected_value == lit(pa.scalar(1, type=pa.int32()))
992+
993+
# Test nanoarrow
994+
na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0]
995+
assert expected_value == lit(na_scalar)
996+
997+
# Test pyo3
998+
arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32())
999+
assert expected_value == lit(arro3_scalar)
1000+
1001+
expected_value = lit([1, 2, 3])
1002+
assert str(expected_value) == "Expr(List([1, 2, 3]))"
1003+
1004+
assert expected_value == lit(pa.scalar([1, 2, 3]))
1005+
1006+
na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32())
1007+
assert expected_value == lit(na_array)
1008+
1009+
arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32())
1010+
assert expected_value == lit(arro3_array)
1011+
1012+
9831013
def test_ensure_expr():
9841014
e = col("a")
9851015
assert ensure_expr(e) is e.expr

0 commit comments

Comments
 (0)