Skip to content

Commit 87bf681

Browse files
committed
Add copy_records_to_table for COPY FROM STDIN bulk-load
Closes #166. The existing binary_copy_to_table required callers to pre-encode PostgreSQL's binary COPY wire format, leaving no ergonomic bulk-load path comparable to asyncpg's copy_records_to_table or psycopg3's cursor.copy(...). The new method on Connection and Transaction accepts an iterable of records, introspects column types from the target table, and streams rows via tokio-postgres' BinaryCopyInWriter using the same PythonDTO conversions used by execute().
1 parent 1e56b25 commit 87bf681

3 files changed

Lines changed: 365 additions & 3 deletions

File tree

python/psqlpy/_internal/__init__.pyi

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,32 @@ class Transaction:
820820
number of inserted rows;
821821
"""
822822

823+
async def copy_records_to_table(
824+
self: Self,
825+
table_name: str,
826+
records: typing.Iterable[Sequence[Any]],
827+
columns: Sequence[str] | None = None,
828+
schema_name: str | None = None,
829+
) -> int:
830+
"""Copy records into a table using the binary COPY FROM STDIN protocol.
831+
832+
Column types are introspected from the target table, so each record
833+
may contain raw Python values (the same conversions used by
834+
`execute`). Mirrors `asyncpg.Connection.copy_records_to_table`.
835+
836+
### Parameters:
837+
- `table_name`: name of the table.
838+
- `records`: iterable of records (each a sequence of column values
839+
matching the order of `columns`, or of the table's columns when
840+
`columns` is `None`).
841+
- `columns`: sequence of column names to load into. When `None`,
842+
all columns of the table are used in their declared order.
843+
- `schema_name`: optional schema for `table_name`.
844+
845+
### Returns:
846+
number of inserted rows;
847+
"""
848+
823849
async def connect(
824850
dsn: str | None = None,
825851
username: str | None = None,
@@ -1146,6 +1172,32 @@ class Connection:
11461172
number of inserted rows;
11471173
"""
11481174

1175+
async def copy_records_to_table(
1176+
self: Self,
1177+
table_name: str,
1178+
records: typing.Iterable[Sequence[Any]],
1179+
columns: Sequence[str] | None = None,
1180+
schema_name: str | None = None,
1181+
) -> int:
1182+
"""Copy records into a table using the binary COPY FROM STDIN protocol.
1183+
1184+
Column types are introspected from the target table, so each record
1185+
may contain raw Python values (the same conversions used by
1186+
`execute`). Mirrors `asyncpg.Connection.copy_records_to_table`.
1187+
1188+
### Parameters:
1189+
- `table_name`: name of the table.
1190+
- `records`: iterable of records (each a sequence of column values
1191+
matching the order of `columns`, or of the table's columns when
1192+
`columns` is `None`).
1193+
- `columns`: sequence of column names to load into. When `None`,
1194+
all columns of the table are used in their declared order.
1195+
- `schema_name`: optional schema for `table_name`.
1196+
1197+
### Returns:
1198+
number of inserted rows;
1199+
"""
1200+
11491201
class ConnectionPoolStatus:
11501202
max_size: int
11511203
size: int

