@@ -1058,7 +1058,7 @@ def load_model(
10581058 if not name :
10591059 raise_config_error ("Model must have a name" , path )
10601060
1061- macro_references : t .Set [MacroReference ] = {
1061+ jinja_macro_references : t .Set [MacroReference ] = {
10621062 r
10631063 for references in [
10641064 * [extract_macro_references (e .sql (dialect = dialect )) for e in pre_statements ],
@@ -1067,41 +1067,40 @@ def load_model(
10671067 for r in references
10681068 }
10691069
1070+ common_kwargs = dict (
1071+ pre_statements = pre_statements ,
1072+ post_statements = post_statements ,
1073+ defaults = defaults ,
1074+ path = path ,
1075+ module_path = module_path ,
1076+ macros = macros ,
1077+ python_env = python_env ,
1078+ jinja_macros = jinja_macros ,
1079+ jinja_macro_references = jinja_macro_references ,
1080+ ** meta_fields ,
1081+ )
1082+
10701083 if query_or_seed_insert is not None and isinstance (
10711084 query_or_seed_insert , (exp .Subqueryable , d .JinjaQuery )
10721085 ):
1073- macro_references .update (extract_macro_references (query_or_seed_insert .sql (dialect = dialect )))
1086+ jinja_macro_references .update (
1087+ extract_macro_references (query_or_seed_insert .sql (dialect = dialect ))
1088+ )
10741089 return create_sql_model (
10751090 name ,
10761091 query_or_seed_insert ,
1077- pre_statements = pre_statements ,
1078- post_statements = post_statements ,
1079- defaults = defaults ,
1080- path = path ,
1081- module_path = module_path ,
10821092 time_column_format = time_column_format ,
1083- macros = macros ,
1084- jinja_macros = (jinja_macros or JinjaMacroRegistry ()).trim (macro_references ),
1085- python_env = python_env ,
1086- ** meta_fields ,
1093+ ** common_kwargs ,
10871094 )
10881095 else :
10891096 try :
10901097 seed_properties = {
1091- p .name .lower (): p .args .get ("value" ) for p in meta_fields .pop ("kind" ).expressions
1098+ p .name .lower (): p .args .get ("value" ) for p in common_kwargs .pop ("kind" ).expressions
10921099 }
10931100 return create_seed_model (
10941101 name ,
10951102 SeedKind (** seed_properties ),
1096- pre_statements = pre_statements ,
1097- post_statements = post_statements ,
1098- defaults = defaults ,
1099- path = path ,
1100- module_path = module_path ,
1101- macros = macros ,
1102- jinja_macros = (jinja_macros or JinjaMacroRegistry ()).trim (macro_references ),
1103- python_env = python_env ,
1104- ** meta_fields ,
1103+ ** common_kwargs ,
11051104 )
11061105 except Exception :
11071106 raise_config_error (
@@ -1123,6 +1122,8 @@ def create_sql_model(
11231122 time_column_format : str = c .DEFAULT_TIME_COLUMN_FORMAT ,
11241123 macros : t .Optional [MacroRegistry ] = None ,
11251124 python_env : t .Optional [t .Dict [str , Executable ]] = None ,
1125+ jinja_macros : t .Optional [JinjaMacroRegistry ] = None ,
1126+ jinja_macro_references : t .Optional [t .Set [MacroReference ]] = None ,
11261127 dialect : t .Optional [str ] = None ,
11271128 ** kwargs : t .Any ,
11281129) -> Model :
@@ -1156,6 +1157,7 @@ def create_sql_model(
11561157 if not python_env :
11571158 python_env = _python_env (
11581159 [* pre_statements , query , * post_statements ],
1160+ jinja_macro_references ,
11591161 module_path ,
11601162 macros or macro .get_registry (),
11611163 )
@@ -1167,6 +1169,8 @@ def create_sql_model(
11671169 path = path ,
11681170 time_column_format = time_column_format ,
11691171 python_env = python_env ,
1172+ jinja_macros = jinja_macros ,
1173+ jinja_macro_references = jinja_macro_references ,
11701174 dialect = dialect ,
11711175 query = query ,
11721176 pre_statements = pre_statements ,
@@ -1186,6 +1190,8 @@ def create_seed_model(
11861190 module_path : Path = Path (),
11871191 macros : t .Optional [MacroRegistry ] = None ,
11881192 python_env : t .Optional [t .Dict [str , Executable ]] = None ,
1193+ jinja_macros : t .Optional [JinjaMacroRegistry ] = None ,
1194+ jinja_macro_references : t .Optional [t .Set [MacroReference ]] = None ,
11891195 ** kwargs : t .Any ,
11901196) -> Model :
11911197 """Creates a Seed model.
@@ -1213,6 +1219,7 @@ def create_seed_model(
12131219 if not python_env :
12141220 python_env = _python_env (
12151221 [* pre_statements , * post_statements ],
1222+ jinja_macro_references ,
12161223 module_path ,
12171224 macros or macro .get_registry (),
12181225 )
@@ -1226,6 +1233,8 @@ def create_seed_model(
12261233 kind = seed_kind ,
12271234 depends_on = kwargs .pop ("depends_on" , set ()),
12281235 python_env = python_env ,
1236+ jinja_macros = jinja_macros ,
1237+ jinja_macro_references = jinja_macro_references ,
12291238 pre_statements = pre_statements ,
12301239 post_statements = post_statements ,
12311240 ** kwargs ,
@@ -1306,6 +1315,8 @@ def _create_model(
13061315 defaults : t .Optional [t .Dict [str , t .Any ]] = None ,
13071316 path : Path = Path (),
13081317 time_column_format : str = c .DEFAULT_TIME_COLUMN_FORMAT ,
1318+ jinja_macros : t .Optional [JinjaMacroRegistry ] = None ,
1319+ jinja_macro_references : t .Optional [t .Set [MacroReference ]] = None ,
13091320 depends_on : t .Optional [t .Set [str ]] = None ,
13101321 dialect : t .Optional [str ] = None ,
13111322 ** kwargs : t .Any ,
@@ -1314,11 +1325,16 @@ def _create_model(
13141325
13151326 dialect = dialect or ""
13161327
1328+ jinja_macros = jinja_macros or JinjaMacroRegistry ()
1329+ if jinja_macro_references is not None :
1330+ jinja_macros = jinja_macros .trim (jinja_macro_references )
1331+
13171332 try :
13181333 model = klass (
13191334 name = name ,
13201335 ** {
13211336 ** (defaults or {}),
1337+ "jinja_macros" : jinja_macros ,
13221338 "dialect" : dialect ,
13231339 "depends_on" : depends_on ,
13241340 ** kwargs ,
@@ -1403,27 +1419,25 @@ def _find_tables(expressions: t.List[exp.Expression]) -> t.Set[str]:
14031419
14041420def _python_env (
14051421 expressions : t .Union [exp .Expression , t .List [exp .Expression ]],
1422+ jinja_macro_references : t .Optional [t .Set [MacroReference ]],
14061423 module_path : Path ,
14071424 macros : MacroRegistry ,
14081425) -> t .Dict [str , Executable ]:
14091426 python_env : t .Dict [str , Executable ] = {}
14101427
14111428 used_macros = {}
14121429
1413- def _capture_expression_macros (expression : exp .Expression ) -> None :
1414- if isinstance (expression , d .Jinja ):
1415- for var in expression .expressions :
1416- if var in macros :
1417- used_macros [var ] = macros [var ]
1418- else :
1430+ expressions = ensure_list (expressions )
1431+ for expression in expressions :
1432+ if not isinstance (expression , d .Jinja ):
14191433 for macro_func in expression .find_all (d .MacroFunc ):
14201434 if macro_func .__class__ is d .MacroFunc :
14211435 name = macro_func .this .name .lower ()
14221436 used_macros [name ] = macros [name ]
14231437
1424- expressions = ensure_list ( expressions )
1425- for expression in expressions :
1426- _capture_expression_macros ( expression )
1438+ for macro_ref in jinja_macro_references or set ():
1439+ if macro_ref . package is None and macro_ref . name in macros :
1440+ used_macros [ macro_ref . name ] = macros [ macro_ref . name ]
14271441
14281442 for name , macro in used_macros .items ():
14291443 if not macro .func .__module__ .startswith ("sqlmesh." ):
0 commit comments