Skip to content

Commit 55b951e

Browse files
committed
feat: complete Arrow type mapping for parametric and missing types
Add support for all common Arrow type strings returned by Hotdata's information_schema when tables are loaded from Parquet/Arrow sources. Simple additions to _ARROW_TYPE_MAP: - int8 → dt.Int8 (Postgres parser wrongly returns Int64 for "int8") - halffloat → dt.Float16 (PyArrow's str() name for float16 - large_string → dt.String (PyArrow large-offset string variant) New _parse_parametric_arrow_type() for types with embedded parameters: - timestamp[us] / timestamp[us, tz=UTC] → dt.Timestamp(timezone=...) - duration[ms] → dt.Interval(unit='ms') - decimal128(10, 3) / decimal(5, 2) → dt.Decimal(precision, scale) - list<item: T> / large_list<item: T> → dt.Array(T) (recursive) Adds 27 new test cases covering all new patterns including timezone variants, all 4 time units, case-insensitivity, and recursive list element type resolution. EOF )
1 parent 185fe7a commit 55b951e

2 files changed

Lines changed: 142 additions & 5 deletions

File tree

src/ibis_hotdata/types.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import re
6+
57
import ibis.expr.datatypes as dt
68
from ibis.backends.sql.datatypes import PostgresType
79

@@ -12,43 +14,99 @@
1214
# dates
1315
"date32": dt.Date,
1416
"date64": dt.Date,
15-
# floats
17+
# floats — "halffloat" is PyArrow's str() name for float16
1618
"float16": dt.Float16,
1719
"float32": dt.Float32,
1820
"float64": dt.Float64,
21+
"halffloat": dt.Float16,
22+
# signed ints — must override Postgres parser: Postgres "int8" means 8-byte (64-bit),
23+
# but Arrow "int8" means 8-bit. int16/32/64 parse correctly via Postgres.
24+
"int8": dt.Int8,
1925
# unsigned ints
2026
"uint8": dt.UInt8,
2127
"uint16": dt.UInt16,
2228
"uint32": dt.UInt32,
2329
"uint64": dt.UInt64,
24-
# strings
30+
# strings — "large_string" / "largeutf8" are PyArrow large-offset variants
2531
"utf8": dt.String,
2632
"largeutf8": dt.String,
33+
"large_string": dt.String,
2734
# binary
2835
"largebinary": dt.Binary,
2936
# time
3037
"time32": dt.Time,
3138
"time64": dt.Time,
3239
}
3340

41+
# Regex patterns for Arrow parametric types whose string representation includes
42+
# embedded parameters (unit, timezone, precision, value type, …).
43+
_TIMESTAMP_RE = re.compile(r"^timestamp\[(\w+)(?:,\s*tz=(.+))?\]$", re.IGNORECASE)
44+
_DURATION_RE = re.compile(r"^duration\[(\w+)\]$", re.IGNORECASE)
45+
_DECIMAL_RE = re.compile(r"^decimal128?\((\d+),\s*(\d+)\)$", re.IGNORECASE)
46+
_LIST_RE = re.compile(r"^(?:large_)?list<item:\s*(.+)>$", re.IGNORECASE)
47+
48+
# Map Arrow time-unit strings to Ibis IntervalUnit strings.
49+
_ARROW_UNIT_TO_IBIS: dict[str, str] = {
50+
"s": "s",
51+
"ms": "ms",
52+
"us": "us",
53+
"ns": "ns",
54+
}
55+
56+
57+
def _parse_parametric_arrow_type(raw: str, *, nullable: bool) -> dt.DataType | None:
58+
"""Try to parse an Arrow parametric type string into an Ibis DataType.
59+
60+
Returns ``None`` if ``raw`` does not match any known parametric pattern,
61+
allowing the caller to fall through to the Postgres dialect parser.
62+
"""
63+
m = _TIMESTAMP_RE.match(raw)
64+
if m:
65+
tz: str | None = m.group(2).strip() if m.group(2) else None
66+
return dt.Timestamp(timezone=tz, nullable=nullable)
67+
68+
m = _DURATION_RE.match(raw)
69+
if m:
70+
unit = _ARROW_UNIT_TO_IBIS.get(m.group(1).lower(), "s")
71+
return dt.Interval(unit=unit, nullable=nullable)
72+
73+
m = _DECIMAL_RE.match(raw)
74+
if m:
75+
return dt.Decimal(precision=int(m.group(1)), scale=int(m.group(2)), nullable=nullable)
76+
77+
m = _LIST_RE.match(raw)
78+
if m:
79+
value_type = dtype_from_hotdata_sql_type(m.group(1).strip(), nullable=True)
80+
return dt.Array(value_type=value_type, nullable=nullable)
81+
82+
return None
83+
3484

3585
def dtype_from_hotdata_sql_type(sql_type: str | None, *, nullable: bool) -> dt.DataType:
3686
"""Best-effort mapping from Hotdata `/information_schema` column `data_type` strings.
3787
3888
Hotdata may return either SQL-style names (``INTEGER``, ``VARCHAR``, ``DOUBLE
3989
PRECISION``, …) or Arrow-style names (``Date32``, ``Float64``, ``Utf8``, …).
4090
SQL-style names are delegated to the Postgres dialect parser; Arrow-style names
41-
are resolved via an explicit lookup table before falling back to the parser.
91+
are resolved via an explicit lookup table or parametric pattern before falling
92+
back to the parser.
4293
"""
4394
if not sql_type:
4495
return dt.String(nullable=nullable)
4596

