Skip to content

Commit 1fcf609

Browse files
authored
Feat: Support custom session properties in the Spark Submit Airflow operator (#1482)
1 parent 55da1a5 commit 1fcf609

7 files changed

Lines changed: 113 additions & 12 deletions

File tree

examples/sushi/models/items.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@
6969
"float_prop": 1.0,
7070
"bool_prop": True,
7171
},
72+
session_properties={
73+
"string_prop": "some_value",
74+
"int_prop": 1,
75+
"float_prop": 1.0,
76+
"bool_prop": True,
77+
},
7278
)
7379
def execute(
7480
context: ExecutionContext,

sqlmesh/core/model/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def parse_properties(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Opt
111111

112112
properties_validator = field_validator(
113113
"table_properties_",
114+
"session_properties_",
114115
mode="before",
115116
check_fields=False,
116117
)(parse_properties)

sqlmesh/core/model/definition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str:
744744
str(self.disable_restatement),
745745
self.project,
746746
str(self.allow_partials),
747+
self.session_properties_.sql() if self.session_properties_ else None,
747748
]
748749

749750
for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]):
@@ -1926,4 +1927,5 @@ def _refs_to_sql(values: t.Any) -> exp.Expression:
19261927
"references": _refs_to_sql,
19271928
"hash_raw_query": exp.convert,
19281929
"table_properties_": lambda value: value,
1930+
"session_properties_": lambda value: value,
19291931
}

sqlmesh/core/model/meta.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ class ModelMeta(_Node, extra="allow"):
5151
hash_raw_query: bool = False
5252
physical_schema_override: t.Optional[str] = None
5353
table_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="table_properties")
54+
session_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="session_properties")
5455
allow_partials: bool = False
5556

5657
_table_properties: t.Dict[str, exp.Expression] = {}
5758

5859
_bool_validator = bool_validator
5960
_model_kind_validator = model_kind_validator
60-
_table_properties_validator = properties_validator
61+
_properties_validator = properties_validator
6162

6263
@field_validator("audits", mode="before")
6364
@classmethod
@@ -332,6 +333,27 @@ def table_properties(self) -> t.Dict[str, exp.Expression]:
332333
self._table_properties[expression.this.name] = expression.expression
333334
return self._table_properties
334335

336+
@property
337+
def session_properties(self) -> t.Dict[str, t.Union[exp.Expression | str | int | float | bool]]:
338+
"""A dictionary of session properties."""
339+
if not self.session_properties_:
340+
return {}
341+
342+
def _interpret_expr(
343+
e: exp.Expression,
344+
) -> t.Union[exp.Expression | str | int | float | bool]:
345+
if e.is_int:
346+
return int(e.this)
347+
if e.is_number:
348+
return float(e.this)
349+
if isinstance(e, (exp.Literal, exp.Boolean)):
350+
return e.this
351+
return e
352+
353+
return {
354+
e.this.name: _interpret_expr(e.expression) for e in self.session_properties_.expressions
355+
}
356+
335357
@property
336358
def all_references(self) -> t.List[Reference]:
337359
"""All references including grains."""

sqlmesh/schedulers/airflow/operators/spark_submit.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
import sqlmesh
1010
from sqlmesh.engines import commands
11-
from sqlmesh.schedulers.airflow.operators.targets import BaseTarget
11+
from sqlmesh.schedulers.airflow.operators.targets import (
12+
BaseTarget,
13+
SnapshotEvaluationTarget,
14+
)
1215

1316

