Skip to content

Commit 7794063

Browse files
committed
Add copy_records_to_table for COPY FROM STDIN bulk-load
Closes #166. The existing binary_copy_to_table required callers to pre-encode PostgreSQL's binary COPY wire format, leaving no ergonomic bulk-load path comparable to asyncpg's copy_records_to_table or psycopg3's cursor.copy(...). The new method on Connection and Transaction accepts an iterable of records, introspects column types from the target table, and streams rows via tokio-postgres' BinaryCopyInWriter using the same PythonDTO conversions used by execute().
1 parent 91d611a commit 7794063

3 files changed

Lines changed: 363 additions & 3 deletions

File tree

python/psqlpy/_internal/__init__.pyi

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,32 @@ class Transaction:
820820
number of inserted rows;
821821
"""
822822

823+
async def copy_records_to_table(
824+
self: Self,
825+
table_name: str,
826+
records: typing.Iterable[Sequence[Any]],
827+
columns: Sequence[str] | None = None,
828+
schema_name: str | None = None,
829+
) -> int:
830+
"""Copy records into a table using the binary COPY FROM STDIN protocol.
831+
832+
Column types are introspected from the target table, so each record
833+
may contain raw Python values (the same conversions used by
834+
`execute`). Mirrors `asyncpg.Connection.copy_records_to_table`.
835+
836+
### Parameters:
837+
- `table_name`: name of the table.
838+
- `records`: iterable of records (each a sequence of column values
839+
matching the order of `columns`, or of the table's columns when
840+
`columns` is `None`).
841+
- `columns`: sequence of column names to load into. When `None`,
842+
all columns of the table are used in their declared order.
843+
- `schema_name`: optional schema for `table_name`.
844+
845+
### Returns:
846+
number of inserted rows;
847+
"""
848+
823849
async def connect(
824850
dsn: str | None = None,
825851
username: str | None = None,
@@ -1146,6 +1172,32 @@ class Connection:
11461172
number of inserted rows;
11471173
"""
11481174

1175+
async def copy_records_to_table(
1176+
self: Self,
1177+
table_name: str,
1178+
records: typing.Iterable[Sequence[Any]],
1179+
columns: Sequence[str] | None = None,
1180+
schema_name: str | None = None,
1181+
) -> int:
1182+
"""Copy records into a table using the binary COPY FROM STDIN protocol.
1183+
1184+
Column types are introspected from the target table, so each record
1185+
may contain raw Python values (the same conversions used by
1186+
`execute`). Mirrors `asyncpg.Connection.copy_records_to_table`.
1187+
1188+
### Parameters:
1189+
- `table_name`: name of the table.
1190+
- `records`: iterable of records (each a sequence of column values
1191+
matching the order of `columns`, or of the table's columns when
1192+
`columns` is `None`).
1193+
- `columns`: sequence of column names to load into. When `None`,
1194+
all columns of the table are used in their declared order.
1195+
- `schema_name`: optional schema for `table_name`.
1196+
1197+
### Returns:
1198+
number of inserted rows;
1199+
"""
1200+
11491201
class ConnectionPoolStatus:
11501202
max_size: int
11511203
size: int

