88
99import sqlmesh
1010from sqlmesh .engines import commands
11- from sqlmesh .schedulers .airflow .operators .targets import BaseTarget
11+ from sqlmesh .schedulers .airflow .operators .targets import (
12+ BaseTarget ,
13+ SnapshotEvaluationTarget ,
14+ )
1215
1316
1417class SQLMeshSparkSubmitOperator (BaseOperator ):
@@ -54,7 +57,7 @@ def __init__(
5457 super ().__init__ (** kwargs )
5558 self ._target = target
5659 self ._application_name = application_name
57- self ._spark_conf = spark_conf
60+ self ._spark_conf = spark_conf or {}
5861 self ._total_executor_cores = total_executor_cores
5962 self ._executor_cores = executor_cores
6063 self ._executor_memory = executor_memory
@@ -77,24 +80,59 @@ def execute(self, context: Context) -> None:
7780 payload_fd .write (command_payload )
7881
7982 if self ._hook is None :
83+ if (
84+ isinstance (self ._target , SnapshotEvaluationTarget )
85+ and self ._target .snapshot .is_model
86+ ):
87+ session_properties = self ._target .snapshot .model .session_properties
88+ executor_cores : t .Optional [int ] = session_properties .pop ( # type: ignore
89+ "spark.executor.cores" , self ._executor_cores
90+ )
91+ executor_memory : t .Optional [str ] = session_properties .pop ( # type: ignore
92+ "spark.executor.memory" , self ._executor_memory
93+ )
94+ driver_memory : t .Optional [str ] = session_properties .pop ( # type: ignore
95+ "spark.driver.memory" , self ._driver_memory
96+ )
97+ num_executors : t .Optional [int ] = session_properties .pop ( # type: ignore
98+ "spark.executor.instances" , self ._num_executors
99+ )
100+ spark_conf : t .Dict [str , t .Any ] = {** self ._spark_conf , ** session_properties }
101+ else :
102+ executor_cores = self ._executor_cores
103+ executor_memory = self ._executor_memory
104+ driver_memory = self ._driver_memory
105+ num_executors = self ._num_executors
106+ spark_conf = self ._spark_conf
107+
80108 self ._hook = self ._get_hook (
81109 self ._target .command_type ,
82110 payload_file_path ,
83111 self ._target .ddl_concurrent_tasks ,
112+ spark_conf ,
113+ executor_cores ,
114+ executor_memory ,
115+ driver_memory ,
116+ num_executors ,
84117 )
85118 self ._hook .submit (self ._application )
86119 self ._target .post_hook (context )
87120
88121 def on_kill (self ) -> None :
89122 if self ._hook is None :
90- self ._hook = self ._get_hook (None , None , None )
123+ self ._hook = self ._get_hook (None , None , None , None , None , None , None , None )
91124 self ._hook .on_kill ()
92125
93126 def _get_hook (
94127 self ,
95128 command_type : t .Optional [commands .CommandType ],
96129 command_payload_file_path : t .Optional [str ],
97130 ddl_concurrent_tasks : t .Optional [int ],
131+ spark_conf : t .Optional [t .Dict [str , t .Any ]],
132+ executor_cores : t .Optional [int ],
133+ executor_memory : t .Optional [str ],
134+ driver_memory : t .Optional [str ],
135+ num_executors : t .Optional [int ],
98136 ) -> SparkSubmitHook :
99137 application_args = {
100138 "dialect" : "spark" ,
@@ -105,17 +143,17 @@ def _get_hook(
105143 else None ,
106144 }
107145 return SparkSubmitHook (
108- conf = self . _spark_conf ,
146+ conf = spark_conf ,
109147 conn_id = self ._connection_id ,
110148 total_executor_cores = self ._total_executor_cores ,
111- executor_cores = self . _executor_cores ,
112- executor_memory = self . _executor_memory ,
113- driver_memory = self . _driver_memory ,
149+ executor_cores = executor_cores ,
150+ executor_memory = executor_memory ,
151+ driver_memory = driver_memory ,
114152 keytab = self ._keytab ,
115153 principal = self ._principal ,
116154 proxy_user = self ._proxy_user ,
117155 name = self ._application_name ,
118- num_executors = self . _num_executors ,
156+ num_executors = num_executors ,
119157 application_args = [f"--{ k } ={ v } " for k , v in application_args .items () if v is not None ],
120158 files = command_payload_file_path ,
121159 )
0 commit comments