Skip to content

Commit 7fd6b7d

Browse files
authored
Feat: make strict dependency resolution optional in python models (#3654)
1 parent ab6bc17 commit 7fd6b7d

4 files changed

Lines changed: 96 additions & 17 deletions

File tree

docs/concepts/models/python_models.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def execute(
200200
Table resolution for these statements occurs at the virtual layer. This means that table names, including `@this_model` macro, are resolved to their qualified view names. For instance, when running the plan in an environment named `dev`, `db.test_model` and `@this_model` would resolve to `db__dev.test_model` and not to the physical table name.
201201

202202
## Dependencies
203+
203204
In order to fetch data from an upstream model, you first get the table name using `context`'s `resolve_table` method. This returns the appropriate table name for the current runtime [environment](../environments.md):
204205

205206
```python linenums="1"
@@ -230,6 +231,21 @@ def execute(
230231
context.resolve_table("docs_example.another_dependency")
231232
```
232233

234+
User-defined [global variables](global-variables) can also be used in `resolve_table` calls, as long as the `depends_on` keyword argument is present and contains the required dependencies. This is shown in the following example:
235+
236+
```python linenums="1"
237+
@model(
238+
"@schema_name.test_model2",
239+
kind="FULL",
240+
columns={"id": "INT"},
241+
depends_on=["@schema_name.test_model1"],
242+
)
243+
def execute(context, **kwargs):
244+
schema_name = context.var("schema_name")
245+
table = context.resolve_table(f"{schema_name}.test_model1")
246+
select_query = exp.select("*").from_(table)
247+
return context.fetchdf(select_query)
248+
```
233249

234250
## Returning empty dataframes
235251

sqlmesh/core/model/common.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def make_python_env(
2929
used_variables: t.Optional[t.Set[str]] = None,
3030
path: t.Optional[str | Path] = None,
3131
python_env: t.Optional[t.Dict[str, Executable]] = None,
32+
strict_resolution: bool = True,
3233
) -> t.Dict[str, Executable]:
3334
python_env = {} if python_env is None else python_env
3435
variables = variables or {}
@@ -81,15 +82,25 @@ def make_python_env(
8182
build_env(used_macro.func, env=env, name=name, path=module_path)
8283

8384
python_env.update(serialize_env(env, path=module_path))
84-
return _add_variables_to_python_env(python_env, used_variables, variables)
85+
return _add_variables_to_python_env(
86+
python_env,
87+
used_variables,
88+
variables,
89+
strict_resolution=strict_resolution,
90+
)
8591

8692

8793
def _add_variables_to_python_env(
8894
python_env: t.Dict[str, Executable],
8995
used_variables: t.Optional[t.Set[str]],
9096
variables: t.Optional[t.Dict[str, t.Any]],
97+
strict_resolution: bool = True,
9198
) -> t.Dict[str, Executable]:
92-
_, python_used_variables = parse_dependencies(python_env, None)
99+
_, python_used_variables = parse_dependencies(
100+
python_env,
101+
None,
102+
strict_resolution=strict_resolution,
103+
)
93104
used_variables = (used_variables or set()) | python_used_variables
94105

95106
variables = {k: v for k, v in (variables or {}).items() if k in used_variables}
@@ -100,12 +111,17 @@ def _add_variables_to_python_env(
100111

101112

102113
def parse_dependencies(
103-
python_env: t.Dict[str, Executable], entrypoint: t.Optional[str]
114+
python_env: t.Dict[str, Executable], entrypoint: t.Optional[str], strict_resolution: bool = True
104115
) -> t.Tuple[t.Set[str], t.Set[str]]:
105-
"""Parses the source of a model function and finds upstream table dependencies and referenced variables based on calls to context / evaluator.
116+
"""
117+
Parses the source of a model function and finds upstream table dependencies
118+
and referenced variables based on calls to context / evaluator.
106119
107120
Args:
108121
python_env: A dictionary of Python definitions.
122+
entrypoint: The name of the function.
123+
strict_resolution: If true, the arguments of `table` and `resolve_table` calls must
124+
be resolvable at parse time, otherwise an exception will be raised.
109125
110126
Returns:
111127
A tuple containing the set of upstream table dependencies and the set of referenced variables.
@@ -140,9 +156,11 @@ def get_first_arg(keyword_arg_name: str) -> t.Any:
140156
expression = to_source(first_arg)
141157
return eval(expression, env)
142158
except Exception:
143-
raise ConfigError(
144-
f"Error resolving dependencies for '{executable.path}'. Argument '{expression.strip()}' must be resolvable at parse time."
145-
)
159+
if strict_resolution:
160+
raise ConfigError(
161+
f"Error resolving dependencies for '{executable.path}'. "
162+
f"Argument '{expression.strip()}' must be resolvable at parse time."
163+
)
146164

147165
if func.value.id == "context" and func.attr in ("table", "resolve_table"):
148166
depends_on.add(get_first_arg("model_name"))

sqlmesh/core/model/definition.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,28 +1971,48 @@ def create_python_model(
19711971
python_env: The Python environment of all objects referenced by the model implementation.
19721972
path: An optional path to the model definition file.
19731973
depends_on: The custom set of model's upstream dependencies.
1974+
variables: The variables to pass to the model.
19741975
"""
19751976
# Find dependencies for python models by parsing code if they are not explicitly defined
19761977
# Also remove self-references that are found
19771978

19781979
dialect = kwargs.get("dialect")
1980+
renderer_kwargs = {
1981+
"module_path": module_path,
1982+
"macros": macros,
1983+
"jinja_macros": jinja_macros,
1984+
"variables": variables,
1985+
"path": path,
1986+
"dialect": dialect,
1987+
"default_catalog": kwargs.get("default_catalog"),
1988+
}
1989+
19791990
name_renderer = _meta_renderer(
19801991
expression=d.parse_one(name, dialect=dialect),
1981-
module_path=module_path,
1982-
macros=macros,
1983-
jinja_macros=jinja_macros,
1984-
variables=variables,
1985-
path=path,
1986-
dialect=dialect,
1987-
default_catalog=kwargs.get("default_catalog"),
1992+
**renderer_kwargs, # type: ignore
19881993
)
19891994
name = t.cast(t.List[exp.Expression], name_renderer.render())[0].sql(dialect=dialect)
19901995

1996+
dependencies_unspecified = depends_on is None
1997+
19911998
parsed_depends_on, referenced_variables = (
1992-
parse_dependencies(python_env, entrypoint) if python_env is not None else (set(), set())
1999+
parse_dependencies(python_env, entrypoint, strict_resolution=dependencies_unspecified)
2000+
if python_env is not None
2001+
else (set(), set())
19932002
)
1994-
if depends_on is None:
2003+
if dependencies_unspecified:
19952004
depends_on = parsed_depends_on - {name}
2005+
else:
2006+
depends_on_renderer = _meta_renderer(
2007+
expression=exp.Array(
2008+
expressions=[d.parse_one(dep, dialect=dialect) for dep in depends_on or []]
2009+
),
2010+
**renderer_kwargs, # type: ignore
2011+
)
2012+
depends_on = {
2013+
dep.sql(dialect=dialect)
2014+
for dep in t.cast(t.List[exp.Expression], depends_on_renderer.render())[0].expressions
2015+
}
19962016

19972017
variables = {k: v for k, v in (variables or {}).items() if k in referenced_variables}
19982018
if variables:
@@ -2161,6 +2181,7 @@ def _create_model(
21612181
used_variables=used_variables,
21622182
path=path,
21632183
python_env=python_env,
2184+
strict_resolution=depends_on is None,
21642185
)
21652186

21662187
env: t.Dict[str, t.Any] = {}

tests/core/test_model.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1923,6 +1923,30 @@ def my_model(context, **kwargs):
19231923
assert m.depends_on == {'"foo"."bar"'}
19241924

19251925

1926+
def test_python_model_variable_dependencies() -> None:
1927+
@model(
1928+
name="bla.test_model_var_dep",
1929+
kind="full",
1930+
columns={'"col"': "int"},
1931+
depends_on={"@schema_name.table_name"},
1932+
)
1933+
def my_model(context, **kwargs):
1934+
# Even though the argument is not statically resolvable, no error
1935+
# is raised, because the `depends_on` property is present
1936+
schema_name = context.var("schema_name")
1937+
table = context.resolve_table(f"{schema_name}.table_name")
1938+
1939+
return context.fetchdf(exp.select("*").from_(table))
1940+
1941+
m = model.get_registry()["bla.test_model_var_dep"].model(
1942+
module_path=Path("."),
1943+
path=Path("."),
1944+
variables={"schema_name": "foo"},
1945+
)
1946+
1947+
assert m.depends_on == {'"foo"."table_name"'}
1948+
1949+
19261950
def test_python_model_with_session_properties():
19271951
@model(
19281952
name="python_model_prop",
@@ -5321,7 +5345,7 @@ def a_model(context):
53215345
# column type not parseable by default dialect and no explicit dialect: error
53225346
model._dialect = "snowflake"
53235347

5324-
with pytest.raises(ParseError, match="No expression was parsed from 'DateTime64\(9\)'"):
5348+
with pytest.raises(ParseError, match="No expression was parsed from 'DateTime64\\(9\\)'"):
53255349

53265350
@model("bad", columns={'"COL"': "DateTime64(9)"})
53275351
def a_model(context):

0 commit comments

Comments
 (0)