python/tests/test_copy_records.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import typing
2+
from datetime import datetime, timezone
3+
4+
import pytest
5+
from psqlpy import ConnectionPool
6+
7+
pytestmark = pytest.mark.anyio
8+
9+
10+
async def _setup_target_table(psql_pool: ConnectionPool, name: str) -> None:
11+
async with psql_pool.acquire() as connection:
12+
await connection.execute(f"DROP TABLE IF EXISTS {name}")
13+
await connection.execute(
14+
f"""
15+
CREATE TABLE {name} (
16+
id INTEGER,
17+
label TEXT,
18+
weight DOUBLE PRECISION,
19+
created_at TIMESTAMPTZ
20+
)
21+
""",
22+
)
23+
24+
25+
async def _drop_target_table(psql_pool: ConnectionPool, name: str) -> None:
26+
async with psql_pool.acquire() as connection:
27+
await connection.execute(f"DROP TABLE IF EXISTS {name}")
28+
29+
30+
async def test_copy_records_to_table_on_connection(
31+
psql_pool: ConnectionPool,
32+
) -> None:
33+
target: typing.Final = "copy_records_conn"
34+
await _setup_target_table(psql_pool, target)
35+
try:
36+
records = [
37+
(1, "alpha", 1.5, datetime(2026, 1, 1, tzinfo=timezone.utc)),
38+
(2, "beta", 2.25, datetime(2026, 1, 2, tzinfo=timezone.utc)),
39+
(3, "gamma", None, datetime(2026, 1, 3, tzinfo=timezone.utc)),
40+
]
41+
42+
async with psql_pool.acquire() as connection:
43+
inserted = await connection.copy_records_to_table(
44+
table_name=target,
45+
records=records,
46+
)
47+
48+
assert inserted == len(records)
49+
50+
async with psql_pool.acquire() as connection:
51+
result = await connection.execute(
52+
f"SELECT id, label, weight FROM {target} ORDER BY id",
53+
)
54+
rows = result.result()
55+
assert [(r["id"], r["label"], r["weight"]) for r in rows] == [
56+
(1, "alpha", 1.5),
57+
(2, "beta", 2.25),
58+
(3, "gamma", None),
59+
]
60+
finally:
61+
await _drop_target_table(psql_pool, target)
62+
63+
64+
async def test_copy_records_to_table_with_columns_subset(
65+
psql_pool: ConnectionPool,
66+
) -> None:
67+
target: typing.Final = "copy_records_subset"
68+
await _setup_target_table(psql_pool, target)
69+
try:
70+
records = [(10, "only-label"), (11, "another")]
71+
72+
async with psql_pool.acquire() as connection:
73+
inserted = await connection.copy_records_to_table(
74+
table_name=target,
75+
records=records,
76+
columns=["id", "label"],
77+
)
78+
79+
assert inserted == len(records)
80+
81+
async with psql_pool.acquire() as connection:
82+
result = await connection.execute(
83+
f"SELECT id, label, weight, created_at FROM {target} ORDER BY id",
84+
)
85+
rows = result.result()
86+
assert [(r["id"], r["label"]) for r in rows] == [
87+
(10, "only-label"),
88+
(11, "another"),
89+
]
90+
# Untouched columns must remain NULL
91+
assert all(r["weight"] is None and r["created_at"] is None for r in rows)
92+
finally:
93+
await _drop_target_table(psql_pool, target)
94+
95+
96+
async def test_copy_records_to_table_in_transaction(
97+
psql_pool: ConnectionPool,
98+
) -> None:
99+
target: typing.Final = "copy_records_tx"
100+
await _setup_target_table(psql_pool, target)
101+
try:
102+
records = [(100, "tx-row", 0.0, datetime(2026, 5, 1, tzinfo=timezone.utc))]
103+
104+
async with psql_pool.acquire() as connection:
105+
async with connection.transaction() as tx:
106+
inserted = await tx.copy_records_to_table(
107+
table_name=target,
108+
records=records,
109+
)
110+
111+
assert inserted == 1
112+
113+
async with psql_pool.acquire() as connection:
114+
result = await connection.execute(
115+
f"SELECT COUNT(*) AS c FROM {target}",
116+
)
117+
assert result.result()[0]["c"] == 1
118+
finally:
119+
await _drop_target_table(psql_pool, target)
120+
121+
122+
async def test_copy_records_to_table_rejects_record_arity_mismatch(
123+
psql_pool: ConnectionPool,
124+
) -> None:
125+
target: typing.Final = "copy_records_mismatch"
126+
await _setup_target_table(psql_pool, target)
127+
try:
128+
records = [(1, "missing-rest")] # table has 4 columns
129+
130+
async with psql_pool.acquire() as connection:
131+
with pytest.raises(Exception): # noqa: PT011
132+
await connection.copy_records_to_table(
133+
table_name=target,
134+
records=records,
135+
)
136+
finally:
137+
await _drop_target_table(psql_pool, target)
138+
139+
140+
async def test_copy_records_to_table_uses_schema_qualifier(
141+
psql_pool: ConnectionPool,
142+
) -> None:
143+
schema: typing.Final = "copy_records_schema"
144+
target: typing.Final = "tbl"
145+
146+
async with psql_pool.acquire() as connection:
147+
await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE")
148+
await connection.execute(f"CREATE SCHEMA {schema}")
149+
await connection.execute(
150+
f"CREATE TABLE {schema}.{target} (id INTEGER, label TEXT)",
151+
)
152+
153+
try:
154+
records = [(1, "schema-a"), (2, "schema-b")]
155+
async with psql_pool.acquire() as connection:
156+
inserted = await connection.copy_records_to_table(
157+
table_name=target,
158+
records=records,
159+
schema_name=schema,
160+
)
161+
162+
assert inserted == 2
163+
164+
async with psql_pool.acquire() as connection:
165+
result = await connection.execute(
166+
f"SELECT id, label FROM {schema}.{target} ORDER BY id",
167+
)
168+
assert [(r["id"], r["label"]) for r in result.result()] == records
169+
finally:
170+
async with psql_pool.acquire() as connection:
171+
await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE")

