Skip to content

Commit 39a2084

Browse files
committed
fixup! fixup! style: fix provider mypy error
1 parent 4a1d866 commit 39a2084

5 files changed

Lines changed: 60 additions & 48 deletions

File tree

airflow-core/src/airflow/assets/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,8 @@ def resolve_asset_manager() -> AssetManager:
575575
key="asset_manager_kwargs",
576576
fallback={},
577577
)
578+
if TYPE_CHECKING:
579+
assert isinstance(_asset_manager_kwargs, dict)
578580
return _asset_manager_class(**_asset_manager_kwargs)
579581

580582

airflow-core/src/airflow/config_templates/airflow_local_settings.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from __future__ import annotations
2121

2222
import os
23-
from typing import TYPE_CHECKING, Any
23+
from typing import TYPE_CHECKING, Any, cast
2424
from urllib.parse import urlsplit
2525

2626
from airflow.configuration import conf
@@ -159,59 +159,63 @@ def _default_conn_name_from(mod_path, hook_name):
159159
"logging/remote_task_handler_kwargs must be a JSON object (a python dict), we got "
160160
f"{type(remote_task_handler_kwargs)}"
161161
)
162+
_handler_kwargs = cast("dict[str, Any]", remote_task_handler_kwargs)
162163
delete_local_copy = conf.getboolean("logging", "delete_local_logs")
163164

164165
if remote_base_log_folder.startswith("s3://"):
165166
from airflow.providers.amazon.aws.log.s3_task_handler import S3RemoteLogIO
166167