python/tests/test_copy_records.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import typing
2+
from datetime import datetime, timezone
3+
4+
import pytest
5+
from psqlpy import ConnectionPool
6+
from psqlpy.exceptions import PyToRustValueMappingError
7+
8+
pytestmark = pytest.mark.anyio
9+
10+
11+
async def _setup_target_table(psql_pool: ConnectionPool, name: str) -> None:
12+
async with psql_pool.acquire() as connection:
13+
await connection.execute(f"DROP TABLE IF EXISTS {name}")
14+
await connection.execute(
15+
f"""
16+
CREATE TABLE {name} (
17+
id INTEGER,
18+
label TEXT,
19+
weight DOUBLE PRECISION,
20+
created_at TIMESTAMPTZ
21+
)
22+
""",
23+
)
24+
25+
26+
async def _drop_target_table(psql_pool: ConnectionPool, name: str) -> None:
27+
async with psql_pool.acquire() as connection:
28+
await connection.execute(f"DROP TABLE IF EXISTS {name}")
29+
30+
31+
async def test_copy_records_to_table_on_connection(
32+
psql_pool: ConnectionPool,
33+
) -> None:
34+
target: typing.Final = "copy_records_conn"
35+
await _setup_target_table(psql_pool, target)
36+
try:
37+
records = [
38+
(1, "alpha", 1.5, datetime(2026, 1, 1, tzinfo=timezone.utc)),
39+
(2, "beta", 2.25, datetime(2026, 1, 2, tzinfo=timezone.utc)),
40+
(3, "gamma", None, datetime(2026, 1, 3, tzinfo=timezone.utc)),
41+
]
42+
43+
async with psql_pool.acquire() as connection:
44+
inserted = await connection.copy_records_to_table(
45+
table_name=target,
46+
records=records,
47+
)
48+
49+
assert inserted == len(records)
50+
51+
async with psql_pool.acquire() as connection:
52+
result = await connection.execute(
53+
f"SELECT id, label, weight FROM {target} ORDER BY id",
54+
)
55+
rows = result.result()
56+
assert [(r["id"], r["label"], r["weight"]) for r in rows] == [
57+
(1, "alpha", 1.5),
58+
(2, "beta", 2.25),
59+
(3, "gamma", None),
60+
]
61+
finally:
62+
await _drop_target_table(psql_pool, target)
63+
64+
65+
async def test_copy_records_to_table_with_columns_subset(
66+
psql_pool: ConnectionPool,
67+
) -> None:
68+
target: typing.Final = "copy_records_subset"
69+
await _setup_target_table(psql_pool, target)
70+
try:
71+
records = [(10, "only-label"), (11, "another")]
72+
73+
async with psql_pool.acquire() as connection:
74+
inserted = await connection.copy_records_to_table(
75+
table_name=target,
76+
records=records,
77+
columns=["id", "label"],
78+
)
79+
80+
assert inserted == len(records)
81+
82+
async with psql_pool.acquire() as connection:
83+
result = await connection.execute(
84+
f"SELECT id, label, weight, created_at FROM {target} ORDER BY id",
85+
)
86+
rows = result.result()
87+
assert [(r["id"], r["label"]) for r in rows] == [
88+
(10, "only-label"),
89+
(11, "another"),
90+
]
91+
# Untouched columns must remain NULL
92+
assert all(r["weight"] is None and r["created_at"] is None for r in rows)
93+
finally:
94+
await _drop_target_table(psql_pool, target)
95+
96+
97+
async def test_copy_records_to_table_in_transaction(
98+
psql_pool: ConnectionPool,
99+
) -> None:
100+
target: typing.Final = "copy_records_tx"
101+
await _setup_target_table(psql_pool, target)
102+
try:
103+
records = [(100, "tx-row", 0.0, datetime(2026, 5, 1, tzinfo=timezone.utc))]
104+
105+
async with (
106+
psql_pool.acquire() as connection,
107+
connection.transaction() as tx,
108+
):
109+
inserted = await tx.copy_records_to_table(
110+
table_name=target,
111+
records=records,
112+
)
113+
114+
assert inserted == 1
115+
116+
async with psql_pool.acquire() as connection:
117+
result = await connection.execute(
118+
f"SELECT COUNT(*) AS c FROM {target}",
119+
)
120+
assert result.result()[0]["c"] == 1
121+
finally:
122+
await _drop_target_table(psql_pool, target)
123+
124+
125+
async def test_copy_records_to_table_rejects_record_arity_mismatch(
126+
psql_pool: ConnectionPool,
127+
) -> None:
128+
target: typing.Final = "copy_records_mismatch"
129+
await _setup_target_table(psql_pool, target)
130+
try:
131+
records = [(1, "missing-rest")] # table has 4 columns
132+
133+
async with psql_pool.acquire() as connection:
134+
with pytest.raises(PyToRustValueMappingError):
135+
await connection.copy_records_to_table(
136+
table_name=target,
137+
records=records,
138+
)
139+
finally:
140+
await _drop_target_table(psql_pool, target)
141+
142+
143+
async def test_copy_records_to_table_uses_schema_qualifier(
144+
psql_pool: ConnectionPool,
145+
) -> None:
146+
schema: typing.Final = "copy_records_schema"
147+
target: typing.Final = "tbl"
148+
149+
async with psql_pool.acquire() as connection:
150+
await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE")
151+
await connection.execute(f"CREATE SCHEMA {schema}")
152+
await connection.execute(
153+
f"CREATE TABLE {schema}.{target} (id INTEGER, label TEXT)",
154+
)
155+
156+
try:
157+
records = [(1, "schema-a"), (2, "schema-b")]
158+
async with psql_pool.acquire() as connection:
159+
inserted = await connection.copy_records_to_table(
160+
table_name=target,
161+
records=records,
162+
schema_name=schema,
163+
)
164+
165+
assert inserted == len(records)
166+
167+
async with psql_pool.acquire() as connection:
168+
result = await connection.execute(
169+
f"SELECT id, label FROM {schema}.{target} ORDER BY id",
170+
)
171+
assert [(r["id"], r["label"]) for r in result.result()] == records
172+
finally:
173+
async with psql_pool.acquire() as connection:
174+
await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE")