src/driver/common.rs

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ use super::{
1010
use pyo3::{pymethods, Py, PyAny};
1111

1212
use crate::{
13-
connection::traits::CloseTransaction,
13+
connection::traits::{CloseTransaction, Connection as _},
1414
exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError},
15+
value_converter::{dto::enums::PythonDTO, from_python::from_python_typed},
1516
};
1617

1718
use bytes::BytesMut;
1819
use futures_util::pin_mut;
19-
use pyo3::{buffer::PyBuffer, Python};
20-
use tokio_postgres::binary_copy::BinaryCopyInWriter;
20+
use pyo3::{buffer::PyBuffer, types::PyAnyMethods, Python};
21+
use tokio_postgres::{binary_copy::BinaryCopyInWriter, types::ToSql};
2122

2223
use crate::format_helpers::quote_ident;
2324

@@ -320,3 +321,139 @@ macro_rules! impl_binary_copy_method {
320321

321322
impl_binary_copy_method!(Connection);
322323
impl_binary_copy_method!(Transaction);
324+
325+
macro_rules! impl_copy_records_method {
326+
($name:ident) => {
327+
#[pymethods]
328+
impl $name {
329+
/// Copy a list of records into a table using the COPY FROM STDIN
330+
/// binary protocol.
331+
///
332+
/// Column types are introspected from the target table, so callers
333+
/// pass Python values directly (the same conversions used by
334+
/// `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`.
335+
///
336+
/// # Errors
337+
/// May return error if there is some problem with DB communication,
338+
/// the table cannot be introspected, or a value cannot be converted.
339+
#[pyo3(signature = (table_name, records, columns=None, schema_name=None))]
340+
pub async fn copy_records_to_table(
341+
self_: pyo3::Py<Self>,
342+
table_name: String,
343+
records: Py<PyAny>,
344+
columns: Option<Vec<String>>,
345+
schema_name: Option<String>,
346+
) -> PSQLPyResult<u64> {
347+
let (db_client, raw_records) = Python::with_gil(
348+
|gil| -> PSQLPyResult<(Option<_>, Vec<Vec<pyo3::Py<PyAny>>>)> {
349+
let db_client = self_.borrow(gil).conn.clone();
350+
351+
let Some(db_client) = db_client else {
352+
return Ok((None, Vec::new()));
353+
};
354+
355+
let bound = records.bind(gil);
356+
let mut rows: Vec<Vec<pyo3::Py<PyAny>>> = Vec::new();
357+
for item in bound.try_iter()? {
358+
let row = item?;
359+
let mut row_vec: Vec<pyo3::Py<PyAny>> = Vec::new();
360+
for cell in row.try_iter()? {
361+
row_vec.push(cell?.unbind());
362+
}
363+
rows.push(row_vec);
364+
}
365+
366+
Ok((Some(db_client), rows))
367+
},
368+
)?;
369+
370+
let Some(db_client) = db_client else {
371+
return Ok(0);
372+
};
373+
374+
let full_table_name = match schema_name {
375+
Some(ref schema) => {
376+
format!("{}.{}", quote_ident(schema), quote_ident(&table_name))
377+
}
378+
None => quote_ident(&table_name),
379+
};
380+
381+
let columns_sql = match columns {
382+
Some(ref cols) if !cols.is_empty() => Some(
383+
cols.iter()
384+
.map(|c| quote_ident(c))
385+
.collect::<Vec<_>>()
386+
.join(", "),
387+
),
388+
_ => None,
389+
};
390+
391+
let introspect_qs = match &columns_sql {
392+
Some(cols) => format!("SELECT {} FROM {} WHERE false", cols, full_table_name),
393+
None => format!("SELECT * FROM {} WHERE false", full_table_name),
394+
};
395+
396+
let read_conn_g = db_client.read().await;
397+
398+
let stmt = read_conn_g.prepare(&introspect_qs, false).await?;
399+
let column_types: Vec<tokio_postgres::types::Type> =
400+
stmt.columns().iter().map(|c| c.type_().clone()).collect();
401+
402+
if column_types.is_empty() {
403+
return Err(RustPSQLDriverError::PyToRustValueConversionError(
404+
"Cannot introspect column types from target table".into(),
405+
));
406+
}
407+
408+
let typed_rows: Vec<Vec<PythonDTO>> =
409+
Python::with_gil(|gil| -> PSQLPyResult<Vec<Vec<PythonDTO>>> {
410+
let mut typed: Vec<Vec<PythonDTO>> =
411+
Vec::with_capacity(raw_records.len());
412+
for (row_idx, row) in raw_records.iter().enumerate() {
413+
if row.len() != column_types.len() {
414+
return Err(RustPSQLDriverError::PyToRustValueConversionError(
415+
format!(
416+
"Record at index {} has {} fields, expected {}",
417+
row_idx,
418+
row.len(),
419+
column_types.len()
420+
),
421+
));
422+
}
423+
let mut row_dto: Vec<PythonDTO> = Vec::with_capacity(row.len());
424+
for (cell, ty) in row.iter().zip(column_types.iter()) {
425+
row_dto.push(from_python_typed(cell.bind(gil), ty)?);
426+
}
427+
typed.push(row_dto);
428+
}
429+
Ok(typed)
430+
})?;
431+
432+
let copy_qs = match &columns_sql {
433+
Some(cols) => format!(
434+
"COPY {}({}) FROM STDIN (FORMAT binary)",
435+
full_table_name, cols
436+
),
437+
None => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name),
438+
};
439+
440+
let sink = read_conn_g.copy_in(&copy_qs).await?;
441+
let writer = BinaryCopyInWriter::new(sink, &column_types);
442+
pin_mut!(writer);
443+
444+
for row in &typed_rows {
445+
let row_refs: Vec<&(dyn ToSql + Sync)> =
446+
row.iter().map(|v| v as &(dyn ToSql + Sync)).collect();
447+
writer.as_mut().write(&row_refs).await?;
448+
}
449+
450+
let rows_created = writer.as_mut().finish().await?;
451+
452+
Ok(rows_created)
453+
}
454+
}
455+
};
456+
}
457+
458+
impl_copy_records_method!(Connection);
459+
impl_copy_records_method!(Transaction);

0 commit comments

Comments
 (0)