Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
self.correlation_id = correlation_id
self._schema_differ_overrides = schema_differ_overrides
self._query_execution_tracker = query_execution_tracker
self._data_object_cache: t.Dict[str, t.Optional[DataObject]] = {}

def with_settings(self, **kwargs: t.Any) -> EngineAdapter:
extra_kwargs = {
Expand Down Expand Up @@ -983,6 +984,13 @@ def _create_table(
),
track_rows_processed=track_rows_processed,
)
# Extract table name to clear cache
table_name = (
table_name_or_schema.this
if isinstance(table_name_or_schema, exp.Schema)
else table_name_or_schema
)
self._clear_data_object_cache(table_name)

def _build_create_table_exp(
self,
Expand Down Expand Up @@ -1074,6 +1082,7 @@ def clone_table(
**kwargs,
)
)
self._clear_data_object_cache(target_table_name)

def drop_data_object(self, data_object: DataObject, ignore_if_not_exists: bool = True) -> None:
"""Drops a data object of arbitrary type.
Expand Down Expand Up @@ -1139,6 +1148,7 @@ def _drop_object(
drop_args["cascade"] = cascade

self.execute(exp.Drop(this=exp.to_table(name), kind=kind, exists=exists, **drop_args))
self._clear_data_object_cache(name)

def get_alter_operations(
self,
Expand Down Expand Up @@ -1329,6 +1339,8 @@ def create_view(
quote_identifiers=self.QUOTE_IDENTIFIERS_IN_VIEWS,
)

self._clear_data_object_cache(view_name)

# Register table comment with commands if the engine doesn't support doing it in CREATE
if (
table_description
Expand Down Expand Up @@ -2278,14 +2290,52 @@ def get_data_objects(
if object_names is not None:
if not object_names:
return []
object_names_list = list(object_names)
batches = [
object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE]
for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE)
]
return [
obj for batch in batches for obj in self._get_data_objects(schema_name, set(batch))
]

# Check cache for each object name
target_schema = to_schema(schema_name)
cached_objects = []
missing_names = set()

for name in object_names:
cache_key = _get_data_object_cache_key(
target_schema.catalog, target_schema.db, name
)
if cache_key in self._data_object_cache:
data_object = self._data_object_cache[cache_key]
# If the object is none, then the table was previously looked for but not found
if data_object:
cached_objects.append(data_object)
else:
missing_names.add(name)

# Fetch missing objects from database
if missing_names:
object_names_list = list(missing_names)
batches = [
object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE]
for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE)
]
fetched_objects = [
obj
for batch in batches
for obj in self._get_data_objects(schema_name, set(batch))
]

# Cache the fetched objects
for obj in fetched_objects:
cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name)
self._data_object_cache[cache_key] = obj

fetched_object_names = {obj.name for obj in fetched_objects}
for missing_name in missing_names - fetched_object_names:
cache_key = _get_data_object_cache_key(
target_schema.catalog, target_schema.db, missing_name
)
self._data_object_cache[cache_key] = None

return cached_objects + fetched_objects

return cached_objects
return self._get_data_objects(schema_name)

def fetchone(
Expand Down Expand Up @@ -2693,6 +2743,15 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An

return expression.sql(**sql_gen_kwargs, copy=False) # type: ignore

def _clear_data_object_cache(self, table_name: t.Optional[TableName] = None) -> None:
"""Clears the cache entry for the given table name, or clears the entire cache if table_name is None."""
if table_name is None:
self._data_object_cache.clear()
else:
table = exp.to_table(table_name)
cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
self._data_object_cache.pop(cache_key, None)

def _get_data_objects(
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
) -> t.List[DataObject]:
Expand Down Expand Up @@ -2940,3 +2999,11 @@ def _decoded_str(value: t.Union[str, bytes]) -> str:
if isinstance(value, bytes):
return value.decode("utf-8")
return value


def _get_data_object_cache_key(catalog: t.Optional[str], schema_name: str, object_name: str) -> str:
"""Returns a cache key for a data object based on its fully qualified name."""
catalog_part = catalog.lower() if catalog else ""
schema_part = schema_name.lower()
object_part = object_name.lower()
return f"{catalog_part}.{schema_part}.{object_part}"
67 changes: 62 additions & 5 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ def promote(
]
self._create_schemas(gateway_table_pairs=gateway_table_pairs)

# Fetch the view data objects for the promoted snapshots to get them cached
self._get_virtual_data_objects(target_snapshots, environment_naming_info)

deployability_index = deployability_index or DeployabilityIndex.all_deployable()
with self.concurrent_context():
concurrent_apply_to_snapshots(
Expand Down Expand Up @@ -425,7 +428,9 @@ def get_snapshots_to_create(
target_snapshots: Target snapshots.
deployability_index: Determines snapshots that are deployable / representative in the context of this creation.
"""
existing_data_objects = self._get_data_objects(target_snapshots, deployability_index)
existing_data_objects = self._get_physical_data_objects(
target_snapshots, deployability_index
)
snapshots_to_create = []
for snapshot in target_snapshots:
if not snapshot.is_model or snapshot.is_symbolic:
Expand Down Expand Up @@ -482,7 +487,7 @@ def migrate(
deployability_index: Determines snapshots that are deployable in the context of this evaluation.
"""
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
target_data_objects = self._get_data_objects(target_snapshots, deployability_index)
target_data_objects = self._get_physical_data_objects(target_snapshots, deployability_index)
if not target_data_objects:
return

Expand Down Expand Up @@ -1472,7 +1477,7 @@ def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex
and adapter.table_exists(snapshot.table_name())
)

def _get_data_objects(
def _get_physical_data_objects(
self,
target_snapshots: t.Iterable[Snapshot],
deployability_index: DeployabilityIndex,
Expand All @@ -1488,15 +1493,67 @@ def _get_data_objects(
A dictionary of snapshot IDs to existing data objects of their physical tables. If the data object
for a snapshot is not found, it will not be included in the dictionary.
"""
return self._get_data_objects(
target_snapshots,
lambda s: exp.to_table(
s.table_name(deployability_index.is_deployable(s)), dialect=s.model.dialect
),
)

def _get_virtual_data_objects(
self,
target_snapshots: t.Iterable[Snapshot],
environment_naming_info: EnvironmentNamingInfo,
) -> t.Dict[SnapshotId, DataObject]:
"""Returns a dictionary of snapshot IDs to existing data objects of their virtual views.

Args:
target_snapshots: Target snapshots.
environment_naming_info: The environment naming info of the target virtual environment.

Returns:
A dictionary of snapshot IDs to existing data objects of their virtual views. If the data object
for a snapshot is not found, it will not be included in the dictionary.
"""

def _get_view_name(s: Snapshot) -> exp.Table:
adapter = (
self.get_adapter(s.model_gateway)
if environment_naming_info.gateway_managed
else self.adapter
)
return exp.to_table(
s.qualified_view_name.for_environment(
environment_naming_info, dialect=adapter.dialect
),
dialect=adapter.dialect,
)

return self._get_data_objects(target_snapshots, _get_view_name)

def _get_data_objects(
self,
target_snapshots: t.Iterable[Snapshot],
table_name_callable: t.Callable[[Snapshot], exp.Table],
) -> t.Dict[SnapshotId, DataObject]:
"""Returns a dictionary of snapshot IDs to existing data objects.

Args:
target_snapshots: Target snapshots.
table_name_callable: A function that takes a snapshot and returns the table to look for.

Returns:
A dictionary of snapshot IDs to existing data objects. If the data object for a snapshot is not found,
it will not be included in the dictionary.
"""
tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = (
defaultdict(lambda: defaultdict(set))
)
snapshots_by_table_name: t.Dict[str, Snapshot] = {}
for snapshot in target_snapshots:
if not snapshot.is_model or snapshot.is_symbolic:
continue
is_deployable = deployability_index.is_deployable(snapshot)
table = exp.to_table(snapshot.table_name(is_deployable), dialect=snapshot.model.dialect)
table = table_name_callable(snapshot)
table_schema = d.schema_(table.db, catalog=table.catalog)
tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name)
snapshots_by_table_name[table.name] = snapshot
Expand Down
1 change: 1 addition & 0 deletions tests/core/engine_adapter/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture):
)
mocker.patch.object(adapter, "_get_data_objects", return_value=[])
adapter.cursor.execute.reset_mock()
adapter._clear_data_object_cache()

adapter.s3_warehouse_location = "s3://foo"
adapter.replace_query(
Expand Down
Loading