Skip to content

Commit a8e5abf

Browse files
authored
Feat: Expose the gateway name as a macro (#1509)
1 parent 9c67234 commit a8e5abf

8 files changed

Lines changed: 72 additions & 8 deletions

File tree

examples/sushi/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131

3232
# A configuration used for SQLMesh tests.
3333
test_config = Config(
34-
default_connection=DuckDBConnectionConfig(),
34+
gateways={"in_memory": GatewayConfig(connection=DuckDBConnectionConfig())},
35+
default_gateway="in_memory",
3536
auto_categorize_changes=CategorizerConfig(sql=AutoCategorizationMode.SEMI),
3637
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
3738
)

sqlmesh/core/config/root.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import zlib
66

77
from pydantic import Field
8+
from sqlglot.helper import first
89

910
from sqlmesh.core import constants as c
1011
from sqlmesh.core.config import EnvironmentSuffixTarget
@@ -122,7 +123,7 @@ def get_gateway(self, name: t.Optional[str] = None) -> GatewayConfig:
122123
if "" in self.gateways:
123124
return self.gateways[""]
124125

125-
return next(iter(self.gateways.values()))
126+
return first(self.gateways.values())
126127

127128
if name not in self.gateways:
128129
raise ConfigError(f"Missing gateway with name '{name}'.")
@@ -152,6 +153,14 @@ def get_scheduler(self, gateway_name: t.Optional[str] = None) -> SchedulerConfig
152153
def get_state_schema(self, gateway_name: t.Optional[str] = None) -> t.Optional[str]:
153154
return self.get_gateway(gateway_name).state_schema
154155

156+
@property
157+
def default_gateway_name(self) -> str:
158+
if self.default_gateway:
159+
return self.default_gateway
160+
if "" in self.gateways:
161+
return ""
162+
return first(self.gateways)
163+
155164
@property
156165
def dialect(self) -> t.Optional[str]:
157166
return self.model_defaults.dialect

sqlmesh/core/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def load(self, update_schemas: bool = True) -> Context:
388388
gc.disable()
389389
project = self._loader.load(self, update_schemas)
390390
self._macros = project.macros
391+
self._jinja_macros = project.jinja_macros
391392
self._models = project.models
392393
self._metrics = project.metrics
393394
self._standalone_audits.clear()

sqlmesh/core/loader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from sqlmesh.utils.dag import DAG
3434
from sqlmesh.utils.errors import ConfigError
3535
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
36+
from sqlmesh.utils.metaprogramming import Executable
3637
from sqlmesh.utils.yaml import YAML
3738

3839
if t.TYPE_CHECKING:
@@ -256,6 +257,10 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
256257
macros = macro.get_registry()
257258
macro.set_registry(standard_macros)
258259

260+
gateway_name = self._context.gateway or self._context.config.default_gateway_name
261+
macros["gateway"] = Executable.value(gateway_name)
262+
jinja_macros.add_globals({"gateway": gateway_name})
263+
259264
return macros, jinja_macros
260265

261266
def _load_models(

sqlmesh/core/macros.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def __init__(
114114
self.macros[normalize_macro_name(k)] = self.env[v.name or k]
115115
elif v.is_import and getattr(self.env.get(k), "__sqlmesh_macro__", None):
116116
self.macros[normalize_macro_name(k)] = self.env[k]
117+
elif v.is_value:
118+
self.locals[k] = self.env[k]
117119

118120
def send(
119121
self, name: str, *args: t.Any

sqlmesh/core/model/definition.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,10 +1823,14 @@ def _python_env(
18231823
expressions = ensure_list(expressions)
18241824
for expression in expressions:
18251825
if not isinstance(expression, d.Jinja):
1826-
for macro_func in expression.find_all(d.MacroFunc):
1827-
if macro_func.__class__ is d.MacroFunc:
1828-
name = macro_func.this.name.lower()
1826+
for macro_func_or_var in expression.find_all(d.MacroFunc, d.MacroVar):
1827+
if macro_func_or_var.__class__ is d.MacroFunc:
1828+
name = macro_func_or_var.this.name.lower()
18291829
used_macros[name] = macros[name]
1830+
elif macro_func_or_var.__class__ is d.MacroVar:
1831+
name = macro_func_or_var.name
1832+
if name in macros:
1833+
used_macros[name] = macros[name]
18301834

18311835
for macro_ref in jinja_macro_references or set():
18321836
if macro_ref.package is None and macro_ref.name in macros:

sqlmesh/utils/metaprogramming.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def __lt__(self, other: t.Any) -> bool:
304304

305305

306306
class Executable(PydanticModel):
307-
payload: t.Any
307+
payload: str
308308
kind: ExecutableKind = ExecutableKind.DEFINITION
309309
name: t.Optional[str] = None
310310
path: t.Optional[str] = None
@@ -326,6 +326,10 @@ def is_statement(self) -> bool:
326326
def is_value(self) -> bool:
327327
return self.kind == ExecutableKind.VALUE
328328

329+
@classmethod
330+
def value(cls, v: t.Any) -> Executable:
331+
return Executable(payload=repr(v), kind=ExecutableKind.VALUE)
332+
329333

330334
def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable]:
331335
"""Serializes a python function into a self contained dictionary.
@@ -370,7 +374,7 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable
370374
kind=ExecutableKind.IMPORT,
371375
)
372376
else:
373-
serialized[k] = Executable(payload=repr(v), kind=ExecutableKind.VALUE)
377+
serialized[k] = Executable.value(v)
374378

375379
return serialized
376380

tests/core/test_context.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def test():
287287
context = Context(paths=str(tmp_path), config=config)
288288

289289
assert ["db.actual_test"] == list(context.models)
290-
assert "test" == list(context._macros)[-1]
290+
assert "test" in context._macros
291291

292292

293293
def test_plan_apply(sushi_context) -> None:
@@ -477,3 +477,41 @@ def test_default_schema_and_config(sushi_context_pre_scheduling) -> None:
477477
MappingSchema({"a": {"col": "int"}}), default_schema="schema", default_catalog="catalog"
478478
)
479479
assert c.mapping_schema == {"catalog": {"schema": {"a": {"col": "int"}}}}
480+
481+
482+
def test_gateway_macro(sushi_context: Context) -> None:
483+
sushi_context.upsert_model(
484+
load_sql_based_model(
485+
parse(
486+
"""
487+
MODEL(name sushi.test_gateway_macro);
488+
SELECT @gateway AS gateway;
489+
"""
490+
),
491+
macros=sushi_context._macros,
492+
)
493+
)
494+
495+
assert (
496+
sushi_context.render("sushi.test_gateway_macro").sql()
497+
== "SELECT 'in_memory' AS \"gateway\""
498+
)
499+
500+
sushi_context.upsert_model(
501+
load_sql_based_model(
502+
parse(
503+
"""
504+
MODEL(name sushi.test_gateway_macro_jinja);
505+
JINJA_QUERY_BEGIN;
506+
SELECT '{{ gateway }}' AS gateway_jinja;
507+
JINJA_END;
508+
"""
509+
),
510+
jinja_macros=sushi_context._jinja_macros,
511+
)
512+
)
513+
514+
assert (
515+
sushi_context.render("sushi.test_gateway_macro_jinja").sql()
516+
== "SELECT 'in_memory' AS \"gateway_jinja\""
517+
)

0 commit comments

Comments
 (0)