Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 89 additions & 10 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
self.correlation_id = correlation_id
self._schema_differ_overrides = schema_differ_overrides
self._query_execution_tracker = query_execution_tracker
self._data_object_cache: t.Dict[str, t.Optional[DataObject]] = {}

def with_settings(self, **kwargs: t.Any) -> EngineAdapter:
extra_kwargs = {
Expand Down Expand Up @@ -983,6 +984,13 @@ def _create_table(
),
track_rows_processed=track_rows_processed,
)
# Extract table name to clear cache
table_name = (
table_name_or_schema.this
if isinstance(table_name_or_schema, exp.Schema)
else table_name_or_schema
)
self._clear_data_object_cache(table_name)

def _build_create_table_exp(
self,
Expand Down Expand Up @@ -1074,6 +1082,7 @@ def clone_table(
**kwargs,
)
)
self._clear_data_object_cache(target_table_name)

def drop_data_object(self, data_object: DataObject, ignore_if_not_exists: bool = True) -> None:
"""Drops a data object of arbitrary type.
Expand Down Expand Up @@ -1139,6 +1148,7 @@ def _drop_object(
drop_args["cascade"] = cascade

self.execute(exp.Drop(this=exp.to_table(name), kind=kind, exists=exists, **drop_args))
self._clear_data_object_cache(name)

def get_alter_operations(
self,
Expand Down Expand Up @@ -1329,6 +1339,8 @@ def create_view(
quote_identifiers=self.QUOTE_IDENTIFIERS_IN_VIEWS,
)

self._clear_data_object_cache(view_name)

# Register table comment with commands if the engine doesn't support doing it in CREATE
if (
table_description
Expand Down Expand Up @@ -1458,8 +1470,14 @@ def columns(
}

def table_exists(self, table_name: TableName) -> bool:
table = exp.to_table(table_name)
data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
if data_object_cache_key in self._data_object_cache:
logger.debug("Table existence cache hit: %s", data_object_cache_key)
return self._data_object_cache[data_object_cache_key] is not None

try:
self.execute(exp.Describe(this=exp.to_table(table_name), kind="TABLE"))
self.execute(exp.Describe(this=table, kind="TABLE"))
return True
except Exception:
return False
Expand Down Expand Up @@ -2278,15 +2296,59 @@ def get_data_objects(
if object_names is not None:
if not object_names:
return []
object_names_list = list(object_names)
batches = [
object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE]
for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE)
]
return [
obj for batch in batches for obj in self._get_data_objects(schema_name, set(batch))
]
return self._get_data_objects(schema_name)

# Check cache for each object name
target_schema = to_schema(schema_name)
cached_objects = []
missing_names = set()

for name in object_names:
cache_key = _get_data_object_cache_key(
target_schema.catalog, target_schema.db, name
)
if cache_key in self._data_object_cache:
logger.debug("Data object cache hit: %s", cache_key)
data_object = self._data_object_cache[cache_key]
# If the object is none, then the table was previously looked for but not found
if data_object:
cached_objects.append(data_object)
else:
logger.debug("Data object cache miss: %s", cache_key)
missing_names.add(name)

# Fetch missing objects from database
if missing_names:
object_names_list = list(missing_names)
batches = [
object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE]
for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE)
]
fetched_objects = [
obj
for batch in batches
for obj in self._get_data_objects(schema_name, set(batch))
]

for obj in fetched_objects:
cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name)
self._data_object_cache[cache_key] = obj

fetched_object_names = {obj.name for obj in fetched_objects}
for missing_name in missing_names - fetched_object_names:
cache_key = _get_data_object_cache_key(
target_schema.catalog, target_schema.db, missing_name
)
self._data_object_cache[cache_key] = None

return cached_objects + fetched_objects

return cached_objects

fetched_objects = self._get_data_objects(schema_name)
for obj in fetched_objects:
cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name)
self._data_object_cache[cache_key] = obj
return fetched_objects

def fetchone(
self,
Expand Down Expand Up @@ -2693,6 +2755,17 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An

return expression.sql(**sql_gen_kwargs, copy=False) # type: ignore

def _clear_data_object_cache(self, table_name: t.Optional[TableName] = None) -> None:
"""Clears the cache entry for the given table name, or clears the entire cache if table_name is None."""
if table_name is None:
logger.debug("Clearing entire data object cache")
self._data_object_cache.clear()
else:
table = exp.to_table(table_name)
cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
logger.debug("Clearing data object cache key: %s", cache_key)
self._data_object_cache.pop(cache_key, None)

def _get_data_objects(
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
) -> t.List[DataObject]:
Expand Down Expand Up @@ -2940,3 +3013,9 @@ def _decoded_str(value: t.Union[str, bytes]) -> str:
if isinstance(value, bytes):
return value.decode("utf-8")
return value


def _get_data_object_cache_key(catalog: t.Optional[str], schema_name: str, object_name: str) -> str:
"""Returns a cache key for a data object based on its fully qualified name."""
catalog = catalog or ""
return f"{catalog}.{schema_name}.{object_name}"
10 changes: 9 additions & 1 deletion sqlmesh/core/engine_adapter/base_postgres.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import typing as t
import logging

from sqlglot import exp

from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.engine_adapter.base import EngineAdapter, _get_data_object_cache_key
from sqlmesh.core.engine_adapter.shared import (
CatalogSupport,
CommentCreationTable,
Expand All @@ -20,6 +21,9 @@
from sqlmesh.core.engine_adapter._typing import QueryOrDF


logger = logging.getLogger(__name__)


class BasePostgresEngineAdapter(EngineAdapter):
DEFAULT_BATCH_SIZE = 400
COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY
Expand Down Expand Up @@ -75,6 +79,10 @@ def table_exists(self, table_name: TableName) -> bool:
Reference: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/cursor.py#L528-L553
"""
table = exp.to_table(table_name)
data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
if data_object_cache_key in self._data_object_cache:
logger.debug("Table existence cache hit: %s", data_object_cache_key)
return self._data_object_cache[data_object_cache_key] is not None

