Skip to content

Commit 8b0e3cf

Browse files
authored
Feat: Support expressions when selecting models for restatement (#1486)
1 parent 1fcf609 commit 8b0e3cf

2 files changed

Lines changed: 18 additions & 5 deletions

File tree

sqlmesh/core/context.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -761,11 +761,16 @@ def plan(
761761
if include_unmodified is None:
762762
include_unmodified = self.config.include_unmodified
763763

764+
model_selector = Selector(self.state_reader, self._models, self.path, dag=self.dag)
765+
764766
models_override: t.Optional[UniqueKeyDict[str, Model]] = None
765767
if select_models:
766-
models_override = Selector(
767-
self.state_reader, self._models, self.path, dag=self.dag
768-
).select_models(select_models, environment, fallback_env_name=create_from or c.PROD)
768+
models_override = model_selector.select_models(
769+
select_models, environment, fallback_env_name=create_from or c.PROD
770+
)
771+
772+
if restate_models is not None:
773+
restate_models = model_selector.expand_model_selections(restate_models)
769774

770775
# If no end date is specified, use the max interval end from prod
771776
# to prevent unintended evaluation of the entire DAG.

sqlmesh/core/selector.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def select_models(
6767
).values()
6868
}
6969

70-
all_selected_models = self._expand_model_selections(model_selections)
70+
all_selected_models = self.expand_model_selections(model_selections)
7171

7272
dag: DAG[str] = DAG()
7373
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
@@ -91,7 +91,15 @@ def select_models(
9191

9292
return models
9393

94-
def _expand_model_selections(self, model_selections: t.Iterable[str]) -> t.Set[str]:
94+
def expand_model_selections(self, model_selections: t.Iterable[str]) -> t.Set[str]:
95+
"""Expands a set of model selections into a set of model names.
96+
97+
Args:
98+
model_selections: A set of model selections.
99+
100+
Returns:
101+
A set of model names.
102+
"""
95103
result: t.Set[str] = set()
96104

97105
def _add_model(model_name: str, include_upstream: bool, include_downstream: bool) -> None:

0 commit comments

Comments
 (0)