|
1 | 1 | #include "duckdb_python/pyconnection/pyconnection.hpp" |
2 | 2 |
|
| 3 | +#include "duckdb/catalog/catalog.hpp" |
| 4 | +#include "duckdb/catalog/default/default_types.hpp" |
3 | 5 | #include "duckdb/common/arrow/arrow.hpp" |
4 | 6 | #include "duckdb/common/enums/profiler_format.hpp" |
5 | 7 | #include "duckdb/common/types.hpp" |
@@ -52,6 +54,17 @@ shared_ptr<PythonImportCache> DuckDBPyConnection::import_cache = nullptr; |
52 | 54 | PythonEnvironmentType DuckDBPyConnection::environment = PythonEnvironmentType::NORMAL; // NOLINT: allow global |
53 | 55 | std::string DuckDBPyConnection::formatted_python_version = ""; |
54 | 56 |
|
| 57 | +static shared_ptr<PythonRegisteredObjectState> GetPythonRegisteredObjectState(ClientContext &context) { |
| 58 | + return context.registered_state->GetOrCreate<PythonRegisteredObjectState>(PythonRegisteredObjectState::Key); |
| 59 | +} |
| 60 | + |
| 61 | +static bool TemporaryObjectExists(ClientContext &context, const string &name) { |
| 62 | + auto &catalog = Catalog::GetCatalog(context, TEMP_CATALOG); |
| 63 | + EntryLookupInfo lookup_info(CatalogType::TABLE_ENTRY, name); |
| 64 | + auto entry = catalog.GetEntry(context, DEFAULT_SCHEMA, lookup_info, OnEntryNotFound::RETURN_NULL); |
| 65 | + return entry != nullptr; |
| 66 | +} |
| 67 | + |
55 | 68 | DuckDBPyConnection::~DuckDBPyConnection() { |
56 | 69 | try { |
57 | 70 | py::gil_scoped_release gil; |
@@ -743,11 +756,16 @@ shared_ptr<DuckDBPyConnection> DuckDBPyConnection::RegisterPythonObject(const st |
743 | 756 | const py::object &python_object) { |
744 | 757 | auto &connection = con.GetConnection(); |
745 | 758 | auto &client = *connection.context; |
746 | | - auto object = PythonReplacementScan::ReplacementObject(python_object, name, client); |
747 | | - auto view_rel = make_shared_ptr<ViewRelation>(connection.context, std::move(object), name); |
748 | | - bool replace = registered_objects.count(name); |
749 | | - view_rel->CreateView(name, replace, true); |
750 | | - registered_objects.insert(name); |
| 759 | + auto registered_state = GetPythonRegisteredObjectState(client); |
| 760 | + if (!registered_state->Contains(name)) { |
| 761 | + bool temp_object_exists = false; |
| 762 | + client.RunFunctionInTransaction([&]() { temp_object_exists = TemporaryObjectExists(client, name); }, false); |
| 763 | + if (temp_object_exists) { |
| 764 | + throw CatalogException("View with name \"%s\" already exists!", name); |
| 765 | + } |
| 766 | + } |
| 767 | + PythonReplacementScan::ReplacementObject(python_object, name, client); |
| 768 | + registered_state->Register(name, python_object); |
751 | 769 | return shared_from_this(); |
752 | 770 | } |
753 | 771 |
|
@@ -1821,15 +1839,12 @@ unordered_set<string> DuckDBPyConnection::GetTableNames(const string &query, boo |
1821 | 1839 |
|
1822 | 1840 | shared_ptr<DuckDBPyConnection> DuckDBPyConnection::UnregisterPythonObject(const string &name) { |
1823 | 1841 | auto &connection = con.GetConnection(); |
1824 | | - if (!registered_objects.count(name)) { |
| 1842 | + auto registered_state = GetPythonRegisteredObjectState(*connection.context); |
| 1843 | + if (!registered_state->Contains(name)) { |
1825 | 1844 | return shared_from_this(); |
1826 | 1845 | } |
1827 | 1846 | D_ASSERT(py::gil_check()); |
1828 | | - py::gil_scoped_release release; |
1829 | | - // FIXME: DROP TEMPORARY VIEW? doesn't exist? |
1830 | | - const auto quoted_name = SQLQuotedIdentifier::ToString(name); |
1831 | | - connection.Query("DROP VIEW " + quoted_name + ""); |
1832 | | - registered_objects.erase(name); |
| 1847 | + registered_state->Unregister(name); |
1833 | 1848 | return shared_from_this(); |
1834 | 1849 | } |
1835 | 1850 |
|
|
0 commit comments