167168
_default_conn_name_from("airflow.providers.amazon.aws.hooks.s3", "S3Hook")
168169
REMOTE_TASK_LOG = S3RemoteLogIO(
169-
**(
170+
**cast(
171+
"dict[str, Any]",
170172
{
171173
"base_log_folder": BASE_LOG_FOLDER,
172174
"remote_base": remote_base_log_folder,
173175
"delete_local_copy": delete_local_copy,
174176
}
175-
| remote_task_handler_kwargs
177+
| _handler_kwargs,
176178
)
177179
)
178-
remote_task_handler_kwargs = {}
180+
_handler_kwargs = {}
179181

180182
elif remote_base_log_folder.startswith("cloudwatch://"):
181183
from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudWatchRemoteLogIO
182184

183185
_default_conn_name_from("airflow.providers.amazon.aws.hooks.logs", "AwsLogsHook")
184186
url_parts = urlsplit(remote_base_log_folder)
185187
REMOTE_TASK_LOG = CloudWatchRemoteLogIO(
186-
**(
188+
**cast(
189+
"dict[str, Any]",
187190
{
188191
"base_log_folder": BASE_LOG_FOLDER,
189192
"remote_base": remote_base_log_folder,
190193
"delete_local_copy": delete_local_copy,
191194
"log_group_arn": url_parts.netloc + url_parts.path,
192195
}
193-
| remote_task_handler_kwargs
196+
| _handler_kwargs,
194197
)
195198
)
196-
remote_task_handler_kwargs = {}
199+
_handler_kwargs = {}
197200
elif remote_base_log_folder.startswith("gs://"):
198201
from airflow.providers.google.cloud.log.gcs_task_handler import GCSRemoteLogIO
199202

200203
_default_conn_name_from("airflow.providers.google.cloud.hooks.gcs", "GCSHook")
201204
key_path = conf.get_mandatory_value("logging", "google_key_path", fallback=None)
202205

203206
REMOTE_TASK_LOG = GCSRemoteLogIO(
204-
**(
207+
**cast(
208+
"dict[str, Any]",
205209
{
206210
"base_log_folder": BASE_LOG_FOLDER,
207211
"remote_base": remote_base_log_folder,
208212
"delete_local_copy": delete_local_copy,
209213
"gcp_key_path": key_path,
210214
}
211-
| remote_task_handler_kwargs
215+
| _handler_kwargs,
212216
)
213217
)
214-
remote_task_handler_kwargs = {}
218+
_handler_kwargs = {}
215219
elif remote_base_log_folder.startswith("wasb"):
216220
from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbRemoteLogIO
217221

@@ -224,17 +228,18 @@ def _default_conn_name_from(mod_path, hook_name):
224228
wasb_remote_base = remote_base_log_folder.removeprefix("wasb://")
225229

226230
REMOTE_TASK_LOG = WasbRemoteLogIO(
227-
**(
231+
**cast(
232+
"dict[str, Any]",
228233
{
229234
"base_log_folder": BASE_LOG_FOLDER,
230235
"remote_base": wasb_remote_base,
231236
"delete_local_copy": delete_local_copy,
232237
"wasb_container": wasb_log_container,
233238
}
234-
| remote_task_handler_kwargs
239+
| _handler_kwargs,
235240
)
236241
)
237-
remote_task_handler_kwargs = {}
242+
_handler_kwargs = {}
238243
elif remote_base_log_folder.startswith("stackdriver://"):
239244
key_path = conf.get_mandatory_value("logging", "GOOGLE_KEY_PATH", fallback=None)
240245
# stackdriver:///airflow-tasks => airflow-tasks
@@ -255,32 +260,34 @@ def _default_conn_name_from(mod_path, hook_name):
255260
_default_conn_name_from("airflow.providers.alibaba.cloud.hooks.oss", "OSSHook")
256261

257262
REMOTE_TASK_LOG = OSSRemoteLogIO(
258-
**(
263+
**cast(
264+
"dict[str, Any]",
259265
{
260266
"base_log_folder": BASE_LOG_FOLDER,
261267
"remote_base": remote_base_log_folder,
262268
"delete_local_copy": delete_local_copy,
263269
}
264-
| remote_task_handler_kwargs
270+
| _handler_kwargs,
265271
)
266272
)
267-
remote_task_handler_kwargs = {}
273+
_handler_kwargs = {}
268274
elif remote_base_log_folder.startswith("hdfs://"):
269275
from airflow.providers.apache.hdfs.log.hdfs_task_handler import HdfsRemoteLogIO
270276

271277
_default_conn_name_from("airflow.providers.apache.hdfs.hooks.webhdfs", "WebHDFSHook")
272278

273279
REMOTE_TASK_LOG = HdfsRemoteLogIO(
274-
**(
280+
**cast(
281+
"dict[str, Any]",
275282
{
276283
"base_log_folder": BASE_LOG_FOLDER,
277284
"remote_base": urlsplit(remote_base_log_folder).path,
278285
"delete_local_copy": delete_local_copy,
279286
}
280-
| remote_task_handler_kwargs
287+
| _handler_kwargs,
281288
)
282289
)
283-
remote_task_handler_kwargs = {}
290+
_handler_kwargs = {}
284291
elif ELASTICSEARCH_HOST:
285292
from airflow.providers.elasticsearch.log.es_task_handler import ElasticsearchRemoteLogIO
286293

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import itertools
2222
import os
2323
from datetime import timedelta
24-
from typing import TYPE_CHECKING
24+
from typing import TYPE_CHECKING, Any
2525
from unittest import mock
2626

2727
import pendulum
@@ -151,7 +151,7 @@ def create_task_instances(
151151
assert dag_version
152152

153153
for mi in map_indexes:
154-
kwargs = self.ti_init | {"map_index": mi}
154+
kwargs: dict[str, Any] = self.ti_init | {"map_index": mi}
155155
ti = TaskInstance(task=tasks[i], **kwargs, dag_version_id=dag_version.id)
156156
session.add(ti)
157157
ti.dag_run = dr

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -225,28 +225,31 @@ def get_template_context(self) -> Context:
225225

226226
# Cache the context object, which ensures that all calls to get_template_context
227227
# are operating on the same context object.
228-
self._cached_template_context: Context = self._cached_template_context or {
229-
# From the Task Execution interface
230-
"dag": self.task.dag,
231-
"inlets": self.task.inlets,
232-
"map_index_template": self.task.map_index_template,
233-
"outlets": self.task.outlets,
234-
"run_id": self.run_id,
235-
"task": self.task,
236-
"task_instance": self,
237-
"ti": self,
238-
"outlet_events": OutletEventAccessors(),
239-
"inlet_events": InletEventsAccessors(self.task.inlets),
240-
"macros": MacrosAccessor(),
241-
"params": validated_params,
242-
# TODO: Make this go through Public API longer term.
243-
# "test_mode": task_instance.test_mode,
244-
"var": {
245-
"json": VariableAccessor(deserialize_json=True),
246-
"value": VariableAccessor(deserialize_json=False),
247-
},
248-
"conn": ConnectionAccessor(),
249-
}
228+
context = self._cached_template_context
229+
if context is None:
230+
context = {
231+
# From the Task Execution interface
232+
"dag": self.task.dag,
233+
"inlets": self.task.inlets,
234+
"map_index_template": self.task.map_index_template,
235+
"outlets": self.task.outlets,
236+
"run_id": self.run_id,
237+
"task": self.task,
238+
"task_instance": self,
239+
"ti": self,
240+
"outlet_events": OutletEventAccessors(),
241+
"inlet_events": InletEventsAccessors(self.task.inlets),
242+
"macros": MacrosAccessor(),
243+
"params": validated_params,
244+
# TODO: Make this go through Public API longer term.
245+
# "test_mode": task_instance.test_mode,
246+
"var": {
247+
"json": VariableAccessor(deserialize_json=True),
248+
"value": VariableAccessor(deserialize_json=False),
249+
},
250+
"conn": ConnectionAccessor(),
251+
}
252+
self._cached_template_context = context
250253
if from_server:
251254
dag_run = from_server.dag_run
252255
context_from_server: Context = {
@@ -265,7 +268,7 @@ def get_template_context(self) -> Context:
265268
lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
266269
),
267270
}
268-
self._cached_template_context.update(context_from_server)
271+
context.update(context_from_server)
269272

270273
if logical_date := coerce_datetime(dag_run.logical_date):
271274
if TYPE_CHECKING:
@@ -276,7 +279,7 @@ def get_template_context(self) -> Context:
276279
ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
277280
ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
278281
# logical_date and data_interval either coexist or be None together
279-
self._cached_template_context.update(
282+
context.update(
280283
{
281284
# keys that depend on logical_date
282285
"logical_date": logical_date,
@@ -303,7 +306,7 @@ def get_template_context(self) -> Context:
303306
if upstream_map_indexes is not None:
304307
setattr(self, "_upstream_map_indexes", upstream_map_indexes)
305308

306-
return self._cached_template_context
309+
return context
307310

308311
def render_templates(
309312
self, context: Context | None = None, jinja_env: jinja2.Environment | None = None

task-sdk/src/airflow/sdk/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from collections.abc import Iterable
2222
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypeAlias
2323

24-
from airflow.sdk.api.datamodels._generated import TerminalTIState, WeightRule
24+
from airflow.sdk.api.datamodels._generated import WeightRule
2525
from airflow.sdk.bases.xcom import BaseXCom
2626
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
2727

@@ -160,7 +160,7 @@ def get_ti_count(
160160
task_group_id: str | None = None,
161161
logical_dates: list[AwareDatetime] | None = None,
162162
run_ids: list[str] | None = None,
163-
states: list[str] | list[TerminalTIState] | None = None,
163+
states: list[str] | None = None,
164164
) -> int: ...
165165

166166
@staticmethod

0 commit comments

Comments
 (0)