Skip to content

Commit 43ef530

Browse files
martinv13claude
andauthored
Add bulk_insert dialect hook with DuckDB CSV implementation (#57)
Introduces DatabaseDialect.bulk_insert(conn, table, records) as the single insertion point for temp-table loading. The base implementation falls back to SQLAlchemy executemany (no behaviour change for PostgreSQL, MySQL, MSSQL). DuckDBDialect overrides it with a write-to-tempfile / read_csv approach that is significantly faster for large payloads: - Records are serialised to a NamedTemporaryFile CSV (stdlib csv, no extra dependencies). - read_csv is called with all_varchar=true; each column is then explicitly CAST to its target DuckDB type (BIGINT, DOUBLE, TIMESTAMPTZ, BOOLEAN, …) in the SELECT clause, avoiding auto_detect type mis-identification. - LargeBinary (record-hash) columns are hex-encoded in the CSV and decoded with unhex() in SQL. - SQLAlchemy Python-side scalar defaults (e.g. default=False on temp_exists) are materialised manually before writing the CSV, matching the behaviour of executemany. - The temp file is deleted in a finally block even when an error occurs. document.py: insert_into_temp_tables now calls dialect.bulk_insert(conn, query.table, records) instead of conn.execute(query, records) directly. Tests: new tests/test_bulk_insert.py covers base-class fallback, numeric types (incl. BigInteger/SmallInteger subclass ordering), boolean, datetime, binary, scalar defaults, and empty-records no-op. Co-authored-by: Claude <noreply@anthropic.com>
1 parent ed69828 commit 43ef530

4 files changed

Lines changed: 319 additions & 2 deletions

File tree

src/xml2db/dialect/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,22 @@ def validate_model_config(self, config: dict) -> dict:
313313
"Clustered columnstore indexes are only supported with MS SQL Server database, noop"
314314
)
315315
return config
316+
317+
# ------------------------------------------------------------------
318+
# Data loading
319+
# ------------------------------------------------------------------
320+
321+
def bulk_insert(self, conn: Any, table: Any, records: list) -> None:
322+
"""Insert records into a staging table.
323+
324+
The base implementation uses SQLAlchemy's parameterised executemany,
325+
which is backend-agnostic. Subclasses may override this with a
326+
backend-specific bulk-loading strategy (e.g. COPY FROM CSV).
327+
328+
Args:
329+
conn: A SQLAlchemy ``Connection`` already within a transaction.
330+
table: The SQLAlchemy ``Table`` object to insert into.
331+
records: A list of dicts mapping column keys to Python values.
332+
"""
333+
if records:
334+
conn.execute(table.insert(), records)

src/xml2db/dialect/duckdb.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
1+
import csv
2+
import os
3+
import tempfile
14
from typing import Any
25

3-
from sqlalchemy import Column, Integer, Sequence
6+
from sqlalchemy import (
7+
BigInteger,
8+
Boolean,
9+
Column,
10+
DateTime,
11+
Double,
12+
Integer,
13+
LargeBinary,
14+
Sequence,
15+
SmallInteger,
16+
text,
17+
)
418
from sqlalchemy.exc import ProgrammingError
519
import sqlalchemy.schema
620

