Skip to content

Commit c37b0aa

Browse files
authored
fix: add proper catalog support to engine adapter (#1559)
1 parent aa9a5b5 commit c37b0aa

28 files changed

Lines changed: 750 additions & 152 deletions

sqlmesh/core/_typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66

77
if t.TYPE_CHECKING:
88
TableName = t.Union[str, exp.Table]
9+
SchemaName = t.Union[str, exp.Table]

sqlmesh/core/dialect.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,3 +839,35 @@ def transform_values(
839839
yield exp.func("PARSE_JSON", f"'{value}'")
840840
else:
841841
yield value
842+
843+
844+
def to_schema(sql_path: str | exp.Table) -> exp.Table:
845+
if isinstance(sql_path, exp.Table) and sql_path.this is None:
846+
return sql_path
847+
table = exp.to_table(sql_path.copy() if isinstance(sql_path, exp.Table) else sql_path)
848+
table.set("catalog", table.args.get("db"))
849+
table.set("db", table.args.get("this"))
850+
table.set("this", None)
851+
return table
852+
853+
854+
def schema_(
855+
db: exp.Identifier | str,
856+
catalog: t.Optional[exp.Identifier | str] = None,
857+
quoted: t.Optional[bool] = None,
858+
) -> exp.Table:
859+
"""Build a Schema.
860+
861+
Args:
862+
db: Database name.
863+
catalog: Catalog name.
864+
quoted: Whether to force quotes on the schema's identifiers.
865+
866+
Returns:
867+
The new Schema instance.
868+
"""
869+
return exp.Table(
870+
this=None,
871+
db=exp.to_identifier(db, quoted=quoted) if db else None,
872+
catalog=exp.to_identifier(catalog, quoted=quoted) if catalog else None,
873+
)

sqlmesh/core/engine_adapter/base.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,23 @@
2525
from sqlglot.helper import ensure_list
2626
from sqlglot.optimizer.qualify_columns import quote_identifiers
2727

28-
from sqlmesh.core.dialect import add_table, select_from_values_for_batch_range
29-
from sqlmesh.core.engine_adapter.shared import DataObject
28+
from sqlmesh.core.dialect import (
29+
add_table,
30+
schema_,
31+
select_from_values_for_batch_range,
32+
to_schema,
33+
)
34+
from sqlmesh.core.engine_adapter.shared import DataObject, set_catalog
3035
from sqlmesh.core.model.kind import TimeColumn
3136
from sqlmesh.core.schema_diff import SchemaDiffer
3237
from sqlmesh.utils import double_escape
3338
from sqlmesh.utils.connection_pool import create_connection_pool
3439
from sqlmesh.utils.date import TimeLike, make_inclusive, to_ts
35-
from sqlmesh.utils.errors import SQLMeshError
40+
from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError
3641
from sqlmesh.utils.pandas import columns_to_types_from_df
3742

3843
if t.TYPE_CHECKING:
39-
from sqlmesh.core._typing import TableName
44+
from sqlmesh.core._typing import SchemaName, TableName
4045
from sqlmesh.core.engine_adapter._typing import (
4146
DF,
4247
PySparkDataFrame,
@@ -76,6 +81,28 @@ def requires_condition(self) -> bool:
7681
return self.is_replace_where or self.is_delete_insert
7782

7883

84+
class CatalogSupport(Enum):
85+
UNSUPPORTED = 1
86+
REQUIRES_SET_CATALOG = 2
87+
FULL_SUPPORT = 3
88+
89+
@property
90+
def is_unsupported(self) -> bool:
91+
return self == CatalogSupport.UNSUPPORTED
92+
93+
@property
94+
def is_requires_set_catalog(self) -> bool:
95+
return self == CatalogSupport.REQUIRES_SET_CATALOG
96+
97+
@property
98+
def is_full_support(self) -> bool:
99+
return self == CatalogSupport.FULL_SUPPORT
100+
101+
@property
102+
def is_supported(self) -> bool:
103+
return self.is_requires_set_catalog or self.is_full_support
104+
105+
79106
class SourceQuery:
80107
def __init__(
81108
self,
@@ -115,7 +142,6 @@ class EngineAdapter:
115142

116143
DIALECT = ""
117144
DEFAULT_BATCH_SIZE = 10000
118-
DEFAULT_SQL_GEN_KWARGS: t.Dict[str, str | bool | int] = {}
119145
ESCAPE_JSON = False
120146
SUPPORTS_TRANSACTIONS = True
121147
SUPPORTS_INDEXES = False
@@ -125,6 +151,7 @@ class EngineAdapter:
125151
SUPPORTS_CLONING = False
126152
SCHEMA_DIFFER = SchemaDiffer()
127153
SUPPORTS_TUPLE_IN = True
154+
CATALOG_SUPPORT = CatalogSupport.UNSUPPORTED
128155

129156
def __init__(
130157
self,
@@ -267,6 +294,14 @@ def close(self) -> t.Any:
267294
"""Closes all open connections and releases all allocated resources."""
268295
self._connection_pool.close_all()
269296

297+
def get_current_catalog(self) -> t.Optional[str]:
298+
"""Returns the catalog name of the current connection."""
299+
raise NotImplementedError()
300+
301+
def set_current_catalog(self, catalog: str) -> None:
302+
"""Sets the catalog name of the current connection."""
303+
raise NotImplementedError()
304+
270305
def replace_query(
271306
self,
272307
table_name: TableName,
@@ -516,7 +551,10 @@ def clone_table(
516551
this=exp.to_table(target_table_name),
517552
kind="TABLE",
518553
replace=replace,
519-
clone=exp.Clone(this=exp.to_table(source_table_name), **(clone_kwargs or {})),
554+
clone=exp.Clone(
555+
this=exp.to_table(source_table_name),
556+
**(clone_kwargs or {}),
557+
),
520558
**kwargs,
521559
)
522560
)
@@ -625,18 +663,18 @@ def create_view(
625663
)
626664
)
627665

666+
@set_catalog()
628667
def create_schema(
629668
self,
630-
schema_name: str,
631-
catalog_name: t.Optional[str] = None,
669+
schema_name: SchemaName,
632670
ignore_if_exists: bool = True,
633671
warn_on_error: bool = True,
634672
) -> None:
635673
"""Create a schema from a name or qualified table name."""
636674
try:
637675
self.execute(
638676
exp.Create(
639-
this=exp.table_(schema_name, catalog_name),
677+
this=to_schema(schema_name),
640678
kind="SCHEMA",
641679
exists=ignore_if_exists,
642680
)
@@ -646,13 +684,16 @@ def create_schema(
646684
raise
647685
logger.warning("Failed to create schema '%s': %s", schema_name, e)
648686

687+
@set_catalog()
649688
def drop_schema(
650-
self, schema_name: str, ignore_if_not_exists: bool = True, cascade: bool = False
689+
self,
690+
schema_name: SchemaName,
691+
ignore_if_not_exists: bool = True,
692+
cascade: bool = False,
651693
) -> None:
652-
"""Drop a schema from a name or qualified table name."""
653694
self.execute(
654695
exp.Drop(
655-
this=exp.table_(schema_name.split(".")[0]),
696+
this=to_schema(schema_name),
656697
kind="SCHEMA",
657698
exists=ignore_if_not_exists,
658699
cascade=cascade,
@@ -672,6 +713,7 @@ def drop_view(
672713
)
673714
)
674715

716+
@set_catalog()
675717
def columns(
676718
self, table_name: TableName, include_pseudo_columns: bool = False
677719
) -> t.Dict[str, exp.DataType]:
@@ -687,6 +729,7 @@ def columns(
687729
)
688730
}
689731

732+
@set_catalog()
690733
def table_exists(self, table_name: TableName) -> bool:
691734
try:
692735
self.execute(exp.Describe(this=exp.to_table(table_name), kind="TABLE"))
@@ -926,6 +969,8 @@ def scd_type_2(
926969
columns_to_types = columns_to_types or self.columns(target_table)
927970
if valid_from_name not in columns_to_types or valid_to_name not in columns_to_types:
928971
columns_to_types = self.columns(target_table)
972+
if not columns_to_types:
973+
raise SQLMeshError(f"Could not get columns_to_types. Does {target_table} exist?")
929974
if updated_at_name not in columns_to_types:
930975
raise SQLMeshError(
931976
f"Column {updated_at_name} not found in {target_table}. Table must contain an `updated_at` timestamp for SCD Type 2"
@@ -1154,11 +1199,20 @@ def merge(
11541199
match_expressions=[when_matched, when_not_matched],
11551200
)
11561201

1202+
@set_catalog()
11571203
def rename_table(
11581204
self,
11591205
old_table_name: TableName,
11601206
new_table_name: TableName,
11611207
) -> None:
1208+
new_table = exp.to_table(new_table_name)
1209+
if new_table.catalog:
1210+
old_table = exp.to_table(old_table_name)
1211+
catalog = old_table.catalog or self.get_current_catalog()
1212+
if catalog != new_table.catalog:
1213+
raise UnsupportedCatalogOperationError(
1214+
"Tried to rename table across catalogs which is not supported"
1215+
)
11621216
self.execute(exp.rename_table(old_table_name, new_table_name))
11631217

11641218
def fetchone(
@@ -1307,7 +1361,7 @@ def temp_table(
13071361
with self.transaction():
13081362
table = self._get_temp_table(name)
13091363
if table.db:
1310-
self.create_schema(table.db)
1364+
self.create_schema(schema_(table.args["db"], table.args.get("catalog")))
13111365
self._create_table_from_source_queries(
13121366
table, source_queries, columns_to_types, exists=True, **kwargs
13131367
)
@@ -1346,7 +1400,6 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An
13461400
"dialect": self.dialect,
13471401
"pretty": False,
13481402
"comments": False,
1349-
**self.DEFAULT_SQL_GEN_KWARGS,
13501403
**self.sql_gen_kwargs,
13511404
**kwargs,
13521405
}
@@ -1356,13 +1409,10 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An
13561409

13571410
return expression.sql(**sql_gen_kwargs) # type: ignore
13581411

1359-
def _get_data_objects(
1360-
self, schema_name: str, catalog_name: t.Optional[str] = None
1361-
) -> t.List[DataObject]:
1412+
def _get_data_objects(self, schema_name: SchemaName) -> t.List[DataObject]:
13621413
"""
13631414
Returns all the data objects that exist in the given schema and optionally catalog.
13641415
"""
1365-
13661416
raise NotImplementedError()
13671417

13681418
def _get_temp_table(

sqlmesh/core/engine_adapter/base_postgres.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
from sqlglot import exp
66

7-
from sqlmesh.core.engine_adapter.mixins import EngineAdapter
7+
from sqlmesh.core.engine_adapter import EngineAdapter
88
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
99
from sqlmesh.utils.errors import SQLMeshError
1010

1111
if t.TYPE_CHECKING:
12-
from sqlmesh.core._typing import TableName
12+
from sqlmesh.core._typing import SchemaName, TableName
1313
from sqlmesh.core.engine_adapter.base import QueryOrDF
1414

1515

@@ -92,13 +92,11 @@ def create_view(
9292
**create_kwargs,
9393
)
9494

95-
def _get_data_objects(
96-
self, schema_name: str, catalog_name: t.Optional[str] = None
97-
) -> t.List[DataObject]:
95+
def _get_data_objects(self, schema_name: SchemaName) -> t.List[DataObject]:
9896
"""
9997
Returns all the data objects that exist in the given schema and optionally catalog.
10098
"""
101-
catalog_name = f"'{catalog_name}'" if catalog_name else "NULL"
99+
catalog_name = f"'{self.get_current_catalog()}'"
102100
query = f"""
103101
SELECT
104102
{catalog_name} AS catalog_name,

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from sqlglot.helper import ensure_list
1010
from sqlglot.transforms import remove_precision_parameterized_types
1111

12-
from sqlmesh.core.engine_adapter.base import SourceQuery
12+
from sqlmesh.core.dialect import to_schema
13+
from sqlmesh.core.engine_adapter.base import CatalogSupport, SourceQuery
1314
from sqlmesh.core.engine_adapter.mixins import InsertOverwriteWithMergeMixin
1415
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
1516
from sqlmesh.core.node import IntervalUnit
@@ -26,7 +27,7 @@
2627
from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult
2728
from google.cloud.bigquery.table import Table as BigQueryTable
2829

29-
from sqlmesh.core._typing import TableName
30+
from sqlmesh.core._typing import SchemaName, TableName
3031
from sqlmesh.core.engine_adapter._typing import DF, Query
3132
from sqlmesh.core.engine_adapter.base import QueryOrDF
3233

@@ -44,6 +45,7 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin):
4445
SUPPORTS_TRANSACTIONS = False
4546
SUPPORTS_MATERIALIZED_VIEWS = True
4647
SUPPORTS_CLONING = True
48+
CATALOG_SUPPORT = CatalogSupport.FULL_SUPPORT
4749

4850
# SQL is not supported for adding columns to structs: https://cloud.google.com/bigquery/docs/managing-table-schemas#api_1
4951
# Can explore doing this with the API in the future
@@ -131,10 +133,17 @@ def _end_session(self) -> None:
131133
def _is_session_active(self) -> bool:
132134
return self._session_id is not None
133135

136+
def get_current_catalog(self) -> t.Optional[str]:
137+
"""Returns the catalog name of the current connection."""
138+
return self.client.project
139+
140+
def set_current_catalog(self, catalog: str) -> None:
141+
"""Sets the catalog name of the current connection."""
142+
self.client.project = catalog
143+
134144
def create_schema(
135145
self,
136-
schema_name: str,
137-
catalog_name: t.Optional[str] = None,
146+
schema_name: SchemaName,
138147
ignore_if_exists: bool = True,
139148
warn_on_error: bool = True,
140149
) -> None:
@@ -144,7 +153,6 @@ def create_schema(
144153
try:
145154
super().create_schema(
146155
schema_name,
147-
catalog_name=catalog_name,
148156
ignore_if_exists=ignore_if_exists,
149157
warn_on_error=False,
150158
)
@@ -535,18 +543,16 @@ def execute(
535543
self.cursor._set_rowcount(query_results)
536544
self.cursor._set_description(query_results.schema)
537545

538-
def _get_data_objects(
539-
self, schema_name: str, catalog_name: t.Optional[str] = None
540-
) -> t.List[DataObject]:
546+
def _get_data_objects(self, schema_name: SchemaName) -> t.List[DataObject]:
541547
"""
542548
Returns all the data objects that exist in the given schema and optionally catalog.
543549
"""
544550
from google.api_core.exceptions import NotFound
545551
from google.cloud.bigquery import DatasetReference
546552

547-
dataset_ref = DatasetReference(
548-
project=catalog_name or self.client.project, dataset_id=schema_name
549-
)
553+
schema = to_schema(schema_name)
554+
catalog_name = schema.catalog or self.get_current_catalog()
555+
dataset_ref = DatasetReference(project=catalog_name, dataset_id=schema.db)
550556
try:
551557
return [
552558
DataObject(

0 commit comments

Comments
 (0)