Skip to content

Commit 62b9efc

Browse files
fix SqlCatalog list_namespaces()
1 parent 826a006 commit 62b9efc

2 files changed

Lines changed: 41 additions & 13 deletions

File tree

pyiceberg/catalog/sql.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -610,15 +610,29 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi
610610
table_stmt = select(IcebergTables.table_namespace).where(IcebergTables.catalog_name == self.name)
611611
namespace_stmt = select(IcebergNamespaceProperties.namespace).where(IcebergNamespaceProperties.catalog_name == self.name)
612612
if namespace:
613-
namespace_str = Catalog.namespace_to_string(namespace, NoSuchNamespaceError)
614-
table_stmt = table_stmt.where(IcebergTables.table_namespace.like(namespace_str))
615-
namespace_stmt = namespace_stmt.where(IcebergNamespaceProperties.namespace.like(namespace_str))
613+
namespace_like = Catalog.namespace_to_string(namespace, NoSuchNamespaceError) + "%"
614+
table_stmt = table_stmt.where(IcebergTables.table_namespace.like(namespace_like))
615+
namespace_stmt = namespace_stmt.where(IcebergNamespaceProperties.namespace.like(namespace_like))
616616
stmt = union(
617617
table_stmt,
618618
namespace_stmt,
619619
)
620620
with Session(self.engine) as session:
621-
return [Catalog.identifier_to_tuple(namespace_col) for namespace_col in session.execute(stmt).scalars()]
621+
namespaces = [Catalog.identifier_to_tuple(namespace_col) for namespace_col in session.execute(stmt).scalars()]
622+
623+
sub_namespaces_level_length = 1
624+
if namespace:
625+
namespace_tuple = Catalog.identifier_to_tuple(namespace)
626+
sub_namespaces_level_length = len(namespace_tuple) + 1
627+
628+
# only get sub namespaces/children
629+
result = list({ns[:sub_namespaces_level_length] for ns in namespaces if len(ns) >= sub_namespaces_level_length})
630+
631+
if namespace:
632+
# exclude fuzzy matches when `namespace` contains `%` or `_`
633+
result = [ns for ns in result if ns[: len(namespace_tuple)] == namespace_tuple]
634+
635+
return result
622636

623637
def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties:
624638
"""Get properties for a namespace.

tests/catalog/test_sql.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import os
1919
from pathlib import Path
20-
from typing import Any, Generator, List, cast
20+
from typing import Any, Generator, cast
2121

2222
import pyarrow as pa
2323
import pytest
@@ -1116,17 +1116,31 @@ def test_create_namespace_with_empty_identifier(catalog: SqlCatalog, empty_names
11161116
lazy_fixture("catalog_sqlite"),
11171117
],
11181118
)
1119-
@pytest.mark.parametrize("namespace_list", [lazy_fixture("database_list"), lazy_fixture("hierarchical_namespace_list")])
1120-
def test_list_namespaces(catalog: SqlCatalog, namespace_list: List[str]) -> None:
1119+
def test_list_namespaces(catalog: SqlCatalog) -> None:
1120+
namespace_list = ["db", "db.ns1", "db.ns1.ns2", "db.ns2", "db2", "db2.ns1", "db%"]
11211121
for namespace in namespace_list:
11221122
catalog.create_namespace(namespace)
1123-
# Test global list
1123+
11241124
ns_list = catalog.list_namespaces()
1125-
for namespace in namespace_list:
1126-
assert Catalog.identifier_to_tuple(namespace) in ns_list
1127-
# Test individual namespace list
1128-
assert len(one_namespace := catalog.list_namespaces(namespace)) == 1
1129-
assert Catalog.identifier_to_tuple(namespace) == one_namespace[0]
1125+
expected_list = [("db",), ("db2",), ("db%",)]
1126+
assert len(ns_list) == len(expected_list)
1127+
for ns in expected_list:
1128+
assert ns in ns_list
1129+
1130+
ns_list = catalog.list_namespaces("db")
1131+
expected_list = [("db", "ns1"), ("db", "ns2")]
1132+
assert len(ns_list) == len(expected_list)
1133+
for ns in expected_list:
1134+
assert ns in ns_list
1135+
1136+
ns_list = catalog.list_namespaces("db.ns1")
1137+
expected_list = [("db", "ns1", "ns2")]
1138+
assert len(ns_list) == len(expected_list)
1139+
for ns in expected_list:
1140+
assert ns in ns_list
1141+
1142+
ns_list = catalog.list_namespaces("db.ns1.ns2")
1143+
assert len(ns_list) == 0
11301144

11311145

11321146
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)