97+
raw = sql_type.strip()
98+
4699
# Arrow-style names (case-insensitive lookup).
47-
arrow_cls = _ARROW_TYPE_MAP.get(sql_type.strip().lower())
100+
arrow_cls = _ARROW_TYPE_MAP.get(raw.lower())
48101
if arrow_cls is not None:
49102
return arrow_cls(nullable=nullable)
50103

104+
# Arrow parametric types (timestamp[us], duration[ms], decimal128(p,s), list<…>).
105+
parametric = _parse_parametric_arrow_type(raw, nullable=nullable)
106+
if parametric is not None:
107+
return parametric
108+
51109
try:
52-
return PostgresType.from_string(sql_type.strip(), nullable=nullable)
110+
return PostgresType.from_string(raw, nullable=nullable)
53111
except Exception: # ibis/sqlglot raise a variety of parse errors; fall back to String
54112
return dt.String(nullable=nullable)

tests/test_hotdata_types.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ def test_dtype_from_hotdata_malformed_fallback_string():
5555
("LargeBinary", True, dt.Binary),
5656
("Time32", True, dt.Time),
5757
("Time64", False, dt.Time),
58+
# Previously missing: signed int8 (Postgres "int8" means int64, not int8)
59+
("int8", True, dt.Int8),
60+
("Int8", False, dt.Int8),
61+
# Previously missing: halffloat (PyArrow's str() name for float16)
62+
("halffloat", True, dt.Float16),
63+
("HALFFLOAT", False, dt.Float16),
64+
# Previously missing: large_string (PyArrow large-offset string variant)
65+
("large_string", True, dt.String),
66+
("Large_String", False, dt.String),
5867
# Case-insensitive
5968
("date32", True, dt.Date),
6069
("FLOAT64", True, dt.Float64),
@@ -65,3 +74,73 @@ def test_dtype_from_hotdata_arrow_type_names(sql_type, nullable, expected_cls):
6574
out = dtype_from_hotdata_sql_type(sql_type, nullable=nullable)
6675
assert out.nullable is nullable
6776
assert isinstance(out, expected_cls)
77+
78+
79+
@pytest.mark.parametrize(
80+
("sql_type", "expected_tz"),
81+
[
82+
("timestamp[s]", None),
83+
("timestamp[ms]", None),
84+
("timestamp[us]", None),
85+
("timestamp[ns]", None),
86+
("timestamp[us, tz=UTC]", "UTC"),
87+
("timestamp[us, tz=America/New_York]", "America/New_York"),
88+
("TIMESTAMP[US]", None),
89+
],
90+
)
91+
def test_dtype_from_hotdata_arrow_timestamp(sql_type, expected_tz):
92+
out = dtype_from_hotdata_sql_type(sql_type, nullable=True)
93+
assert isinstance(out, dt.Timestamp)
94+
assert out.timezone == expected_tz
95+
assert out.nullable is True
96+
97+
98+
@pytest.mark.parametrize(
99+
("sql_type", "expected_unit"),
100+
[
101+
("duration[s]", "s"),
102+
("duration[ms]", "ms"),
103+
("duration[us]", "us"),
104+
("duration[ns]", "ns"),
105+
("DURATION[MS]", "ms"),
106+
],
107+
)
108+
def test_dtype_from_hotdata_arrow_duration(sql_type, expected_unit):
109+
out = dtype_from_hotdata_sql_type(sql_type, nullable=False)
110+
assert isinstance(out, dt.Interval)
111+
assert out.unit.value == expected_unit
112+
assert out.nullable is False
113+
114+
115+
@pytest.mark.parametrize(
116+
("sql_type", "expected_precision", "expected_scale"),
117+
[
118+
("decimal128(10, 3)", 10, 3),
119+
("decimal128(38, 0)", 38, 0),
120+
("decimal(5, 2)", 5, 2),
121+
("DECIMAL128(18, 6)", 18, 6),
122+
],
123+
)
124+
def test_dtype_from_hotdata_arrow_decimal(sql_type, expected_precision, expected_scale):
125+
out = dtype_from_hotdata_sql_type(sql_type, nullable=True)
126+
assert isinstance(out, dt.Decimal)
127+
assert out.precision == expected_precision
128+
assert out.scale == expected_scale
129+
assert out.nullable is True
130+
131+
132+
@pytest.mark.parametrize(
133+
("sql_type", "expected_value_cls"),
134+
[
135+
("list<item: int32>", dt.Int32),
136+
("list<item: utf8>", dt.String),
137+
("list<item: float64>", dt.Float64),
138+
("large_list<item: int64>", dt.Int64),
139+
("LIST<ITEM: UINT8>", dt.UInt8),
140+
],
141+
)
142+
def test_dtype_from_hotdata_arrow_list(sql_type, expected_value_cls):
143+
out = dtype_from_hotdata_sql_type(sql_type, nullable=True)
144+
assert isinstance(out, dt.Array)
145+
assert isinstance(out.value_type, expected_value_cls)
146+
assert out.nullable is True

0 commit comments

Comments
 (0)