Skip to content

Commit 32b37c1

Browse files
committed
Use non-view registry for Python register
1 parent ab63b5f commit 32b37c1

5 files changed

Lines changed: 97 additions & 12 deletions

File tree

src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
169169
//! MemoryFileSystem used to temporarily store file-like objects for reading
170170
shared_ptr<ModifiedMemoryFileSystem> internal_object_filesystem;
171171
case_insensitive_map_t<unique_ptr<ExternalDependency>> registered_functions;
172-
case_insensitive_set_t registered_objects;
173172

174173
public:
175174
explicit DuckDBPyConnection() {

src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,25 @@
44
#include "duckdb/common/case_insensitive_map.hpp"
55
#include "duckdb/parser/tableref.hpp"
66
#include "duckdb/function/replacement_scan.hpp"
7+
#include "duckdb_python/python_dependency.hpp"
78
#include "duckdb_python/pybind11/pybind_wrapper.hpp"
89

910
namespace duckdb {
1011

12+
class PythonRegisteredObjectState : public ClientContextState {
13+
public:
14+
static constexpr const char *Key = "python_registered_objects";
15+
16+
void Register(const string &name, const py::object &object);
17+
void Unregister(const string &name);
18+
py::object Get(const string &name);
19+
bool Contains(const string &name);
20+
21+
private:
22+
mutex lock;
23+
case_insensitive_map_t<shared_ptr<DependencyItem>> registered_objects;
24+
};
25+
1126
struct PythonReplacementScan {
1227
public:
1328
static unique_ptr<TableRef> Replace(ClientContext &context, ReplacementScanInput &input,

src/duckdb_py/pyconnection.cpp

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "duckdb_python/pyconnection/pyconnection.hpp"
22

3+
#include "duckdb/catalog/catalog.hpp"
4+
#include "duckdb/catalog/default/default_types.hpp"
35
#include "duckdb/common/arrow/arrow.hpp"
46
#include "duckdb/common/enums/profiler_format.hpp"
57
#include "duckdb/common/types.hpp"
@@ -52,6 +54,17 @@ shared_ptr<PythonImportCache> DuckDBPyConnection::import_cache = nullptr;
5254
PythonEnvironmentType DuckDBPyConnection::environment = PythonEnvironmentType::NORMAL; // NOLINT: allow global
5355
std::string DuckDBPyConnection::formatted_python_version = "";
5456

57+
static shared_ptr<PythonRegisteredObjectState> GetPythonRegisteredObjectState(ClientContext &context) {
58+
return context.registered_state->GetOrCreate<PythonRegisteredObjectState>(PythonRegisteredObjectState::Key);
59+
}
60+
61+
static bool TemporaryObjectExists(ClientContext &context, const string &name) {
62+
auto &catalog = Catalog::GetCatalog(context, TEMP_CATALOG);
63+
EntryLookupInfo lookup_info(CatalogType::TABLE_ENTRY, name);
64+
auto entry = catalog.GetEntry(context, DEFAULT_SCHEMA, lookup_info, OnEntryNotFound::RETURN_NULL);
65+
return entry != nullptr;
66+
}
67+
5568
DuckDBPyConnection::~DuckDBPyConnection() {
5669
try {
5770
py::gil_scoped_release gil;
@@ -743,11 +756,16 @@ shared_ptr<DuckDBPyConnection> DuckDBPyConnection::RegisterPythonObject(const st
743756
const py::object &python_object) {
744757
auto &connection = con.GetConnection();
745758
auto &client = *connection.context;
746-
auto object = PythonReplacementScan::ReplacementObject(python_object, name, client);
747-
auto view_rel = make_shared_ptr<ViewRelation>(connection.context, std::move(object), name);
748-
bool replace = registered_objects.count(name);
749-
view_rel->CreateView(name, replace, true);
750-
registered_objects.insert(name);
759+
auto registered_state = GetPythonRegisteredObjectState(client);
760+
if (!registered_state->Contains(name)) {
761+
bool temp_object_exists = false;
762+
client.RunFunctionInTransaction([&]() { temp_object_exists = TemporaryObjectExists(client, name); }, false);
763+
if (temp_object_exists) {
764+
throw CatalogException("View with name \"%s\" already exists!", name);
765+
}
766+
}
767+
PythonReplacementScan::ReplacementObject(python_object, name, client);
768+
registered_state->Register(name, python_object);
751769
return shared_from_this();
752770
}
753771

@@ -1821,15 +1839,12 @@ unordered_set<string> DuckDBPyConnection::GetTableNames(const string &query, boo
18211839

18221840
shared_ptr<DuckDBPyConnection> DuckDBPyConnection::UnregisterPythonObject(const string &name) {
18231841
auto &connection = con.GetConnection();
1824-
if (!registered_objects.count(name)) {
1842+
auto registered_state = GetPythonRegisteredObjectState(*connection.context);
1843+
if (!registered_state->Contains(name)) {
18251844
return shared_from_this();
18261845
}
18271846
D_ASSERT(py::gil_check());
1828-
py::gil_scoped_release release;
1829-
// FIXME: DROP TEMPORARY VIEW? doesn't exist?
1830-
const auto quoted_name = SQLQuotedIdentifier::ToString(name);
1831-
connection.Query("DROP VIEW " + quoted_name + "");
1832-
registered_objects.erase(name);
1847+
registered_state->Unregister(name);
18331848
return shared_from_this();
18341849
}
18351850

src/duckdb_py/python_replacement_scan.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,34 @@
1616

1717
namespace duckdb {
1818

19+
void PythonRegisteredObjectState::Register(const string &name, const py::object &object) {
20+
py::gil_scoped_acquire gil;
21+
lock_guard<mutex> guard(lock);
22+
registered_objects[name] = PythonDependencyItem::Create(object);
23+
}
24+
25+
void PythonRegisteredObjectState::Unregister(const string &name) {
26+
py::gil_scoped_acquire gil;
27+
lock_guard<mutex> guard(lock);
28+
registered_objects.erase(name);
29+
}
30+
31+
py::object PythonRegisteredObjectState::Get(const string &name) {
32+
py::gil_scoped_acquire gil;
33+
lock_guard<mutex> guard(lock);
34+
auto entry = registered_objects.find(name);
35+
if (entry == registered_objects.end()) {
36+
return py::none();
37+
}
38+
auto &dependency = entry->second->Cast<PythonDependencyItem>();
39+
return dependency.object->obj;
40+
}
41+
42+
bool PythonRegisteredObjectState::Contains(const string &name) {
43+
lock_guard<mutex> guard(lock);
44+
return registered_objects.find(name) != registered_objects.end();
45+
}
46+
1947
static void CreateArrowScan(const string &name, py::object entry, TableFunctionRef &table_function,
2048
vector<unique_ptr<ParsedExpression>> &children, ClientProperties &client_properties,
2149
PyArrowObjectType type, DatabaseInstance &db) {
@@ -238,6 +266,16 @@ static unique_ptr<TableRef> ReplaceInternal(ClientContext &context, const string
238266
return nullptr;
239267
}
240268

269+
auto registered_objects =
270+
context.registered_state->Get<PythonRegisteredObjectState>(PythonRegisteredObjectState::Key);
271+
if (registered_objects) {
272+
py::gil_scoped_acquire acquire;
273+
auto entry = registered_objects->Get(table_name);
274+
if (!entry.is_none()) {
275+
return PythonReplacementScan::TryReplacementObject(entry, table_name, context);
276+
}
277+
}
278+
241279
lookup_result = context.TryGetCurrentSetting("python_scan_all_frames", result);
242280
D_ASSERT((bool)lookup_result);
243281
auto scan_all_frames = result.GetValue<bool>();

tests/fast/pandas/test_pandas_unregister.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import gc
22
import tempfile
3+
import weakref
34

45
import pandas as pd
56
import pytest
@@ -50,3 +51,20 @@ def test_pandas_unregister2(self, duckdb_cursor):
5051
with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"):
5152
connection.execute("SELECT * FROM dataframe;").fetchdf()
5253
connection.close()
54+
55+
def test_pandas_unregister_releases_object_inside_transaction(self, duckdb_cursor):
56+
duckdb_cursor.execute("CREATE TABLE t(i BIGINT)")
57+
duckdb_cursor.begin()
58+
59+
df = pd.DataFrame({"i": [1, 2, 3]})
60+
ref = weakref.ref(df)
61+
62+
duckdb_cursor.register("dataframe", df)
63+
duckdb_cursor.execute("INSERT INTO t SELECT * FROM dataframe")
64+
duckdb_cursor.unregister("dataframe")
65+
66+
del df
67+
gc.collect()
68+
69+
assert ref() is None
70+
duckdb_cursor.rollback()

0 commit comments

Comments
 (0)