Skip to content

Commit d0833c0

Browse files
authored
Feat: traverse python objects in post-order to support custom decorators (#3762)
1 parent 8aff33e commit d0833c0

2 files changed

Lines changed: 89 additions & 43 deletions

File tree

sqlmesh/utils/metaprogramming.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -278,68 +278,74 @@ def build_env(
278278
name: Name of the object in the env.
279279
path: The module path to serialize. Other modules will not be walked and treated as imports.
280280
"""
281+
# We don't rely on `env` to keep track of visited objects, because it's populated in post-order
282+
visited: t.Set[str] = set()
283+
284+
def walk(obj: t.Any, name: str) -> None:
285+
obj_module = inspect.getmodule(obj)
286+
if name in visited or (obj_module and obj_module.__name__ == "builtins"):
287+
return
288+
289+
visited.add(name)
290+
if name not in env:
291+
if hasattr(obj, c.SQLMESH_MACRO):
292+
# We only need to add the undecorated code of @macro() functions in env, which
293+
# is accessible through the `__wrapped__` attribute added by functools.wraps
294+
obj = obj.__wrapped__
295+
elif callable(obj) and not isinstance(obj, SERIALIZABLE_CALLABLES):
296+
obj = getattr(obj, "__wrapped__", None)
297+
name = getattr(obj, "__name__", "")
298+
299+
# Callable class instances shouldn't be serialized (e.g. tenacity.Retrying).
300+
# We still want to walk the callables they decorate, though
301+
if not isinstance(obj, SERIALIZABLE_CALLABLES) or name in env:
302+
return
303+
304+
if (
305+
not obj_module
306+
or not hasattr(obj_module, "__file__")
307+
or not _is_relative_to(obj_module.__file__, path)
308+
):
309+
env[name] = obj
310+
return
311+
elif env[name] != obj:
312+
raise SQLMeshError(
313+
f"Cannot store {obj} in environment, duplicate definitions found for '{name}'"
314+
)
281315

282-
obj_module = inspect.getmodule(obj)
283-
284-
if obj_module and obj_module.__name__ == "builtins":
285-
return
286-
287-
def walk(obj: t.Any) -> None:
288316
if inspect.isclass(obj):
289317
for var in decorator_vars(obj):
290318
if obj_module and var in obj_module.__dict__:
291-
build_env(
292-
obj_module.__dict__[var],
293-
env=env,
294-
name=var,
295-
path=path,
296-
)
319+
walk(obj_module.__dict__[var], var)
297320

298321
for base in obj.__bases__:
299-
build_env(base, env=env, name=base.__qualname__, path=path)
322+
walk(base, base.__qualname__)
300323

301324
for k, v in obj.__dict__.items():
302325
if k.startswith("__"):
303326
continue
304-
# traverse methods in a class to find global references
327+
328+
# Traverse methods in a class to find global references
305329
if isinstance(v, (classmethod, staticmethod)):
306330
v = v.__func__
331+
307332
if callable(v):
308-
# if the method is a part of the object, walk it
309-
# else it is a global function and we just store it
333+
# Walk the method if it's part of the object, else it's a global function and we just store it
310334
if v.__qualname__.startswith(obj.__qualname__):
311-
walk(v)
335+
for k, v in func_globals(v).items():
336+
walk(v, k)
312337
else:
313-
build_env(v, env=env, name=v.__name__, path=path)
338+
walk(v, v.__name__)
314339
elif callable(obj):
315340
for k, v in func_globals(obj).items():
316-
build_env(v, env=env, name=k, path=path)
317-
318-
if name not in env:
319-
if hasattr(obj, c.SQLMESH_MACRO):
320-
# We only need to add the undecorated code of @macro() functions in env, which
321-
# is accessible through the `__wrapped__` attribute added by functools.wraps
322-
obj = obj.__wrapped__
323-
elif callable(obj) and not isinstance(obj, SERIALIZABLE_CALLABLES):
324-
obj = getattr(obj, "__wrapped__", None)
325-
name = getattr(obj, "__name__", "")
326-
327-
# Callable class instances shouldn't be serialized (e.g. tenacity.Retrying).
328-
# We still want to walk the callables they decorate, though
329-
if not isinstance(obj, SERIALIZABLE_CALLABLES) or name in env:
330-
return
341+
walk(v, k)
331342

343+
# We store the object in the environment after its dependencies, because otherwise we
344+
# could crash at environment hydration time, since dicts are ordered and the top-level
345+
# objects would be loaded before their dependencies.
332346
env[name] = obj
333-
if (
334-
obj_module
335-
and hasattr(obj_module, "__file__")
336-
and _is_relative_to(obj_module.__file__, path)
337-
):
338-
walk(env[name])
339-
elif env[name] != obj:
340-
raise SQLMeshError(
341-
f"Cannot store {obj} in environment, duplicate definitions found for '{name}'"
342-
)
347+
348+
walk(obj, name)
343349

344350

345351
class ExecutableKind(str, Enum):

tests/utils/test_metaprogramming.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,18 @@ def fetch_data():
115115
return "'test data'"
116116

117117

118+
def custom_decorator(_func):
119+
def wrapper(*args, **kwargs):
120+
return _func(*args, **kwargs)
121+
122+
return wrapper
123+
124+
125+
@custom_decorator
126+
def function_with_custom_decorator():
127+
return
128+
129+
118130
def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2) -> int:
119131
"""DOC STRING"""
120132
sqlglot.parse_one("1")
@@ -123,6 +135,7 @@ def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2)
123135
noop_metadata()
124136
normalize_model_name("test")
125137
fetch_data()
138+
function_with_custom_decorator()
126139

127140
def closure(z: int) -> int:
128141
return z + Z
@@ -147,6 +160,7 @@ def test_func_globals() -> None:
147160
"expressions": exp,
148161
"fetch_data": fetch_data,
149162
"test_context_manager": test_context_manager,
163+
"function_with_custom_decorator": function_with_custom_decorator,
150164
}
151165
assert func_globals(other_func) == {
152166
"X": 1,
@@ -180,6 +194,7 @@ def test_normalize_source() -> None:
180194
noop_metadata()
181195
normalize_model_name('test')
182196
fetch_data()
197+
function_with_custom_decorator()
183198
184199
def closure(z: int):
185200
return z + Z
@@ -226,6 +241,7 @@ def test_serialize_env() -> None:
226241
noop_metadata()
227242
normalize_model_name('test')
228243
fetch_data()
244+
function_with_custom_decorator()
229245
230246
def closure(z: int):
231247
return z + Z
@@ -354,4 +370,28 @@ def fetch_data():
354370
path="test_metaprogramming.py",
355371
alias="f",
356372
),
373+
"function_with_custom_decorator": Executable(
374+
name="wrapper",
375+
path="test_metaprogramming.py",
376+
payload="""def wrapper(*args, **kwargs):
377+
return _func(*args, **kwargs)""",
378+
alias="function_with_custom_decorator",
379+
),
380+
"custom_decorator": Executable(
381+
name="custom_decorator",
382+
path="test_metaprogramming.py",
383+
payload="""def custom_decorator(_func):
384+
385+
def wrapper(*args, **kwargs):
386+
return _func(*args, **kwargs)
387+
return wrapper""",
388+
),
389+
"_func": Executable(
390+
name="function_with_custom_decorator",
391+
path="test_metaprogramming.py",
392+
payload="""@custom_decorator
393+
def function_with_custom_decorator():
394+
return""",
395+
alias="_func",
396+
),
357397
}

0 commit comments

Comments
 (0)