sql = (
exp.select("1")
Expand Down
7 changes: 7 additions & 0 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlglot.transforms import remove_precision_parameterized_types

from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter.base import _get_data_object_cache_key
from sqlmesh.core.engine_adapter.mixins import (
ClusteredByMixin,
RowDiffMixin,
Expand Down Expand Up @@ -744,6 +745,12 @@ def insert_overwrite_by_partition(
)

def table_exists(self, table_name: TableName) -> bool:
table = exp.to_table(table_name)
data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
if data_object_cache_key in self._data_object_cache:
logger.debug("Table existence cache hit: %s", data_object_cache_key)
return self._data_object_cache[data_object_cache_key] is not None

try:
from google.cloud.exceptions import NotFound
except ModuleNotFoundError:
Expand Down
9 changes: 9 additions & 0 deletions sqlmesh/core/engine_adapter/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import typing as t
import logging

from sqlglot import exp

Expand All @@ -13,6 +14,7 @@
InsertOverwriteStrategy,
MERGE_SOURCE_ALIAS,
MERGE_TARGET_ALIAS,
_get_data_object_cache_key,
)
from sqlmesh.core.engine_adapter.mixins import (
GetCurrentCatalogFromFunctionMixin,
Expand All @@ -36,6 +38,9 @@
from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF


logger = logging.getLogger(__name__)


@set_catalog()
class MSSQLEngineAdapter(
EngineAdapterWithIndexSupport,
Expand Down Expand Up @@ -144,6 +149,10 @@ def build_var_length_col(
def table_exists(self, table_name: TableName) -> bool:
"""MsSql doesn't support describe so we query information_schema."""
table = exp.to_table(table_name)
data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
if data_object_cache_key in self._data_object_cache:
logger.debug("Table existence cache hit: %s", data_object_cache_key)
return self._data_object_cache[data_object_cache_key] is not None

sql = (
exp.select("1")
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/engine_adapter/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PostgresEngineAdapter(
HAS_VIEW_BINDING = True
CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog")
SUPPORTS_REPLACE_TABLE = False
MAX_IDENTIFIER_LENGTH = 63
MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63
SUPPORTS_QUERY_EXECUTION_TRACKING = True
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
Expand Down
67 changes: 62 additions & 5 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ def promote(
]
self._create_schemas(gateway_table_pairs=gateway_table_pairs)

# Fetch the view data objects for the promoted snapshots to get them cached
self._get_virtual_data_objects(target_snapshots, environment_naming_info)

deployability_index = deployability_index or DeployabilityIndex.all_deployable()
with self.concurrent_context():
concurrent_apply_to_snapshots(
Expand Down Expand Up @@ -425,7 +428,9 @@ def get_snapshots_to_create(
target_snapshots: Target snapshots.
deployability_index: Determines snapshots that are deployable / representative in the context of this creation.
"""
existing_data_objects = self._get_data_objects(target_snapshots, deployability_index)
existing_data_objects = self._get_physical_data_objects(
target_snapshots, deployability_index
)
snapshots_to_create = []
for snapshot in target_snapshots:
if not snapshot.is_model or snapshot.is_symbolic:
Expand Down Expand Up @@ -482,7 +487,7 @@ def migrate(
deployability_index: Determines snapshots that are deployable in the context of this evaluation.
"""
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
target_data_objects = self._get_data_objects(target_snapshots, deployability_index)
target_data_objects = self._get_physical_data_objects(target_snapshots, deployability_index)
if not target_data_objects:
return

Expand Down Expand Up @@ -1472,7 +1477,7 @@ def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex
and adapter.table_exists(snapshot.table_name())
)

def _get_data_objects(
def _get_physical_data_objects(
self,
target_snapshots: t.Iterable[Snapshot],
deployability_index: DeployabilityIndex,
Expand All @@ -1488,15 +1493,67 @@ def _get_data_objects(
A dictionary of snapshot IDs to existing data objects of their physical tables. If the data object
for a snapshot is not found, it will not be included in the dictionary.
"""
return self._get_data_objects(
target_snapshots,
lambda s: exp.to_table(
s.table_name(deployability_index.is_deployable(s)), dialect=s.model.dialect
),
)

def _get_virtual_data_objects(
self,
target_snapshots: t.Iterable[Snapshot],
environment_naming_info: EnvironmentNamingInfo,
) -> t.Dict[SnapshotId, DataObject]:
"""Returns a dictionary of snapshot IDs to existing data objects of their virtual views.

Args:
target_snapshots: Target snapshots.
environment_naming_info: The environment naming info of the target virtual environment.

Returns:
A dictionary of snapshot IDs to existing data objects of their virtual views. If the data object
for a snapshot is not found, it will not be included in the dictionary.
"""

def _get_view_name(s: Snapshot) -> exp.Table:
adapter = (
self.get_adapter(s.model_gateway)
if environment_naming_info.gateway_managed
else self.adapter
)
return exp.to_table(
s.qualified_view_name.for_environment(
environment_naming_info, dialect=adapter.dialect
),
dialect=adapter.dialect,
)

return self._get_data_objects(target_snapshots, _get_view_name)

def _get_data_objects(
self,
target_snapshots: t.Iterable[Snapshot],
table_name_callable: t.Callable[[Snapshot], exp.Table],
) -> t.Dict[SnapshotId, DataObject]:
"""Returns a dictionary of snapshot IDs to existing data objects.

Args:
target_snapshots: Target snapshots.
table_name_callable: A function that takes a snapshot and returns the table to look for.

Returns:
A dictionary of snapshot IDs to existing data objects. If the data object for a snapshot is not found,
it will not be included in the dictionary.
"""
tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = (
defaultdict(lambda: defaultdict(set))
)
snapshots_by_table_name: t.Dict[str, Snapshot] = {}
for snapshot in target_snapshots:
if not snapshot.is_model or snapshot.is_symbolic:
continue
is_deployable = deployability_index.is_deployable(snapshot)
table = exp.to_table(snapshot.table_name(is_deployable), dialect=snapshot.model.dialect)
table = table_name_callable(snapshot)
table_schema = d.schema_(table.db, catalog=table.catalog)
tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name)
snapshots_by_table_name[table.name] = snapshot
Expand Down
1 change: 1 addition & 0 deletions tests/core/engine_adapter/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture):
)
mocker.patch.object(adapter, "_get_data_objects", return_value=[])
adapter.cursor.execute.reset_mock()
adapter._clear_data_object_cache()

adapter.s3_warehouse_location = "s3://foo"
adapter.replace_query(
Expand Down
Loading