1417
class SQLMeshSparkSubmitOperator(BaseOperator):
@@ -54,7 +57,7 @@ def __init__(
5457
super().__init__(**kwargs)
5558
self._target = target
5659
self._application_name = application_name
57-
self._spark_conf = spark_conf
60+
self._spark_conf = spark_conf or {}
5861
self._total_executor_cores = total_executor_cores
5962
self._executor_cores = executor_cores
6063
self._executor_memory = executor_memory
@@ -77,24 +80,59 @@ def execute(self, context: Context) -> None:
7780
payload_fd.write(command_payload)
7881

7982
if self._hook is None:
83+
if (
84+
isinstance(self._target, SnapshotEvaluationTarget)
85+
and self._target.snapshot.is_model
86+
):
87+
session_properties = self._target.snapshot.model.session_properties
88+
executor_cores: t.Optional[int] = session_properties.pop( # type: ignore
89+
"spark.executor.cores", self._executor_cores
90+
)
91+
executor_memory: t.Optional[str] = session_properties.pop( # type: ignore
92+
"spark.executor.memory", self._executor_memory
93+
)
94+
driver_memory: t.Optional[str] = session_properties.pop( # type: ignore
95+
"spark.driver.memory", self._driver_memory
96+
)
97+
num_executors: t.Optional[int] = session_properties.pop( # type: ignore
98+
"spark.executor.instances", self._num_executors
99+
)
100+
spark_conf: t.Dict[str, t.Any] = {**self._spark_conf, **session_properties}
101+
else:
102+
executor_cores = self._executor_cores
103+
executor_memory = self._executor_memory
104+
driver_memory = self._driver_memory
105+
num_executors = self._num_executors
106+
spark_conf = self._spark_conf
107+
80108
self._hook = self._get_hook(
81109
self._target.command_type,
82110
payload_file_path,
83111
self._target.ddl_concurrent_tasks,
112+
spark_conf,
113+
executor_cores,
114+
executor_memory,
115+
driver_memory,
116+
num_executors,
84117
)
85118
self._hook.submit(self._application)
86119
self._target.post_hook(context)
87120

88121
def on_kill(self) -> None:
89122
if self._hook is None:
90-
self._hook = self._get_hook(None, None, None)
123+
self._hook = self._get_hook(None, None, None, None, None, None, None, None)
91124
self._hook.on_kill()
92125

93126
def _get_hook(
94127
self,
95128
command_type: t.Optional[commands.CommandType],
96129
command_payload_file_path: t.Optional[str],
97130
ddl_concurrent_tasks: t.Optional[int],
131+
spark_conf: t.Optional[t.Dict[str, t.Any]],
132+
executor_cores: t.Optional[int],
133+
executor_memory: t.Optional[str],
134+
driver_memory: t.Optional[str],
135+
num_executors: t.Optional[int],
98136
) -> SparkSubmitHook:
99137
application_args = {
100138
"dialect": "spark",
@@ -105,17 +143,17 @@ def _get_hook(
105143
else None,
106144
}
107145
return SparkSubmitHook(
108-
conf=self._spark_conf,
146+
conf=spark_conf,
109147
conn_id=self._connection_id,
110148
total_executor_cores=self._total_executor_cores,
111-
executor_cores=self._executor_cores,
112-
executor_memory=self._executor_memory,
113-
driver_memory=self._driver_memory,
149+
executor_cores=executor_cores,
150+
executor_memory=executor_memory,
151+
driver_memory=driver_memory,
114152
keytab=self._keytab,
115153
principal=self._principal,
116154
proxy_user=self._proxy_user,
117155
name=self._application_name,
118-
num_executors=self._num_executors,
156+
num_executors=num_executors,
119157
application_args=[f"--{k}={v}" for k, v in application_args.items() if v is not None],
120158
files=command_payload_file_path,
121159
)

tests/core/test_model.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,6 +1997,38 @@ def test_model_table_properties(sushi_context):
19971997
)
19981998

19991999

2000+
def test_model_session_properties(sushi_context):
2001+
assert sushi_context.models["sushi.items"].session_properties == {
2002+
"string_prop": "some_value",
2003+
"int_prop": 1,
2004+
"float_prop": 1.0,
2005+
"bool_prop": True,
2006+
}
2007+
model = load_sql_based_model(
2008+
d.parse(
2009+
"""
2010+
MODEL (
2011+
name test_schema.test_model,
2012+
session_properties (
2013+
'spark.executor.cores' = 2,
2014+
'spark.executor.memory' = '1G',
2015+
some_bool = True,
2016+
some_float = 0.1,
2017+
)
2018+
);
2019+
SELECT a FROM tbl;
2020+
"""
2021+
)
2022+
)
2023+
2024+
assert model.session_properties == {
2025+
"spark.executor.cores": 2,
2026+
"spark.executor.memory": "1G",
2027+
"some_bool": True,
2028+
"some_float": 0.1,
2029+
}
2030+
2031+
20002032
def test_model_jinja_macro_rendering():
20012033
expressions = d.parse(
20022034
"""

tests/core/test_snapshot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def test_fingerprint(model: Model, parent_model: Model):
437437

438438
original_fingerprint = SnapshotFingerprint(
439439
data_hash="3811098861",
440-
metadata_hash="541992912",
440+
metadata_hash="3858405978",
441441
)
442442

443443
assert fingerprint == original_fingerprint
@@ -484,7 +484,7 @@ def test_fingerprint_seed_model():
484484

485485
expected_fingerprint = SnapshotFingerprint(
486486
data_hash="3270932819",
487-
metadata_hash="2823924537",
487+
metadata_hash="1017437962",
488488
)
489489

490490
model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql"))
@@ -524,7 +524,7 @@ def test_fingerprint_jinja_macros(model: Model):
524524
)
525525
original_fingerprint = SnapshotFingerprint(
526526
data_hash="2864998504",
527-
metadata_hash="541992912",
527+
metadata_hash="3858405978",
528528
)
529529

530530
fingerprint = fingerprint_from_node(model, nodes={})

0 commit comments

Comments
 (0)