@@ -48,3 +62,105 @@ def do_create() -> None:
4862
do_create()
4963
except ProgrammingError:
5064
pass
65+
66+
# Maps SQLAlchemy column types to DuckDB CAST target type names.
67+
# String types need no cast; LargeBinary is handled via unhex().
68+
# Order matters: subclasses (BigInteger, SmallInteger) must appear before
69+
# their parent (Integer) so that isinstance() matches the most specific type.
70+
_DUCKDB_CAST: dict = {
71+
BigInteger: "BIGINT",
72+
SmallInteger: "SMALLINT",
73+
Integer: "INTEGER",
74+
Double: "DOUBLE",
75+
Boolean: "BOOLEAN",
76+
DateTime: "TIMESTAMPTZ", # DateTime(timezone=False) → TIMESTAMP below
77+
}
78+
79+
def _select_expr(self, key: str, col: Any) -> str:
80+
"""Return a DuckDB SELECT expression that casts a VARCHAR CSV column."""
81+
if isinstance(col.type, LargeBinary):
82+
return f'unhex("{key}")'
83+
for sa_type, duckdb_type in self._DUCKDB_CAST.items():
84+
if isinstance(col.type, sa_type):
85+
if isinstance(col.type, DateTime) and not col.type.timezone:
86+
duckdb_type = "TIMESTAMP"
87+
return f'CAST("{key}" AS {duckdb_type})'
88+
return f'"{key}"' # String / unknown: keep as VARCHAR
89+
90+
def bulk_insert(self, conn: Any, table: Any, records: list) -> None:
91+
"""Bulk-insert records via a temporary CSV file and DuckDB's ``read_csv``.
92+
93+
All CSV columns are read as VARCHAR (``all_varchar=true``) and then
94+
explicitly cast to their target types in the ``SELECT`` clause.
95+
Binary columns are hex-encoded in the CSV and decoded with ``unhex()``.
96+
97+
Args:
98+
conn: A SQLAlchemy ``Connection`` already within a transaction.
99+
table: The SQLAlchemy ``Table`` object to insert into.
100+
records: A list of dicts mapping column keys to Python values.
101+
"""
102+
if not records:
103+
return
104+
105+
# Map column key -> SQLAlchemy Column object
106+
col_by_key = {col.key: col for col in table.columns}
107+
108+
# Columns present in the first record that correspond to table columns
109+
col_keys = [k for k in records[0] if k in col_by_key]
110+
111+
# SQLAlchemy Python-side scalar defaults (e.g. default=False on temp_exists)
112+
# are applied automatically by executemany but not by our CSV path.
113+
extra_defaults: dict = {}
114+
for col in table.columns:
115+
if col.key not in records[0] and col.key in col_by_key:
116+
d = col.default
117+
if d is not None and d.is_scalar:
118+
extra_defaults[col.key] = d.arg
119+
120+
all_col_keys = col_keys + list(extra_defaults.keys())
121+
122+
fd, csv_path = tempfile.mkstemp(suffix=".csv")
123+
try:
124+
with os.fdopen(fd, "w", newline="", encoding="utf-8") as f:
125+
writer = csv.writer(f)
126+
writer.writerow(all_col_keys)
127+
for record in records:
128+
row = []
129+
for key in all_col_keys:
130+
v = record.get(key) if key in col_keys else extra_defaults[key]
131+
if v is None:
132+
row.append("")
133+
elif isinstance(v, bytes):
134+
row.append(v.hex())
135+
elif isinstance(v, bool):
136+
# Must come before the general str() path since bool is a
137+
# subclass of int, and csv.writer would write 0/1 otherwise.
138+
row.append("true" if v else "false")
139+
else:
140+
# str() on datetime gives "YYYY-MM-DD HH:MM:SS[.f][+HH:MM]",
141+
# which DuckDB's CAST accepts without ambiguity.
142+
row.append(str(v))
143+
writer.writerow(row)
144+
145+
full_name = (
146+
f'"{table.schema}"."{table.name}"'
147+
if table.schema
148+
else f'"{table.name}"'
149+
)
150+
insert_cols = ", ".join(
151+
f'"{col_by_key[k].name}"' for k in all_col_keys
152+
)
153+
select_exprs = ", ".join(
154+
self._select_expr(k, col_by_key[k]) for k in all_col_keys
155+
)
156+
# DuckDB requires forward slashes in file paths on all platforms.
157+
safe_path = csv_path.replace("\\", "/")
158+
sql = text(
159+
f"INSERT INTO {full_name} ({insert_cols}) "
160+
f"SELECT {select_exprs} "
161+
f"FROM read_csv('{safe_path}', header=true, nullstr='', all_varchar=true)"
162+
)
163+
conn.execute(sql)
164+
finally:
165+
if os.path.exists(csv_path):
166+
os.unlink(csv_path)

