diff --git a/README.md b/README.md index 8b836e0..bc5760a 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Use [Ibis](https://ibis-project.org/) to create on-demand databases, upload data, and query with Python expressions — get pandas or Arrow results back without writing SQL. -**Requirements:** Python 3.10+, **ibis-framework** 10.x, **hotdata** ≥0.2.3. +**Requirements:** Python 3.10+, **ibis-framework** ≥10,<11, **hotdata** ≥0.2.3. ## Install @@ -25,7 +25,7 @@ con = ibis.hotdata.connect( ) # 1. Create a database and declare the tables you'll load -con.create_database("sales", schema="public", tables=["orders"]) +con.create_database("sales", tables=["orders"]) # 2. Upload a pandas DataFrame (or PyArrow table) df = pd.DataFrame({ @@ -33,14 +33,14 @@ df = pd.DataFrame({ "amount": [9.99, 49.99, 5.00], "region": ["west", "east", "west"], }) -con.create_table("orders", df, database=("sales", "public"), overwrite=True) +con.create_table("orders", df, database=("sales", "main"), overwrite=True) # 3. Uploads are async — wait briefly before querying time.sleep(2) # 4. Query with Ibis expressions # Managed tables are always accessed with catalog "default" -t = con.table("orders", database=("default", "public")) +t = con.table("orders", database=("default", "main")) result = ( t.group_by("region") .agg(total=t.amount.sum()) @@ -49,7 +49,7 @@ result = ( ) # 5. Clean up -con.drop_table("orders", database=("sales", "public")) +con.drop_table("orders", database=("sales", "main")) con.drop_database("sales") ``` @@ -60,13 +60,26 @@ con = ibis.hotdata.connect( api_url="https://api.hotdata.dev", token="YOUR_API_KEY", workspace_id="ws_...", + # optional + session_id=None, # sandbox id (X-Session-Id header) + timeout=120.0, # per-request HTTP timeout in seconds + verify_ssl=True, # False to skip TLS verification, or path to CA bundle + default_connection=None, # default catalog (connection id); auto-detected if only one exists + default_schema=None, # default schema; auto-detected if only one exists + database_id=None, # bind an existing managed database id at connect time + poll_interval_s=0.25, # polling interval for async queries + poll_timeout_s=600.0, # max time to wait for a query result ) ``` -URL-style also works: +URL-style also works, with the same parameters as query string keys: ```python -con = ibis.connect("hotdata://api.hotdata.dev/?token=...&workspace_id=ws_...") +con = ibis.connect( + "hotdata://api.hotdata.dev/" + "?token=...&workspace_id=ws_..." + "&default_connection=my_conn&default_schema=public" +) ``` ## Managed databases @@ -77,25 +90,33 @@ Managed databases are the primary way to bring data into Hotdata with Ibis. Decl ```python # Declare the database and all table names up front -con.create_database("analytics", schema="public", tables=["events", "users"]) +con.create_database("analytics", tables=["events", "users"]) # Upload from a pandas DataFrame -con.create_table("events", events_df, database=("analytics", "public"), overwrite=True) +con.create_table("events", events_df, database=("analytics", "main"), overwrite=True) # PyArrow tables also work import pyarrow as pa table = pa.table({"id": [1, 2], "name": ["alice", "bob"]}) -con.create_table("users", table, database=("analytics", "public"), overwrite=True) +con.create_table("users", table, database=("analytics", "main"), overwrite=True) + +# Schema-only (no data): creates an empty table with the declared schema +import ibis.expr.schema as sch +con.create_table( + "staging", + schema=sch.Schema({"id": "int64", "ts": "timestamp"}), + database=("analytics", "main"), +) ``` -Table names must be declared when the database is created — you cannot add new table names later without recreating the database. +Table names must be declared when the database is created — you cannot upload to a table name that was not listed in `tables=`. ### Query When querying, use `"default"` as the catalog: ```python -t = con.table("events", database=("default", "public")) +t = con.table("events", database=("default", "main")) result = ( t.filter(t.event_type == "click") @@ -110,7 +131,7 @@ Or with raw SQL: ```python result = con.sql( 'SELECT user_id, COUNT(*) AS n ' - 'FROM "default"."public"."events" ' + 'FROM "default"."main"."events" ' 'WHERE event_type = \'click\' ' 'GROUP BY user_id' ).execute() @@ -118,9 +139,14 @@ result = con.sql( ### Delete +Pass `force=True` to silently skip errors when the database or table does not exist: + ```python -con.drop_table("events", database=("analytics", "public")) +con.drop_table("events", database=("analytics", "main")) +con.drop_table("events", database=("analytics", "main"), force=True) # no-op if missing + con.drop_database("analytics") +con.drop_database("analytics", force=True) # no-op if missing ``` ### Addressing summary @@ -135,7 +161,7 @@ con.drop_database("analytics") ### Ibis expressions ```python -t = con.table("orders", database=("default", "public")) +t = con.table("orders", database=("default", "main")) summary = ( t.filter(t.amount > 10) @@ -146,13 +172,13 @@ summary = ( ) ``` -`.execute()` returns a **pandas DataFrame**. Use `.to_pyarrow()` for an Arrow table or `.to_pyarrow_batches()` to stream batches without materializing the full result. +`.execute()` returns a **pandas DataFrame**. `.to_pyarrow()` returns an Arrow table. `.to_pyarrow_batches()` returns a `RecordBatchReader` — note that Hotdata returns a single Arrow IPC payload per query, so this method downloads the full result first and then splits it into local batches. ### Raw SQL ```python base = con.sql( - 'SELECT * FROM "default"."public"."orders"', + 'SELECT * FROM "default"."main"."orders"', dialect="postgres", ) result = base.filter(base.amount > 10).execute() @@ -189,17 +215,20 @@ con.list_tables(database=("my_postgres", "public")) # tables | Feature | Status | |---------|--------| | `create_database` / `drop_database` (managed) | ✅ | -| `create_table` / `drop_table` (DataFrame or Arrow upload) | ✅ | +| `create_table` from pandas / PyArrow / schema-only | ✅ | +| `drop_table` | ✅ | | `con.table(...)` with full schema metadata | ✅ | | Ibis expressions: filter, select, join, group\_by, agg, order\_by, limit | ✅ | | `con.sql(...)` raw SQL | ✅ | | `.execute()` → pandas, `.to_pyarrow()`, `.to_pyarrow_batches()` | ✅ | | `list_catalogs`, `list_databases`, `list_tables` | ✅ | +| Arrow / Parquet column types (timestamp, decimal, list, duration, …) | ✅ | | Temporary tables | ❌ | +| In-memory tables (`ibis.memtable(...)`) | ❌ | | Python UDFs | ❌ | | INSERT / UPDATE / DELETE on external connections | ❌ | -SQL compilation uses Ibis's Postgres dialect. Use `con.sql(...)` as a fallback for expressions that don't compile cleanly. +SQL compilation uses Ibis's Postgres dialect. Column types returned by Hotdata's information schema are resolved via PyArrow's type system, so Parquet-loaded tables with Arrow-native types (timestamps with time zones, decimals, lists, durations) are mapped correctly to Ibis types. ## Development diff --git a/src/ibis_hotdata/backend.py b/src/ibis_hotdata/backend.py index cccc6cf..cfdb1d4 100644 --- a/src/ibis_hotdata/backend.py +++ b/src/ibis_hotdata/backend.py @@ -43,7 +43,6 @@ from ibis.backends.sql import SQLBackend from ibis_hotdata.http import HotdataAPIError, HotdataClient -from ibis_hotdata.managed import DEFAULT_SCHEMA from ibis_hotdata.types import dtype_from_hotdata_sql_type _INFORMATION_SCHEMA_PAGE_SIZE = 500 @@ -203,7 +202,7 @@ def do_connect( ) def disconnect(self) -> None: - if getattr(self, "_http", None) is not None: + if hasattr(self, "_http"): self._http.close() # --- hierarchy --------------------------------------------------------- @@ -253,10 +252,12 @@ def _to_catalog_db_tuple(self, table_loc: sge.Table): """Use the compiler SQL dialect when stringifying qualifiers (backend name is not a dialect).""" dialect = self.dialect - if (sg_cat := table_loc.args["catalog"]) is not None: + sg_cat = table_loc.args["catalog"] + if sg_cat is not None: sg_cat.args["quoted"] = False sg_cat = sg_cat.sql(dialect=dialect) - if (sg_db := table_loc.args["db"]) is not None: + sg_db = table_loc.args["db"] + if sg_db is not None: sg_db.args["quoted"] = False sg_db = sg_db.sql(dialect=dialect) @@ -429,7 +430,7 @@ def _resolve_database_connection_id(self) -> str | None: db = self._http.get_database(self._database_id) self._database_connection_id = db.get("default_connection_id") except HotdataAPIError: - pass + pass # best-effort: if the lookup fails, callers fall back to the catalog name return self._database_connection_id # --- schema / sql execution -------------------------------------------- @@ -575,7 +576,7 @@ def create_database( /, *, catalog: str | None = None, - schema: str = DEFAULT_SCHEMA, + schema: str = "main", tables: Sequence[str] | None = None, force: bool = False, ) -> None: @@ -722,7 +723,7 @@ def drop_table( raise _ibis_err_from_hotdata(exc) from exc def _register_in_memory_table(self, _op: ops.InMemoryTable) -> None: - return + pass # Hotdata has no local in-memory table concept; Ibis calls this hook before execute @cached_property def version(self) -> str: diff --git a/src/ibis_hotdata/http.py b/src/ibis_hotdata/http.py index 3ffec5a..bd51cd9 100644 --- a/src/ibis_hotdata/http.py +++ b/src/ibis_hotdata/http.py @@ -2,6 +2,7 @@ from __future__ import annotations +import http import io import json import time @@ -30,15 +31,14 @@ from hotdata.models.database_default_table_decl import DatabaseDefaultTableDecl from hotdata.models.load_managed_table_request import LoadManagedTableRequest -from ibis_hotdata.managed import DEFAULT_SCHEMA - T = TypeVar("T") # Matches Hotdata / runtimedb ``GET /v1/results/{{id}}`` Arrow responses. APPLICATION_ARROW_STREAM = "application/vnd.apache.arrow.stream" # Statuses that mean the query run is still in progress. -_IN_FLIGHT = {"running", "queued", "pending"} +# runtimedb QueryRunStatus only emits "running", "succeeded", "failed". +_IN_FLIGHT = {"running"} def _sleep_until(deadline: float, interval: float) -> None: @@ -197,7 +197,7 @@ def create_managed_database( self, description: str | None = None, *, - schema: str = DEFAULT_SCHEMA, + schema: str = "public", tables: Sequence[str] = (), ) -> dict[str, Any]: """POST ``/v1/databases`` — creates a managed database with an auto-provisioned default catalog.""" @@ -264,27 +264,27 @@ def _poll_result_arrow( status = raw.status ctype = (raw.headers.get("Content-Type") or "").split(";")[0].strip().lower() - if status == 200 and ctype == APPLICATION_ARROW_STREAM.lower(): + if status == http.HTTPStatus.OK and ctype == APPLICATION_ARROW_STREAM.lower(): table = _ipc_stream_bytes_to_table(body) return self._arrow_payload_from_table(table, result_id=result_id) - if status == 202: + if status == http.HTTPStatus.ACCEPTED: _sleep_until(deadline, poll_interval_s) continue - if status == 409: + if status == http.HTTPStatus.CONFLICT: d = _json_utf8(body) if body else {} raise HotdataAPIError( d.get("error_message") or "Result failed", - status_code=409, + status_code=http.HTTPStatus.CONFLICT, body=d, ) - if status == 404: + if status == http.HTTPStatus.NOT_FOUND: d = _json_utf8(body) if body else {} raise HotdataAPIError( d.get("detail") or f"Result {result_id!r} not found", - status_code=404, + status_code=http.HTTPStatus.NOT_FOUND, body=d, ) @@ -304,7 +304,7 @@ def _arrow_payload_from_table( ) -> dict[str, Any]: sch = table.schema columns = sch.names - nullable = [sch.field(i).nullable for i in range(len(columns))] + nullable = [field.nullable for field in sch] return { "format": "arrow", "pa_table": table, diff --git a/src/ibis_hotdata/managed.py b/src/ibis_hotdata/managed.py deleted file mode 100644 index 972c366..0000000 --- a/src/ibis_hotdata/managed.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Helpers for Hotdata managed databases.""" - -from __future__ import annotations - -DEFAULT_SCHEMA = "public" diff --git a/src/ibis_hotdata/types.py b/src/ibis_hotdata/types.py index cb7920d..7ddaa49 100644 --- a/src/ibis_hotdata/types.py +++ b/src/ibis_hotdata/types.py @@ -2,53 +2,146 @@ from __future__ import annotations +import re + +import pyarrow as pa import ibis.expr.datatypes as dt from ibis.backends.sql.datatypes import PostgresType +from ibis.formats.pyarrow import PyArrowType -# Arrow-style type names returned by Hotdata's information_schema when tables are -# loaded from Parquet/Arrow sources. PostgresType.from_string() treats these as -# USERDEFINED unknowns, so we resolve them explicitly before falling through. -_ARROW_TYPE_MAP: dict[str, type[dt.DataType]] = { +# Simple Arrow type strings → PyArrow instances. Covers non-parametric types +# that the Postgres dialect parser does not know (Arrow-specific names, unsigned +# ints) or mis-maps (Arrow "int8" = 8-bit; Postgres "int8" = 8-byte / int64). +# All scalar types that can appear as list/map element types must be listed here +# because element type strings are resolved via this map, not the Postgres parser. +_PA_TYPE_MAP: dict[str, pa.DataType] = { # dates - "date32": dt.Date, - "date64": dt.Date, - # floats - "float16": dt.Float16, - "float32": dt.Float32, - "float64": dt.Float64, - # unsigned ints - "uint8": dt.UInt8, - "uint16": dt.UInt16, - "uint32": dt.UInt32, - "uint64": dt.UInt64, - # strings - "utf8": dt.String, - "largeutf8": dt.String, + "date32": pa.date32(), + "date64": pa.date64(), + # floats — "halffloat" is PyArrow's str() name for float16 + "float16": pa.float16(), + "float32": pa.float32(), + "float64": pa.float64(), + "halffloat": pa.float16(), + # signed ints — Arrow "int8" ≠ Postgres "int8" (8-byte/int64); all four + # listed here so they resolve correctly when used as list element types + "int8": pa.int8(), + "int16": pa.int16(), + "int32": pa.int32(), + "int64": pa.int64(), + # unsigned ints (Postgres parser returns Unknown for all of these) + "uint8": pa.uint8(), + "uint16": pa.uint16(), + "uint32": pa.uint32(), + "uint64": pa.uint64(), + # strings — large-offset variants not known to the Postgres parser + "utf8": pa.utf8(), + "largeutf8": pa.large_utf8(), + "large_string": pa.large_utf8(), + "string": pa.string(), # binary - "largebinary": dt.Binary, - # time - "time32": dt.Time, - "time64": dt.Time, + "binary": pa.binary(), + "largebinary": pa.large_binary(), + # boolean / null + "bool": pa.bool_(), + "boolean": pa.bool_(), + "null": pa.null(), + # time — unit is absent from these bare string forms; the unit does not + # affect the Ibis type (both time32 and time64 map to dt.Time) + "time32": pa.time32("ms"), + "time64": pa.time64("us"), } +# Regex patterns for parametric Arrow types that embed parameters in the string. +_TIMESTAMP_RE = re.compile(r"^timestamp\[(\w+)(?:,\s*tz=(.+))?\]$", re.IGNORECASE) +_DURATION_RE = re.compile(r"^duration\[(\w+)\]$", re.IGNORECASE) +_DECIMAL_RE = re.compile(r"^decimal(?:128|256)?\((\d+),\s*(\d+)\)$", re.IGNORECASE) +_LIST_RE = re.compile(r"^(large_)?list$", re.IGNORECASE) +# PyArrow appends " not null" when a list's item field is non-nullable. +_NOT_NULL_SUFFIX_RE = re.compile(r"\s+not\s+null$", re.IGNORECASE) + + +def _pa_type_from_arrow_str(raw: str) -> pa.DataType | None: + """Best-effort: Arrow type string → PyArrow DataType, or ``None`` if not recognised. + + Handles simple names (via ``_PA_TYPE_MAP``) and parametric forms + (timestamp, duration, decimal, list/large_list). Returns ``None`` if the + string is not a known Arrow type, allowing the caller to fall through to the + Postgres dialect parser or String fallback. + """ + s = raw.strip() + + # Simple non-parametric types. + pa_type = _PA_TYPE_MAP.get(s.lower()) + if pa_type is not None: + return pa_type + + # timestamp[unit] or timestamp[unit, tz=…] + m = _TIMESTAMP_RE.match(s) + if m: + unit = m.group(1).lower() + tz_raw = m.group(2) + tz: str | None = tz_raw.strip() if tz_raw else None + try: + return pa.timestamp(unit, tz=tz) + except Exception: + return None + + # duration[unit] — unknown units return None so the caller falls through + m = _DURATION_RE.match(s) + if m: + try: + return pa.duration(m.group(1).lower()) + except Exception: + return None + + # decimal / decimal128 / decimal256 + m = _DECIMAL_RE.match(s) + if m: + precision, scale = int(m.group(1)), int(m.group(2)) + try: + # decimal128 supports precision 1–38; fall back to decimal256 for wider values + return pa.decimal128(precision, scale) if precision <= 38 else pa.decimal256(precision, scale) + except Exception: + return None + + # list or large_list (recursive for nested types) + m = _LIST_RE.match(s) + if m: + is_large = bool(m.group(1)) + item_raw = m.group(2).strip() + item_not_null = bool(_NOT_NULL_SUFFIX_RE.search(item_raw)) + item_str = _NOT_NULL_SUFFIX_RE.sub("", item_raw).strip() + item_pa_type = _pa_type_from_arrow_str(item_str) + if item_pa_type is None: + return None + item_field = pa.field("item", item_pa_type, nullable=not item_not_null) + return pa.large_list(item_field) if is_large else pa.list_(item_field) + + return None + def dtype_from_hotdata_sql_type(sql_type: str | None, *, nullable: bool) -> dt.DataType: - """Best-effort mapping from Hotdata `/information_schema` column `data_type` strings. + """Best-effort mapping from Hotdata ``/information_schema`` column ``data_type`` strings. Hotdata may return either SQL-style names (``INTEGER``, ``VARCHAR``, ``DOUBLE PRECISION``, …) or Arrow-style names (``Date32``, ``Float64``, ``Utf8``, …). - SQL-style names are delegated to the Postgres dialect parser; Arrow-style names - are resolved via an explicit lookup table before falling back to the parser. + Arrow-style names are resolved via PyArrow's type system and converted to Ibis + types using the Ibis–PyArrow bridge; SQL-style names fall through to the Postgres + dialect parser. """ if not sql_type: return dt.String(nullable=nullable) - # Arrow-style names (case-insensitive lookup). - arrow_cls = _ARROW_TYPE_MAP.get(sql_type.strip().lower()) - if arrow_cls is not None: - return arrow_cls(nullable=nullable) + raw = sql_type.strip() + + # Try to parse as an Arrow type string (simple or parametric). + pa_type = _pa_type_from_arrow_str(raw) + if pa_type is not None: + return PyArrowType.to_ibis(pa_type).copy(nullable=nullable) + # Fall through to Postgres dialect parser for SQL-style type names. try: - return PostgresType.from_string(sql_type.strip(), nullable=nullable) + return PostgresType.from_string(raw, nullable=nullable) except Exception: # ibis/sqlglot raise a variety of parse errors; fall back to String return dt.String(nullable=nullable) diff --git a/tests/test_hotdata_types.py b/tests/test_hotdata_types.py index 0e743c2..93e9cc8 100644 --- a/tests/test_hotdata_types.py +++ b/tests/test_hotdata_types.py @@ -55,6 +55,15 @@ def test_dtype_from_hotdata_malformed_fallback_string(): ("LargeBinary", True, dt.Binary), ("Time32", True, dt.Time), ("Time64", False, dt.Time), + # Previously missing: signed int8 (Postgres "int8" means int64, not int8) + ("int8", True, dt.Int8), + ("Int8", False, dt.Int8), + # Previously missing: halffloat (PyArrow's str() name for float16) + ("halffloat", True, dt.Float16), + ("HALFFLOAT", False, dt.Float16), + # Previously missing: large_string (PyArrow large-offset string variant) + ("large_string", True, dt.String), + ("Large_String", False, dt.String), # Case-insensitive ("date32", True, dt.Date), ("FLOAT64", True, dt.Float64), @@ -65,3 +74,95 @@ def test_dtype_from_hotdata_arrow_type_names(sql_type, nullable, expected_cls): out = dtype_from_hotdata_sql_type(sql_type, nullable=nullable) assert out.nullable is nullable assert isinstance(out, expected_cls) + + +@pytest.mark.parametrize( + ("sql_type", "expected_tz", "expected_scale"), + [ + ("timestamp[s]", None, 0), + ("timestamp[ms]", None, 3), + ("timestamp[us]", None, 6), + ("timestamp[ns]", None, 9), + ("timestamp[us, tz=UTC]", "UTC", 6), + ("timestamp[us, tz=America/New_York]", "America/New_York", 6), + ("TIMESTAMP[MS]", None, 3), + ], +) +def test_dtype_from_hotdata_arrow_timestamp(sql_type, expected_tz, expected_scale): + out = dtype_from_hotdata_sql_type(sql_type, nullable=True) + assert isinstance(out, dt.Timestamp) + assert out.timezone == expected_tz + assert out.scale == expected_scale + assert out.nullable is True + + +@pytest.mark.parametrize( + ("sql_type", "expected_unit"), + [ + ("duration[s]", "s"), + ("duration[ms]", "ms"), + ("duration[us]", "us"), + ("duration[ns]", "ns"), + ("DURATION[MS]", "ms"), + ], +) +def test_dtype_from_hotdata_arrow_duration(sql_type, expected_unit): + out = dtype_from_hotdata_sql_type(sql_type, nullable=False) + assert isinstance(out, dt.Interval) + assert out.unit.value == expected_unit + assert out.nullable is False + + +def test_dtype_from_hotdata_arrow_duration_unknown_unit_falls_back(): + # An unrecognised duration unit should not silently map to seconds; + # it falls through to the Postgres parser (which returns Unknown) or String fallback. + out = dtype_from_hotdata_sql_type("duration[foo]", nullable=True) + assert not isinstance(out, dt.Interval) # must not produce a valid Interval + + +@pytest.mark.parametrize( + ("sql_type", "expected_precision", "expected_scale"), + [ + ("decimal128(10, 3)", 10, 3), + ("decimal128(38, 0)", 38, 0), + ("decimal256(76, 38)", 76, 38), + ("decimal(5, 2)", 5, 2), + ("DECIMAL128(18, 6)", 18, 6), + # decimal12 is NOT a valid form — should not be matched by the decimal regex + ], +) +def test_dtype_from_hotdata_arrow_decimal(sql_type, expected_precision, expected_scale): + out = dtype_from_hotdata_sql_type(sql_type, nullable=True) + assert isinstance(out, dt.Decimal) + assert out.precision == expected_precision + assert out.scale == expected_scale + assert out.nullable is True + + +@pytest.mark.parametrize( + ("sql_type", "expected_value_cls", "expected_item_nullable"), + [ + ("list", dt.Int32, True), + ("list", dt.String, True), + ("list", dt.Float64, True), + ("large_list", dt.Int64, True), + ("LIST", dt.UInt8, True), + # Non-nullable item fields — PyArrow appends " not null" + ("list", dt.Int32, False), + ("list", dt.String, False), + ("large_list", dt.Float32, False), + ], +) +def test_dtype_from_hotdata_arrow_list(sql_type, expected_value_cls, expected_item_nullable): + out = dtype_from_hotdata_sql_type(sql_type, nullable=True) + assert isinstance(out, dt.Array) + assert isinstance(out.value_type, expected_value_cls) + assert out.value_type.nullable is expected_item_nullable + assert out.nullable is True + + +def test_dtype_from_hotdata_arrow_decimal12_not_matched(): + # "decimal12" (only the trailing 8 made optional) must NOT match the decimal regex. + # The Postgres parser handles bare "decimal" forms; decimal12 is not a real type. + out = dtype_from_hotdata_sql_type("decimal12(10, 3)", nullable=True) + assert not isinstance(out, dt.Decimal) # falls through to Unknown or String