Skip to content

Commit 3b8746f

Browse files
committed
implement at the cpp level
1 parent 0ff6515 commit 3b8746f

7 files changed

Lines changed: 115 additions & 26 deletions

File tree

_duckdb-stubs/__init__.pyi

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ if typing.TYPE_CHECKING:
1313
import pandas
1414
import pyarrow.lib
1515
from collections.abc import Callable, Iterable, Sequence, Mapping
16-
from duckdb import sqltypes, func
16+
from duckdb import sqltypes, func, template
1717
from builtins import list as lst # needed to avoid mypy error on DuckDBPyRelation.list method shadowing
1818

1919
# the field_ids argument to to_parquet and write_parquet has a recursive structure
@@ -241,8 +241,12 @@ class DuckDBPyConnection:
241241
def dtype(self, type_str: str) -> sqltypes.DuckDBPyType: ...
242242
def duplicate(self) -> DuckDBPyConnection: ...
243243
def enum_type(self, name: str, type: sqltypes.DuckDBPyType, values: lst[typing.Any]) -> sqltypes.DuckDBPyType: ...
244-
def execute(self, query: Statement | str, parameters: object = None) -> DuckDBPyConnection: ...
245-
def executemany(self, query: Statement | str, parameters: object = None) -> DuckDBPyConnection: ...
244+
def execute(
245+
self, query: Statement | str | template.SqlTemplate | template.CompiledSql, parameters: object = None
246+
) -> DuckDBPyConnection: ...
247+
def executemany(
248+
self, query: Statement | str | template.SqlTemplate | template.CompiledSql, parameters: object = None
249+
) -> DuckDBPyConnection: ...
246250
def extract_statements(self, query: str) -> lst[Statement]: ...
247251
def fetch_arrow_table(self, rows_per_batch: typing.SupportsInt = 1000000) -> pyarrow.lib.Table:
248252
"""Deprecated: use to_arrow_table() instead."""
@@ -331,7 +335,9 @@ class DuckDBPyConnection:
331335
union_by_name: bool = False,
332336
compression: str | None = None,
333337
) -> DuckDBPyRelation: ...
334-
def from_query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ...
338+
def from_query(
339+
self, query: str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None
340+
) -> DuckDBPyRelation: ...
335341
def get_table_names(self, query: str, *, qualified: bool = False) -> set[str]: ...
336342
def install_extension(
337343
self,
@@ -360,7 +366,9 @@ class DuckDBPyConnection:
360366
def pl(
361367
self, rows_per_batch: typing.SupportsInt = 1000000, *, lazy: bool = False
362368
) -> polars.DataFrame | polars.LazyFrame: ...
363-
def query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ...
369+
def query(
370+
self, query: str | template.SqlTemplate | template.CompiledSql, *, alias: str = "", params: object = None
371+
) -> DuckDBPyRelation: ...
364372
def query_progress(self) -> float: ...
365373
def read_csv(
366374
self,
@@ -462,7 +470,13 @@ class DuckDBPyConnection:
462470
def row_type(
463471
self, fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType]
464472
) -> sqltypes.DuckDBPyType: ...
465-
def sql(self, query: Statement | str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ...
473+
def sql(
474+
self,
475+
query: Statement | str | template.SqlTemplate | template.CompiledSql,
476+
*,
477+
alias: str = "",
478+
params: object = None,
479+
) -> DuckDBPyRelation: ...
466480
def sqltype(self, type_str: str) -> sqltypes.DuckDBPyType: ...
467481
def string_type(self, collation: str = "") -> sqltypes.DuckDBPyType: ...
468482
def struct_type(
@@ -1160,7 +1174,7 @@ def enum_type(
11601174
connection: DuckDBPyConnection | None = None,
11611175
) -> sqltypes.DuckDBPyType: ...
11621176
def execute(
1163-
query: Statement | str,
1177+
query: Statement | str | template.SqlTemplate | template.CompiledSql,
11641178
parameters: object = None,
11651179
*,
11661180
connection: DuckDBPyConnection | None = None,
@@ -1282,7 +1296,7 @@ def from_parquet(
12821296
connection: DuckDBPyConnection | None = None,
12831297
) -> DuckDBPyRelation: ...
12841298
def from_query(
1285-
query: Statement | str,
1299+
query: Statement | str | template.SqlTemplate | template.CompiledSql,
12861300
*,
12871301
alias: str = "",
12881302
params: object = None,
@@ -1350,7 +1364,7 @@ def project(
13501364
df: pandas.DataFrame, *args: _ExpressionLike, groups: str = "", connection: DuckDBPyConnection | None = None
13511365
) -> DuckDBPyRelation: ...
13521366
def query(
1353-
query: Statement | str,
1367+
query: Statement | str | template.SqlTemplate | template.CompiledSql,
13541368
*,
13551369
alias: str = "",
13561370
params: object = None,
@@ -1474,7 +1488,7 @@ def row_type(
14741488
def rowcount(*, connection: DuckDBPyConnection | None = None) -> int: ...
14751489
def set_default_connection(connection: DuckDBPyConnection) -> None: ...
14761490
def sql(
1477-
query: Statement | str,
1491+
query: Statement | str | template.SqlTemplate | template.CompiledSql,
14781492
*,
14791493
alias: str = "",
14801494
params: object = None,

src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
363363
FunctionNullHandling null_handling, PythonExceptionHandling exception_handling,
364364
bool side_effects);
365365
void RegisterArrowObject(const py::object &arrow_object, const string &name);
366+
pair<string, py::object> ExtractCompiledSqlAndParams(const py::object &query, py::object params);
366367
vector<unique_ptr<SQLStatement>> GetStatements(const py::object &query);
367368

368369
static PythonEnvironmentType environment;

src/duckdb_py/pyconnection.cpp

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -708,11 +708,74 @@ shared_ptr<DuckDBPyConnection> DuckDBPyConnection::ExecuteFromString(const strin
708708
return Execute(py::str(query));
709709
}
710710

711+
pair<string, py::object> DuckDBPyConnection::ExtractCompiledSqlAndParams(const py::object &query, py::object params) {
712+
py::object compiled = query;
713+
if (!py::hasattr(compiled, "sql") || !py::hasattr(compiled, "params")) {
714+
if (!py::hasattr(query, "compile")) {
715+
return {string(), params};
716+
}
717+
compiled = query.attr("compile")();
718+
}
719+
720+
if (!py::hasattr(compiled, "sql") || !py::hasattr(compiled, "params")) {
721+
return {string(), params};
722+
}
723+
724+
auto compiled_sql = py::cast<string>(compiled.attr("sql"));
725+
auto compiled_params_obj = compiled.attr("params");
726+
if (!py::is_dict_like(compiled_params_obj)) {
727+
throw InvalidInputException("Compiled SQL parameters must be a dictionary");
728+
}
729+
730+
auto compiled_params = py::cast<py::dict>(compiled_params_obj);
731+
if (compiled_params.empty()) {
732+
return {compiled_sql, params};
733+
}
734+
735+
if (params.is_none()) {
736+
return {compiled_sql, compiled_params};
737+
}
738+
739+
if (py::is_dict_like(params)) {
740+
auto merged_params = py::dict();
741+
for (auto &item : compiled_params) {
742+
merged_params[item.first] = item.second;
743+
}
744+
auto provided_params = py::cast<py::dict>(params);
745+
for (auto &item : provided_params) {
746+
if (merged_params.contains(item.first)) {
747+
throw py::value_error("Cannot merge compiled SQL parameters with duplicate parameter names");
748+
}
749+
merged_params[item.first] = item.second;
750+
}
751+
return {compiled_sql, merged_params};
752+
}
753+
754+
if (py::is_list_like(params)) {
755+
if (py::len(params) == 0) {
756+
return {compiled_sql, compiled_params};
757+
}
758+
throw py::value_error("Cannot merge compiled SQL named parameters with positional parameters");
759+
}
760+
761+
throw InvalidInputException("Prepared parameters can only be passed as a list or a dictionary");
762+
}
763+
711764
shared_ptr<DuckDBPyConnection> DuckDBPyConnection::Execute(const py::object &query, py::object params) {
712765
py::gil_scoped_acquire gil;
713766
con.SetResult(nullptr);
714767

715-
auto statements = GetStatements(query);
768+
auto normalized_query = ExtractCompiledSqlAndParams(query, params);
769+
auto &compiled_sql = normalized_query.first;
770+
auto &merged_params = normalized_query.second;
771+
vector<unique_ptr<SQLStatement>> statements;
772+
if (!compiled_sql.empty()) {
773+
statements = GetStatements(py::str(compiled_sql));
774+
params = merged_params;
775+
} else {
776+
statements = GetStatements(query);
777+
}
778+
716779
if (statements.empty()) {
717780
// TODO: should we throw?
718781
return nullptr;
@@ -1603,7 +1666,17 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::RunQuery(const py::object &quer
16031666
alias = "unnamed_relation_" + StringUtil::GenerateRandomName(16);
16041667
}
16051668

1606-
auto statements = GetStatements(query);
1669+
auto normalized_query = ExtractCompiledSqlAndParams(query, params);
1670+
auto &compiled_sql = normalized_query.first;
1671+
auto &merged_params = normalized_query.second;
1672+
vector<unique_ptr<SQLStatement>> statements;
1673+
if (!compiled_sql.empty()) {
1674+
statements = GetStatements(py::str(compiled_sql));
1675+
params = merged_params;
1676+
} else {
1677+
statements = GetStatements(query);
1678+
}
1679+
16071680
if (statements.empty()) {
16081681
// TODO: should we throw?
16091682
return nullptr;
@@ -1616,7 +1689,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::RunQuery(const py::object &quer
16161689

16171690
// Attempt to create a Relation for lazy execution if possible
16181691
shared_ptr<Relation> relation;
1619-
bool has_params = !py::none().is(params) && py::len(params) > 0;
1692+
bool has_params = !params.is_none() && py::len(params) > 0;
16201693
if (!has_params) {
16211694
// No params (or empty params) — use lazy QueryRelation path
16221695
{

src/duckdb_py/pyexpression/initialize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ void DuckDBPyExpression::Initialize(py::module_ &m) {
314314
Print the stringified version of the expression.
315315
)";
316316
expression.def("show", &DuckDBPyExpression::Print, docs);
317+
expression.def("__duckdb_template__", &DuckDBPyExpression::ToString);
317318

318319
docs = R"(
319320
Set the order by modifier to ASCENDING.

src/duckdb_py/pyrelation/initialize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ void DuckDBPyRelation::Initialize(py::handle &m) {
342342
.def("show", &DuckDBPyRelation::Print, "Display a summary of the data", py::kw_only(),
343343
py::arg("max_width") = py::none(), py::arg("max_rows") = py::none(), py::arg("max_col_width") = py::none(),
344344
py::arg("null_value") = py::none(), py::arg("render_mode") = py::none())
345+
.def("__duckdb_template__", &DuckDBPyRelation::ToSQL)
345346
.def("__str__", &DuckDBPyRelation::ToString)
346347
.def("__repr__", &DuckDBPyRelation::ToString);
347348

src/duckdb_py/typing/pytype.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ void DuckDBPyType::Initialize(py::handle &m) {
329329
auto type_module = py::class_<DuckDBPyType, shared_ptr<DuckDBPyType>>(m, "DuckDBPyType", py::module_local());
330330

331331
type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object");
332+
type_module.def("__duckdb_template__", &DuckDBPyType::ToString);
332333
type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"),
333334
py::is_operator());
334335
type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"),

tests/fast/test_template_e2e.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def test_module_level_sql_apis_accept_sql_template() -> None:
4141

4242
def test_module_level_execute_accepts_sql_template() -> None:
4343
conn = duckdb.connect()
44-
query = template("SELECT ", "hello")
45-
assert duckdb.execute(query, connection=conn).fetchone() == ("hello",)
44+
query = template("SELECT ", 5)
45+
assert duckdb.execute(query, connection=conn).fetchone() == (5,)
4646

4747

4848
def test_connection_sql_accepts_alias_kwarg_with_template() -> None:
@@ -61,11 +61,19 @@ def test_connection_sql_template_can_merge_additional_params() -> None:
6161

6262
def test_connection_sql_template_param_name_conflict_with_additional_params_raises() -> None:
6363
conn = duckdb.connect()
64-
query = template("SELECT ", param(10, "num"), " + $num")
64+
query = template("SELECT ", param(10, "num", exact=True), " + $num")
6565
with pytest.raises((duckdb.InvalidInputException, ValueError)):
6666
conn.sql(query, params={"num": 5}).fetchall()
6767

6868

69+
def test_cant_merge_with_positional_params() -> None:
70+
conn = duckdb.connect()
71+
# It doesn't even have a name, but still should error
72+
query = template("SELECT ", 10, " + ?")
73+
with pytest.raises(ValueError, match="Cannot merge compiled SQL named parameters with positional parameters"):
74+
conn.sql(query, params=[5]).fetchall()
75+
76+
6977
def test_sql_apis_accept_compiled_sql() -> None:
7078
conn = duckdb.connect()
7179
compiled = template("SELECT i FROM range(5) t(i) WHERE i >= ", 3, " ORDER BY i").compile()
@@ -83,16 +91,6 @@ def test_relation_interpolation_works_end_to_end() -> None:
8391
assert conn.sql(query).fetchall() == [(0,), (2,), (4,)]
8492

8593

86-
def test_interpolated_strings_are_parameterized_by_default() -> None:
87-
conn = duckdb.connect()
88-
conn.execute("CREATE TABLE names(name VARCHAR)")
89-
conn.execute("INSERT INTO names VALUES ('alice'), ('bob')")
90-
91-
untrusted = "alice' OR 1=1 --"
92-
query = template("SELECT count(*) FROM names WHERE name = ", untrusted)
93-
assert conn.sql(query).fetchone() == (0,)
94-
95-
9694
def test_builtin_duckdbpytype_object_interpolates_in_template() -> None:
9795
conn = duckdb.connect()
9896
integer_type = duckdb.sqltype("INTEGER")

0 commit comments

Comments
 (0)