|
20 | 20 | from datetime import date, datetime, time, timezone |
21 | 21 | from decimal import Decimal |
22 | 22 |
|
| 23 | +import arro3.core |
| 24 | +import nanoarrow |
23 | 25 | import pyarrow as pa |
24 | 26 | import pytest |
25 | 27 | from datafusion import ( |
@@ -980,6 +982,34 @@ def test_literal_metadata(ctx): |
980 | 982 | assert expected_field.metadata == actual_field.metadata |
981 | 983 |
|
982 | 984 |
|
| 985 | +def test_scalar_conversion() -> None: |
| 986 | + expected_value = lit(1) |
| 987 | + assert str(expected_value) == "Expr(Int64(1))" |
| 988 | + |
| 989 | + # Test pyarrow imports |
| 990 | + assert expected_value == lit(pa.scalar(1)) |
| 991 | + assert expected_value == lit(pa.scalar(1, type=pa.int32())) |
| 992 | + |
| 993 | + # Test nanoarrow |
| 994 | + na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0] |
| 995 | + assert expected_value == lit(na_scalar) |
| 996 | + |
| 997 | + # Test pyo3 |
| 998 | + arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32()) |
| 999 | + assert expected_value == lit(arro3_scalar) |
| 1000 | + |
| 1001 | + expected_value = lit([1, 2, 3]) |
| 1002 | + assert str(expected_value) == "Expr(List([1, 2, 3]))" |
| 1003 | + |
| 1004 | + assert expected_value == lit(pa.scalar([1, 2, 3])) |
| 1005 | + |
| 1006 | + na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32()) |
| 1007 | + assert expected_value == lit(na_array) |
| 1008 | + |
| 1009 | + arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32()) |
| 1010 | + assert expected_value == lit(arro3_array) |
| 1011 | + |
| 1012 | + |
983 | 1013 | def test_ensure_expr(): |
984 | 1014 | e = col("a") |
985 | 1015 | assert ensure_expr(e) is e.expr |
|
0 commit comments