Skip to content

Commit a9bb62d

Browse files
committed
feat: implement table valued functions / user defined table functions
1 parent b494d1c commit a9bb62d

8 files changed

Lines changed: 567 additions & 4 deletions

File tree

duckdb/functional/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from _duckdb.functional import (
22
FunctionNullHandling,
33
PythonUDFType,
4+
PythonTVFType,
45
SPECIAL,
56
DEFAULT,
67
NATIVE,
7-
ARROW
8+
ARROW,
9+
TUPLES,
10+
ARROW_TABLE
811
)
912

1013
__all__ = [
1114
"FunctionNullHandling",
1215
"PythonUDFType",
16+
"PythonTVFType",
1317
"SPECIAL",
1418
"DEFAULT",
1519
"NATIVE",
16-
"ARROW"
20+
"ARROW",
21+
"TUPLES",
22+
"ARROW_TABLE"
1723
]

scripts/connection_methods.json

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,51 @@
107107
],
108108
"return": "DuckDBPyConnection"
109109
},
110+
{
111+
"name": "create_table_function",
112+
"function": "RegisterTableFunction",
113+
"docs": "Register a table valued function via Callable",
114+
"args": [
115+
{
116+
"name": "name",
117+
"type": "str"
118+
},
119+
{
120+
"name": "callable",
121+
"type": "Callable"
122+
}
123+
],
124+
"kwargs": [
125+
{
126+
"name": "parameters",
127+
"type": "Optional[Any]",
128+
"default": "None"
129+
},
130+
{
131+
"name": "schema",
132+
"type": "Optional[Any]",
133+
"default": "None"
134+
},
135+
{
136+
"name": "type",
137+
"type": "Optional[PythonTVFType]",
138+
"default": "PythonTVFType.TUPLES"
139+
}
140+
],
141+
"return": "DuckDBPyConnection"
142+
},
143+
{
144+
"name": "unregister_table_function",
145+
"function": "UnregisterTableFunction",
146+
"docs": "Unregister a table valued function",
147+
"args": [
148+
{
149+
"name": "name",
150+
"type": "str"
151+
}
152+
],
153+
"return": "DuckDBPyConnection"
154+
},
110155
{
111156
"name": [
112157
"sqltype",

src/duckdb_py/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ add_library(python_src OBJECT
2828
python_import_cache.cpp
2929
python_replacement_scan.cpp
3030
python_udf.cpp
31+
python_tvf.cpp
3132
)
3233

3334
target_link_libraries(python_src PRIVATE _duckdb_dependencies)

src/duckdb_py/functional/functional.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ void DuckDBPyFunctional::Initialize(py::module_ &parent) {
1111
.value("ARROW", duckdb::PythonUDFType::ARROW)
1212
.export_values();
1313

14+
py::enum_<duckdb::PythonTVFType>(m, "PythonTVFType")
15+
.value("TUPLES", duckdb::PythonTVFType::TUPLES)
16+
.value("ARROW_TABLE", duckdb::PythonTVFType::ARROW_TABLE)
17+
.export_values();
18+
1419
py::enum_<duckdb::FunctionNullHandling>(m, "FunctionNullHandling")
1520
.value("DEFAULT", duckdb::FunctionNullHandling::DEFAULT_NULL_HANDLING)
1621
.value("SPECIAL", duckdb::FunctionNullHandling::SPECIAL_HANDLING)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#pragma once
2+
3+
#include "duckdb/common/common.hpp"
4+
#include "duckdb/common/exception.hpp"
5+
#include "duckdb/common/string_util.hpp"
6+
7+
using duckdb::InvalidInputException;
8+
using duckdb::string;
9+
using duckdb::StringUtil;
10+
11+
namespace duckdb {
12+
13+
enum class PythonTVFType : uint8_t { TUPLES, ARROW_TABLE };
14+
15+
} // namespace duckdb
16+
17+
using duckdb::PythonTVFType;
18+
19+
namespace py = pybind11;
20+
21+
static PythonTVFType PythonTVFTypeFromString(const string &type) {
22+
auto ltype = StringUtil::Lower(type);
23+
if (ltype.empty() || ltype == "tuples") {
24+
return PythonTVFType::TUPLES;
25+
} else if (ltype == "arrow_table") {
26+
return PythonTVFType::ARROW_TABLE;
27+
} else {
28+
throw InvalidInputException("'%s' is not a recognized type for 'tvf_type'", type);
29+
}
30+
}
31+
32+
static PythonTVFType PythonTVFTypeFromInteger(int64_t value) {
33+
if (value == 0) {
34+
return PythonTVFType::TUPLES;
35+
} else if (value == 1) {
36+
return PythonTVFType::ARROW_TABLE;
37+
} else {
38+
throw InvalidInputException("'%d' is not a recognized type for 'tvf_type'", value);
39+
}
40+
}
41+
42+
namespace PYBIND11_NAMESPACE {
43+
namespace detail {
44+
45+
template <>
46+
struct type_caster<PythonTVFType> : public type_caster_base<PythonTVFType> {
47+
using base = type_caster_base<PythonTVFType>;
48+
PythonTVFType tmp;
49+
50+
public:
51+
bool load(handle src, bool convert) {
52+
if (base::load(src, convert)) {
53+
return true;
54+
} else if (py::isinstance<py::str>(src)) {
55+
tmp = PythonTVFTypeFromString(py::str(src));
56+
value = &tmp;
57+
return true;
58+
} else if (py::isinstance<py::int_>(src)) {
59+
tmp = PythonTVFTypeFromInteger(src.cast<int64_t>());
60+
value = &tmp;
61+
return true;
62+
}
63+
return false;
64+
}
65+
66+
static handle cast(PythonTVFType src, return_value_policy policy, handle parent) {
67+
return base::cast(src, policy, parent);
68+
}
69+
};
70+
71+
} // namespace detail
72+
} // namespace PYBIND11_NAMESPACE

src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "duckdb/function/scalar_function.hpp"
2424
#include "duckdb_python/pybind11/conversions/exception_handling_enum.hpp"
2525
#include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp"
26+
#include "duckdb_python/pybind11/conversions/python_tvf_type_enum.hpp"
2627
#include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp"
2728
#include "duckdb/common/shared_ptr.hpp"
2829

@@ -169,6 +170,8 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
169170
//! MemoryFileSystem used to temporarily store file-like objects for reading
170171
shared_ptr<ModifiedMemoryFileSystem> internal_object_filesystem;
171172
case_insensitive_map_t<unique_ptr<ExternalDependency>> registered_functions;
173+
case_insensitive_map_t<unique_ptr<ExternalDependency>> registered_table_functions;
174+
172175
case_insensitive_set_t registered_objects;
173176

174177
public:
@@ -232,6 +235,13 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
232235
PythonExceptionHandling exception_handling = PythonExceptionHandling::FORWARD_ERROR,
233236
bool side_effects = false);
234237

238+
shared_ptr<DuckDBPyConnection> RegisterTableFunction(const string &name, const py::function &function,
239+
const py::object &parameters = py::none(),
240+
const py::object &schema = py::none(),
241+
PythonTVFType type = PythonTVFType::TUPLES);
242+
243+
shared_ptr<DuckDBPyConnection> UnregisterTableFunction(const string &name);
244+
235245
shared_ptr<DuckDBPyConnection> UnregisterUDF(const string &name);
236246

237247
shared_ptr<DuckDBPyConnection> ExecuteMany(const py::object &query, py::object params = py::list());
@@ -355,6 +365,11 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
355365
const shared_ptr<DuckDBPyType> &return_type, bool vectorized,
356366
FunctionNullHandling null_handling, PythonExceptionHandling exception_handling,
357367
bool side_effects);
368+
369+
duckdb::TableFunction CreateTableFunctionFromCallable(const std::string &name, const py::function &callable,
370+
const py::object &parameters, const py::object &schema,
371+
PythonTVFType type);
372+
358373
void RegisterArrowObject(const py::object &arrow_object, const string &name);
359374
vector<unique_ptr<SQLStatement>> GetStatements(const py::object &query);
360375

src/duckdb_py/pyconnection.cpp

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,6 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud
393393
auto scalar_function = CreateScalarUDF(name, udf, parameters_p, return_type_p, type == PythonUDFType::ARROW,
394394
null_handling, exception_handling, side_effects);
395395
CreateScalarFunctionInfo info(scalar_function);
396-
397396
context.RegisterFunction(info);
398397

399398
auto dependency = make_uniq<ExternalDependency>();
@@ -403,6 +402,57 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud
403402
return shared_from_this();
404403
}
405404

