Skip to content

Commit bddb9f2

Browse files
committed
fix: preserve timestamp scale and handle non-null list item types
Two bugs found by Codex review: 1. Timestamp scale was discarded: Arrow timestamp[ms] should map to dt.Timestamp(scale=3), not dt.Timestamp(). Add unit-to-scale mapping s=0, ms=3, us=6, ns=9, matching PyArrow convention. 2. Non-nullable list items mis-parsed: PyArrow emits 'list<item: int32 not null>' for non-nullable item fields. Strip the ' not null' suffix and pass nullable=False to the recursive call so the element type is correctly typed instead of falling back to Unknown.
1 parent 55b951e commit bddb9f2

2 files changed

Lines changed: 42 additions & 19 deletions

File tree

src/ibis_hotdata/types.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,23 @@
4545
_DECIMAL_RE = re.compile(r"^decimal128?\((\d+),\s*(\d+)\)$", re.IGNORECASE)
4646
_LIST_RE = re.compile(r"^(?:large_)?list<item:\s*(.+)>$", re.IGNORECASE)
4747

48-
# Map Arrow time-unit strings to Ibis IntervalUnit strings.
48+
# Map Arrow time-unit strings to Ibis IntervalUnit strings and Timestamp scales.
49+
# Scales follow PyArrow's convention: s→0, ms→3, us→6, ns→9.
4950
_ARROW_UNIT_TO_IBIS: dict[str, str] = {
5051
"s": "s",
5152
"ms": "ms",
5253
"us": "us",
5354
"ns": "ns",
5455
}
56+
_ARROW_UNIT_TO_TIMESTAMP_SCALE: dict[str, int] = {
57+
"s": 0,
58+
"ms": 3,
59+
"us": 6,
60+
"ns": 9,
61+
}
62+
63+
# Suffix appended by PyArrow when a list's item field is non-nullable.
64+
_NOT_NULL_SUFFIX_RE = re.compile(r"\s+not\s+null$", re.IGNORECASE)
5565

5666

5767
def _parse_parametric_arrow_type(raw: str, *, nullable: bool) -> dt.DataType | None:
@@ -62,8 +72,10 @@ def _parse_parametric_arrow_type(raw: str, *, nullable: bool) -> dt.DataType | N
6272
"""
6373
m = _TIMESTAMP_RE.match(raw)
6474
if m:
75+
unit = m.group(1).lower()
6576
tz: str | None = m.group(2).strip() if m.group(2) else None
66-
return dt.Timestamp(timezone=tz, nullable=nullable)
77+
scale: int | None = _ARROW_UNIT_TO_TIMESTAMP_SCALE.get(unit)
78+
return dt.Timestamp(timezone=tz, scale=scale, nullable=nullable)
6779

6880
m = _DURATION_RE.match(raw)
6981
if m:
@@ -76,7 +88,12 @@ def _parse_parametric_arrow_type(raw: str, *, nullable: bool) -> dt.DataType | N
7688

7789
m = _LIST_RE.match(raw)
7890
if m:
79-
value_type = dtype_from_hotdata_sql_type(m.group(1).strip(), nullable=True)
91+
item_raw = m.group(1).strip()
92+
# PyArrow appends " not null" for non-nullable item fields; strip it and
93+
# pass nullable=False so the element type is marked non-nullable.
94+
item_not_null = bool(_NOT_NULL_SUFFIX_RE.search(item_raw))
95+
item_str = _NOT_NULL_SUFFIX_RE.sub("", item_raw).strip()
96+
value_type = dtype_from_hotdata_sql_type(item_str, nullable=not item_not_null)
8097
return dt.Array(value_type=value_type, nullable=nullable)
8198

8299
return None

tests/test_hotdata_types.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,22 @@ def test_dtype_from_hotdata_arrow_type_names(sql_type, nullable, expected_cls):
7777

7878

7979
@pytest.mark.parametrize(
80-
("sql_type", "expected_tz"),
80+
("sql_type", "expected_tz", "expected_scale"),
8181
[
82-
("timestamp[s]", None),
83-
("timestamp[ms]", None),
84-
("timestamp[us]", None),
85-
("timestamp[ns]", None),
86-
("timestamp[us, tz=UTC]", "UTC"),
87-
("timestamp[us, tz=America/New_York]", "America/New_York"),
88-
("TIMESTAMP[US]", None),
82+
("timestamp[s]", None, 0),
83+
("timestamp[ms]", None, 3),
84+
("timestamp[us]", None, 6),
85+
("timestamp[ns]", None, 9),
86+
("timestamp[us, tz=UTC]", "UTC", 6),
87+
("timestamp[us, tz=America/New_York]", "America/New_York", 6),
88+
("TIMESTAMP[MS]", None, 3),
8989
],
9090
)
91-
def test_dtype_from_hotdata_arrow_timestamp(sql_type, expected_tz):
91+
def test_dtype_from_hotdata_arrow_timestamp(sql_type, expected_tz, expected_scale):
9292
out = dtype_from_hotdata_sql_type(sql_type, nullable=True)
9393
assert isinstance(out, dt.Timestamp)
9494
assert out.timezone == expected_tz
95+
assert out.scale == expected_scale
9596
assert out.nullable is True
9697

9798

@@ -130,17 +131,22 @@ def test_dtype_from_hotdata_arrow_decimal(sql_type, expected_precision, expected
130131

131132

132133
@pytest.mark.parametrize(
133-
("sql_type", "expected_value_cls"),
134+
("sql_type", "expected_value_cls", "expected_item_nullable"),
134135
[
135-
("list<item: int32>", dt.Int32),
136-
("list<item: utf8>", dt.String),
137-
("list<item: float64>", dt.Float64),
138-
("large_list<item: int64>", dt.Int64),
139-
("LIST<ITEM: UINT8>", dt.UInt8),
136+
("list<item: int32>", dt.Int32, True),
137+
("list<item: utf8>", dt.String, True),
138+
("list<item: float64>", dt.Float64, True),
139+
("large_list<item: int64>", dt.Int64, True),
140+
("LIST<ITEM: UINT8>", dt.UInt8, True),
141+
# Non-nullable item fields — PyArrow appends " not null"
142+
("list<item: int32 not null>", dt.Int32, False),
143+
("list<item: utf8 not null>", dt.String, False),
144+
("large_list<item: float32 not null>", dt.Float32, False),
140145
],
141146
)
142-
def test_dtype_from_hotdata_arrow_list(sql_type, expected_value_cls):
147+
def test_dtype_from_hotdata_arrow_list(sql_type, expected_value_cls, expected_item_nullable):
143148
out = dtype_from_hotdata_sql_type(sql_type, nullable=True)
144149
assert isinstance(out, dt.Array)
145150
assert isinstance(out.value_type, expected_value_cls)
151+
assert out.value_type.nullable is expected_item_nullable
146152
assert out.nullable is True

0 commit comments

Comments
 (0)