src/driver/common.rs

Lines changed: 139 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ use super::{
1010
use pyo3::{pymethods, Py, PyAny};
1111

1212
use crate::{
13-
connection::traits::CloseTransaction,
13+
connection::traits::{CloseTransaction, Connection as _},
1414
exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError},
15+
value_converter::{dto::enums::PythonDTO, from_python::from_python_typed},
1516
};
1617

1718
use bytes::BytesMut;
1819
use futures_util::pin_mut;
19-
use pyo3::{buffer::PyBuffer, Python};
20-
use tokio_postgres::binary_copy::BinaryCopyInWriter;
20+
use pyo3::{buffer::PyBuffer, types::PyAnyMethods, Python};
21+
use tokio_postgres::{binary_copy::BinaryCopyInWriter, types::ToSql};
2122

2223
use crate::format_helpers::quote_ident;
2324

@@ -320,3 +321,138 @@ macro_rules! impl_binary_copy_method {
320321

321322
impl_binary_copy_method!(Connection);
322323
impl_binary_copy_method!(Transaction);
324+
325+
macro_rules! impl_copy_records_method {
326+
($name:ident) => {
327+
#[pymethods]
328+
impl $name {
329+
/// Copy a list of records into a table using the COPY FROM STDIN
330+
/// binary protocol.
331+
///
332+
/// Column types are introspected from the target table, so callers
333+
/// pass Python values directly (the same conversions used by
334+
/// `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`.
335+
///
336+
/// # Errors
337+
/// May return error if there is some problem with DB communication,
338+
/// the table cannot be introspected, or a value cannot be converted.
339+
#[pyo3(signature = (table_name, records, columns=None, schema_name=None))]
340+
pub async fn copy_records_to_table(
341+
self_: pyo3::Py<Self>,
342+
table_name: String,
343+
records: Py<PyAny>,
344+
columns: Option<Vec<String>>,
345+
schema_name: Option<String>,
346+
) -> PSQLPyResult<u64> {
347+
let (db_client, raw_records) = Python::with_gil(
348+
|gil| -> PSQLPyResult<(Option<_>, Vec<Vec<pyo3::Py<PyAny>>>)> {
349+
let db_client = self_.borrow(gil).conn.clone();
350+
351+
let Some(db_client) = db_client else {
352+
return Ok((None, Vec::new()));
353+
};
354+
355+
let bound = records.bind(gil);
356+
let mut rows: Vec<Vec<pyo3::Py<PyAny>>> = Vec::new();
357+
for item in bound.try_iter()? {
358+
let row = item?;
359+
let mut row_vec: Vec<pyo3::Py<PyAny>> = Vec::new();
360+
for cell in row.try_iter()? {
361+
row_vec.push(cell?.unbind());
362+
}
363+
rows.push(row_vec);
364+
}
365+
366+
Ok((Some(db_client), rows))
367+
},
368+
)?;
369+
370+
let Some(db_client) = db_client else {
371+
return Ok(0);
372+
};
373+
374+
let full_table_name = match schema_name {
375+
Some(ref schema) => {
376+
format!("{}.{}", quote_ident(schema), quote_ident(&table_name))
377+
}
378+
None => quote_ident(&table_name),
379+
};
380+
381+
let columns_sql = match columns {
382+
Some(ref cols) if !cols.is_empty() => Some(
383+
cols.iter()
384+
.map(|c| quote_ident(c))
385+
.collect::<Vec<_>>()
386+
.join(", "),
387+
),
388+
_ => None,
389+
};
390+
391+
let introspect_qs = match &columns_sql {
392+
Some(cols) => format!("SELECT {} FROM {} WHERE false", cols, full_table_name),
393+
None => format!("SELECT * FROM {} WHERE false", full_table_name),
394+
};
395+
396+
let read_conn_g = db_client.read().await;
397+
398+
let stmt = read_conn_g.prepare(&introspect_qs, false).await?;
399+
let column_types: Vec<tokio_postgres::types::Type> =
400+
stmt.columns().iter().map(|c| c.type_().clone()).collect();
401+
402+
if column_types.is_empty() {
403+
return Err(RustPSQLDriverError::PyToRustValueConversionError(
404+
"Cannot introspect column types from target table".into(),
405+
));
406+
}
407+
408+
let typed_rows: Vec<Vec<PythonDTO>> =
409+
Python::with_gil(|gil| -> PSQLPyResult<Vec<Vec<PythonDTO>>> {
410+
let mut typed: Vec<Vec<PythonDTO>> = Vec::with_capacity(raw_records.len());
411+
for (row_idx, row) in raw_records.iter().enumerate() {
412+
if row.len() != column_types.len() {
413+
return Err(RustPSQLDriverError::PyToRustValueConversionError(
414+
format!(
415+
"Record at index {} has {} fields, expected {}",
416+
row_idx,
417+
row.len(),
418+
column_types.len()
419+
),
420+
));
421+
}
422+
let mut row_dto: Vec<PythonDTO> = Vec::with_capacity(row.len());
423+
for (cell, ty) in row.iter().zip(column_types.iter()) {
424+
row_dto.push(from_python_typed(cell.bind(gil), ty)?);
425+
}
426+
typed.push(row_dto);
427+
}
428+
Ok(typed)
429+
})?;
430+
431+
let copy_qs = match &columns_sql {
432+
Some(cols) => format!(
433+
"COPY {}({}) FROM STDIN (FORMAT binary)",
434+
full_table_name, cols
435+
),
436+
None => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name),
437+
};
438+
439+
let sink = read_conn_g.copy_in(&copy_qs).await?;
440+
let writer = BinaryCopyInWriter::new(sink, &column_types);
441+
pin_mut!(writer);
442+
443+
for row in &typed_rows {
444+
let row_refs: Vec<&(dyn ToSql + Sync)> =
445+
row.iter().map(|v| v as &(dyn ToSql + Sync)).collect();
446+
writer.as_mut().write(&row_refs).await?;
447+
}
448+
449+
let rows_created = writer.as_mut().finish().await?;
450+
451+
Ok(rows_created)
452+
}
453+
}
454+
};
455+
}
456+
457+
impl_copy_records_method!(Connection);
458+
impl_copy_records_method!(Transaction);

0 commit comments

Comments
 (0)