405+
shared_ptr<DuckDBPyConnection> DuckDBPyConnection::RegisterTableFunction(const string &name,
406+
const py::function &function,
407+
const py::object &parameters,
408+
const py::object &schema,
409+
PythonTVFType type) {
410+
411+
auto &connection = con.GetConnection();
412+
auto &context = *connection.context;
413+
414+
if (context.transaction.HasActiveTransaction()) {
415+
context.CancelTransaction();
416+
}
417+
418+
if (registered_table_functions.find(name) != registered_table_functions.end()) {
419+
throw NotImplementedException("A table function by the name of '%s' is already registered, "
420+
"please unregister it first",
421+
name);
422+
}
423+
424+
auto table_function = CreateTableFunctionFromCallable(name, function, parameters, schema, type);
425+
CreateTableFunctionInfo info(table_function);
426+
427+
// re-registration: changing the callable to another
428+
info.on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT;
429+
430+
context.RegisterFunction(info);
431+
432+
auto dependency = make_uniq<ExternalDependency>();
433+
dependency->AddDependency("function", PythonDependencyItem::Create(function));
434+
registered_table_functions[name] = std::move(dependency);
435+
436+
return shared_from_this();
437+
}
438+
439+
shared_ptr<DuckDBPyConnection> DuckDBPyConnection::UnregisterTableFunction(const string &name) {
440+
auto entry = registered_table_functions.find(name);
441+
if (entry == registered_table_functions.end()) {
442+
throw InvalidInputException(
443+
"No table function by the name of '%s' was found in the list of registered table functions", name);
444+
}
445+
446+
auto &connection = con.GetConnection();
447+
auto &context = *connection.context;
448+
449+
// Remove from our registry.
450+
// TODO: Callable still exists in the function catalog, since duckdb doesn't (yet?) support removal
451+
registered_table_functions.erase(entry);
452+
453+
return shared_from_this();
454+
}
455+
406456
void DuckDBPyConnection::Initialize(py::handle &m) {
407457
auto connection_module =
408458
py::class_<DuckDBPyConnection, shared_ptr<DuckDBPyConnection>>(m, "DuckDBPyConnection", py::module_local());
@@ -411,6 +461,14 @@ void DuckDBPyConnection::Initialize(py::handle &m) {
411461
.def("__exit__", &DuckDBPyConnection::Exit, py::arg("exc_type"), py::arg("exc"), py::arg("traceback"));
412462
connection_module.def("__del__", &DuckDBPyConnection::Close);
413463

464+
connection_module.def("create_table_function", &DuckDBPyConnection::RegisterTableFunction,
465+
"Register a table valued function via Callable", py::arg("name"), py::arg("callable"),
466+
py::arg("parameters") = py::none(), py::arg("schema") = py::none(),
467+
py::arg("type") = PythonTVFType::TUPLES);
468+
469+
connection_module.def("unregister_table_function", &DuckDBPyConnection::UnregisterTableFunction,
470+
"Unregister a table valued function", py::arg("name"));
471+
414472
InitializeConnectionMethods(connection_module);
415473
connection_module.def_property_readonly("description", &DuckDBPyConnection::GetDescription,
416474
"Get result set attributes, mainly column names");
@@ -1575,7 +1633,12 @@ unique_ptr<DuckDBPyRelation> DuckDBPyConnection::RunQuery(const py::object &quer
15751633
}
15761634
if (res->type == QueryResultType::STREAM_RESULT) {
15771635
auto &stream_result = res->Cast<StreamQueryResult>();
1578-
res = stream_result.Materialize();
1636+
{
1637+
// Release the GIL, as Materialize *may* need the GIL (TVFs, for instance)
1638+
D_ASSERT(py::gil_check());
1639+
py::gil_scoped_release release;
1640+
res = stream_result.Materialize();
1641+
}
15791642
}
15801643
auto &materialized_result = res->Cast<MaterializedQueryResult>();
15811644
relation = make_shared_ptr<MaterializedRelation>(connection.context, materialized_result.TakeCollection(),
@@ -1826,6 +1889,7 @@ void DuckDBPyConnection::Close() {
18261889
// https://peps.python.org/pep-0249/#Connection.close
18271890
cursors.ClearCursors();
18281891
registered_functions.clear();
1892+
registered_table_functions.clear();
18291893
}
18301894

18311895
void DuckDBPyConnection::Interrupt() {

0 commit comments

Comments
 (0)