Skip to content

Commit 70fd198

Browse files
izeigermantobymao
andauthored
Feat: Cache the optimized model query (#975)
* Feat: Cache the optimized model query * fix failing tests * disable gc when loading * add optimized query test --------- Co-authored-by: tobymao <toby.mao@gmail.com>
1 parent c49c3db commit 70fd198

File tree

13 files changed

+282
-146
lines changed

13 files changed

+282
-146
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"requests",
4545
"rich",
4646
"ruamel.yaml",
47-
"sqlglot~=16.0.0",
47+
"sqlglot~=16.1.0",
4848
"fsspec",
4949
],
5050
extras_require={

sqlmesh/core/context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import abc
3636
import contextlib
37+
import gc
3738
import typing as t
3839
import unittest.result
3940
from io import StringIO
@@ -291,7 +292,7 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model:
291292

292293
self._models.update({model.name: model})
293294
self.dag.add(model.name, model.depends_on)
294-
update_model_schemas(self.dag, self._models)
295+
update_model_schemas(self.dag, self._models, self.path)
295296

296297
return model
297298

@@ -356,11 +357,13 @@ def refresh(self) -> None:
356357
def load(self, update_schemas: bool = True) -> Context:
357358
"""Load all files in the context's path."""
358359
with sys_path(*self.configs):
360+
gc.disable()
359361
project = self._loader.load(self, update_schemas)
360362
self._macros = project.macros
361363
self._models = project.models
362364
self._audits = project.audits
363365
self.dag = project.dag
366+
gc.enable()
364367

365368
return self
366369

sqlmesh/core/loader.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sqlmesh.core.model import (
2323
Model,
2424
ModelCache,
25+
OptimizedQueryCache,
2526
SeedModel,
2627
create_external_model,
2728
load_model,
@@ -38,8 +39,11 @@
3839
from sqlmesh.core.context import Context
3940

4041

41-
def update_model_schemas(dag: DAG[str], models: UniqueKeyDict[str, Model]) -> None:
42+
def update_model_schemas(
43+
dag: DAG[str], models: UniqueKeyDict[str, Model], context_path: Path
44+
) -> None:
4245
schema = MappingSchema(normalize=False)
46+
optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE)
4347

4448
for name in dag.sorted:
4549
model = models.get(name)
@@ -49,18 +53,18 @@ def update_model_schemas(dag: DAG[str], models: UniqueKeyDict[str, Model]) -> No
4953
continue
5054

5155
model.update_schema(schema)
56+
cache_hit = optimized_query_cache.with_optimized_query(model)
5257
schema.add_table(name, model.columns_to_types, dialect=model.dialect)
5358

54-
external = any(dep not in models for dep in model.depends_on)
55-
if external:
59+
if any(dep not in models for dep in model.depends_on):
5660
if "*" in model.columns_to_types:
5761
raise ConfigError(
5862
f"Can't expand SELECT * expression for model '{name}' at '{model._path}'."
5963
" Either specify external source projections expliticly or"
6064
' add source tables as "external models" using the command'
6165
" 'sqlmesh create_external_models'."
6266
)
63-
elif model.mapping_schema:
67+
elif model.mapping_schema and not cache_hit:
6468
try:
6569
validate_qualify_columns(model.render_query())
6670
except SqlglotError as e:
@@ -120,7 +124,7 @@ def load(self, context: Context, update_schemas: bool = True) -> LoadedProject:
120124
self._add_model_to_dag(model)
121125

122126
if update_schemas:
123-
update_model_schemas(self._dag, models)
127+
update_model_schemas(self._dag, models, self._context.path)
124128

125129
audits = self._load_audits()
126130

sqlmesh/core/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlmesh.core.model.cache import ModelCache
1+
from sqlmesh.core.model.cache import ModelCache, OptimizedQueryCache
22
from sqlmesh.core.model.common import parse_model_name
33
from sqlmesh.core.model.decorator import model
44
from sqlmesh.core.model.definition import (

sqlmesh/core/model/cache.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from sqlmesh.core.model.definition import Model, SqlModel
99
from sqlmesh.utils.cache import FileCache
10+
from sqlmesh.utils.hashing import crc32
1011
from sqlmesh.utils.pydantic import PydanticModel
1112

1213

@@ -44,14 +45,73 @@ def get_or_load(self, name: str, entry_id: str, loader: t.Callable[[], Model]) -
4445
cache_entry = self._file_cache.get(name, entry_id)
4546
if cache_entry:
4647
model = cache_entry.model
47-
model._query_renderer.update_cache(cache_entry.rendered_query)
48+
model._query_renderer.update_cache(cache_entry.rendered_query, optimized=False)
4849
return model
4950

5051
loaded_model = loader()
5152
if isinstance(loaded_model, SqlModel):
5253
new_entry = SqlModelCacheEntry(
53-
model=loaded_model, rendered_query=loaded_model.render_query()
54+
model=loaded_model, rendered_query=loaded_model.render_query(optimize=False)
5455
)
5556
self._file_cache.put(name, entry_id, new_entry)
5657

5758
return loaded_model
59+
60+
61+
class OptimizedQueryCacheEntry(PydanticModel):
62+
optimized_rendered_query: exp.Expression
63+
64+
65+
class OptimizedQueryCache:
66+
"""File-based cache implementation for optimized model queries.
67+
68+
Args:
69+
path: The path to the cache folder.
70+
"""
71+
72+
def __init__(self, path: Path):
73+
self.path = path
74+
self._file_cache: FileCache[OptimizedQueryCacheEntry] = FileCache(
75+
path, OptimizedQueryCacheEntry, prefix="optimized_query"
76+
)
77+
78+
def with_optimized_query(self, model: Model) -> bool:
79+
"""Adds an optimized query to the model's in-memory cache.
80+
81+
Args:
82+
model: The model to add the optimized query to.
83+
"""
84+
if not isinstance(model, SqlModel):
85+
return True
86+
87+
entry_id = self._entry_id(model)
88+
cache_entry = self._file_cache.get(model.name, entry_id)
89+
if cache_entry:
90+
model._query_renderer.update_cache(cache_entry.optimized_rendered_query, optimized=True)
91+
return True
92+
93+
new_entry = OptimizedQueryCacheEntry(
94+
optimized_rendered_query=model.render_query(optimize=True)
95+
)
96+
self._file_cache.put(model.name, entry_id, new_entry)
97+
return False
98+
99+
@staticmethod
100+
def _entry_id(model: SqlModel) -> str:
101+
data = OptimizedQueryCache._mapping_schema_hash_data(model.mapping_schema)
102+
data.append(model.render_query(optimize=False).sql())
103+
return crc32(data)
104+
105+
@staticmethod
106+
def _mapping_schema_hash_data(schema: t.Dict[str, t.Any]) -> t.List[str]:
107+
keys = sorted(schema) if all(isinstance(v, dict) for v in schema.values()) else schema
108+
109+
data = []
110+
for k in keys:
111+
data.append(k)
112+
if isinstance(schema[k], dict):
113+
data.extend(OptimizedQueryCache._mapping_schema_hash_data(schema[k]))
114+
else:
115+
data.append(str(schema[k]))
116+
117+
return data

sqlmesh/core/model/definition.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def depends_on(self) -> t.Set[str]:
433433
return self.depends_on_ - {self.name}
434434

435435
if self._depends_on is None:
436-
self._depends_on = _find_tables(self._render_all_sql()) - {self.name}
436+
self._depends_on = _find_tables(self.render_query(optimize=False)) - {self.name}
437437
return self._depends_on
438438

439439
@property
@@ -481,7 +481,7 @@ def depends_on_past(self) -> bool:
481481
if self._depends_on_past is None:
482482
self._depends_on_past = (
483483
self.kind.is_incremental_by_unique_key
484-
or self.name in _find_tables([self.render_query()])
484+
or self.name in _find_tables(self.render_query(optimize=False))
485485
)
486486
return self._depends_on_past
487487

@@ -529,14 +529,6 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
529529
"""
530530
raise NotImplementedError
531531

532-
def _render_all_sql(self) -> t.List[exp.Expression]:
533-
"""Renders all the SQL expressions of this model."""
534-
return [
535-
*self.render_pre_statements(),
536-
self.render_query(),
537-
*self.render_post_statements(),
538-
]
539-
540532

541533
class _SqlBasedModel(_Model):
542534
pre_statements_: t.Optional[t.List[exp.Expression]] = Field(
@@ -713,17 +705,18 @@ def column_descriptions(self) -> t.Dict[str, str]:
713705
if self._column_descriptions is None:
714706
self._column_descriptions = {
715707
select.alias: "\n".join(comment.strip() for comment in select.comments)
716-
for select in self.render_query().selects
708+
for select in self.render_query(optimize=False).selects
717709
if select.comments
718710
}
719711
return self._column_descriptions
720712

721713
def update_schema(self, schema: MappingSchema) -> None:
722714
super().update_schema(schema)
723715
self._columns_to_types = None
716+
self._query_renderer._optimized_cache = {}
724717

725718
def validate_definition(self) -> None:
726-
query = self._query_renderer.render()
719+
query = self._query_renderer.render(optimize=False)
727720

728721
if not isinstance(query, exp.Subqueryable):
729722
raise_config_error("Missing SELECT query in the model definition", self._path)
@@ -1444,7 +1437,7 @@ def _validate_model_fields(klass: t.Type[_Model], provided_fields: t.Set[str], p
14441437
raise_config_error(f"Invalid extra fields {extra_fields} in the model definition", path)
14451438

14461439

1447-
def _find_tables(expressions: t.List[exp.Expression]) -> t.Set[str]:
1440+
def _find_tables(expression: exp.Expression) -> t.Set[str]:
14481441
"""Find all tables referenced in a query.
14491442
14501443
Args:
@@ -1455,8 +1448,7 @@ def _find_tables(expressions: t.List[exp.Expression]) -> t.Set[str]:
14551448
"""
14561449
return {
14571450
exp.table_name(table)
1458-
for e in expressions
1459-
for scope in traverse_scope(e)
1451+
for scope in traverse_scope(expression)
14601452
for table in scope.tables
14611453
if isinstance(table.this, exp.Identifier) and exp.table_name(table) not in scope.cte_sources
14621454
}

0 commit comments

Comments
 (0)