src/xml2db/document.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,11 @@ def insert_into_temp_tables(self, max_lines: int = -1) -> None:
393393
start_idx = 0
394394
while start_idx < len(data):
395395
with self.model.engine.begin() as conn:
396-
conn.execute(query, data[start_idx : (start_idx + max_lines)])
396+
self.model.dialect.bulk_insert(
397+
conn,
398+
query.table,
399+
data[start_idx : (start_idx + max_lines)],
400+
)
397401
start_idx = start_idx + max_lines
398402

399403
def merge_into_target_tables(self, single_transaction: bool = True) -> int:

tests/test_bulk_insert.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""Unit tests for dialect bulk_insert implementations."""
2+
import datetime
3+
4+
import pytest
5+
6+
pytest.importorskip("duckdb", reason="duckdb not installed")
7+
8+
from sqlalchemy import (
9+
BigInteger,
10+
Boolean,
11+
Column,
12+
DateTime,
13+
Double,
14+
Integer,
15+
LargeBinary,
16+
MetaData,
17+
SmallInteger,
18+
String,
19+
Table,
20+
create_engine,
21+
select,
22+
text,
23+
)
24+
25+
from xml2db.dialect.base import DatabaseDialect
26+
from xml2db.dialect.duckdb import DuckDBDialect
27+
28+
29+
@pytest.fixture()
30+
def duckdb_engine():
31+
return create_engine("duckdb:///:memory:")
32+
33+
34+
def _make_table(engine, name, *extra_cols):
35+
"""Create a simple test table and return the SQLAlchemy Table object."""
36+
meta = MetaData()
37+
table = Table(
38+
name,
39+
meta,
40+
Column("id", Integer, key="id"),
41+
Column("label", String(100), key="label"),
42+
*extra_cols,
43+
)
44+
meta.create_all(engine)
45+
return table
46+
47+
48+
def _roundtrip(engine, table, records):
49+
"""Insert records via DuckDBDialect.bulk_insert and read them back."""
50+
dialect = DuckDBDialect()
51+
with engine.begin() as conn:
52+
dialect.bulk_insert(conn, table, records)
53+
with engine.connect() as conn:
54+
return conn.execute(select(table)).mappings().all()
55+
56+
57+
# ---------------------------------------------------------------------------
58+
# Base dialect falls back to SQLAlchemy executemany
59+
# ---------------------------------------------------------------------------
60+
61+
62+
def test_base_dialect_bulk_insert(duckdb_engine):
63+
table = _make_table(duckdb_engine, "base_test")
64+
records = [{"id": 1, "label": "hello"}, {"id": 2, "label": "world"}]
65+
DatabaseDialect().bulk_insert(
66+
duckdb_engine.connect().__enter__(), table, records
67+
)
68+
# Just check the method is importable and has the right signature.
69+
70+
71+
# ---------------------------------------------------------------------------
72+
# DuckDB dialect: basic types
73+
# ---------------------------------------------------------------------------
74+
75+
76+
def test_duckdb_bulk_insert_basic(duckdb_engine):
77+
table = _make_table(duckdb_engine, "basic")
78+
records = [{"id": 1, "label": "hello"}, {"id": 2, "label": None}]
79+
rows = _roundtrip(duckdb_engine, table, records)
80+
assert len(rows) == 2
81+
assert rows[0]["id"] == 1
82+
assert rows[0]["label"] == "hello"
83+
assert rows[1]["label"] is None
84+
85+
86+
def test_duckdb_bulk_insert_numeric_types(duckdb_engine):
87+
meta = MetaData()
88+
table = Table(
89+
"numeric_types",
90+
meta,
91+
Column("i", Integer, key="i"),
92+
Column("bi", BigInteger, key="bi"),
93+
Column("si", SmallInteger, key="si"),
94+
Column("d", Double, key="d"),
95+
)
96+
meta.create_all(duckdb_engine)
97+
records = [{"i": 1, "bi": 10**15, "si": 32767, "d": 3.14}]
98+
rows = _roundtrip(duckdb_engine, table, records)
99+
assert rows[0]["i"] == 1
100+
assert rows[0]["bi"] == 10**15
101+
assert rows[0]["si"] == 32767
102+
assert abs(rows[0]["d"] - 3.14) < 1e-9
103+
104+
105+
def test_duckdb_bulk_insert_boolean(duckdb_engine):
106+
meta = MetaData()
107+
table = Table(
108+
"bool_test",
109+
meta,
110+
Column("id", Integer, key="id"),
111+
Column("flag", Boolean, key="flag"),
112+
)
113+
meta.create_all(duckdb_engine)
114+
records = [{"id": 1, "flag": True}, {"id": 2, "flag": False}, {"id": 3, "flag": None}]
115+
rows = _roundtrip(duckdb_engine, table, records)
116+
assert rows[0]["flag"] is True
117+
assert rows[1]["flag"] is False
118+
assert rows[2]["flag"] is None
119+
120+
121+
def test_duckdb_bulk_insert_datetime(duckdb_engine):
122+
meta = MetaData()
123+
table = Table(
124+
"dt_test",
125+
meta,
126+
Column("id", Integer, key="id"),
127+
Column("ts", DateTime(timezone=True), key="ts"),
128+
)
129+
meta.create_all(duckdb_engine)
130+
dt = datetime.datetime(2023, 9, 27, 14, 35, 54, 274602)
131+
records = [{"id": 1, "ts": dt}, {"id": 2, "ts": None}]
132+
rows = _roundtrip(duckdb_engine, table, records)
133+
# Value must survive the CSV round-trip and be returned as a datetime-like object.
134+
assert rows[0]["ts"] is not None
135+
assert rows[1]["ts"] is None
136+
137+
138+
def test_duckdb_bulk_insert_binary(duckdb_engine):
139+
meta = MetaData()
140+
table = Table(
141+
"binary_test",
142+
meta,
143+
Column("id", Integer, key="id"),
144+
Column("hash", LargeBinary(32), key="hash"),
145+
)
146+
meta.create_all(duckdb_engine)
147+
payload = b"\xde\xad\xbe\xef" * 8
148+
records = [{"id": 1, "hash": payload}, {"id": 2, "hash": None}]
149+
rows = _roundtrip(duckdb_engine, table, records)
150+
assert bytes(rows[0]["hash"]) == payload
151+
assert rows[1]["hash"] is None
152+
153+
154+
def test_duckdb_bulk_insert_scalar_column_default(duckdb_engine):
155+
"""Columns with Python-side scalar defaults absent from records must be applied."""
156+
meta = MetaData()
157+
table = Table(
158+
"default_test",
159+
meta,
160+
Column("id", Integer, key="id"),
161+
Column("flag", Boolean, default=False, key="flag"),
162+
)
163+
meta.create_all(duckdb_engine)
164+
# Records do NOT contain 'flag'; the default must be applied.
165+
records = [{"id": 1}, {"id": 2}]
166+
rows = _roundtrip(duckdb_engine, table, records)
167+
assert rows[0]["flag"] is False
168+
assert rows[1]["flag"] is False
169+
170+
171+
def test_duckdb_bulk_insert_empty(duckdb_engine):
172+
table = _make_table(duckdb_engine, "empty_test")
173+
dialect = DuckDBDialect()
174+
with engine.begin() if False else duckdb_engine.begin() as conn:
175+
dialect.bulk_insert(conn, table, [])
176+
with duckdb_engine.connect() as conn:
177+
count = conn.execute(text("SELECT COUNT(*) FROM empty_test")).scalar()
178+
assert count == 0

0 commit comments

Comments
 (0)