Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow-core/src/airflow/assets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,8 @@ def resolve_asset_manager() -> AssetManager:
key="asset_manager_kwargs",
fallback={},
)
if TYPE_CHECKING:
assert isinstance(_asset_manager_kwargs, dict)
return _asset_manager_class(**_asset_manager_kwargs)


Expand Down
45 changes: 26 additions & 19 deletions airflow-core/src/airflow/config_templates/airflow_local_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import urlsplit

from airflow.configuration import conf
Expand Down Expand Up @@ -159,59 +159,63 @@ def _default_conn_name_from(mod_path, hook_name):
"logging/remote_task_handler_kwargs must be a JSON object (a python dict), we got "
f"{type(remote_task_handler_kwargs)}"
)
_handler_kwargs = cast("dict[str, Any]", remote_task_handler_kwargs)
delete_local_copy = conf.getboolean("logging", "delete_local_logs")

if remote_base_log_folder.startswith("s3://"):
from airflow.providers.amazon.aws.log.s3_task_handler import S3RemoteLogIO

_default_conn_name_from("airflow.providers.amazon.aws.hooks.s3", "S3Hook")
REMOTE_TASK_LOG = S3RemoteLogIO(
**(
**cast(
"dict[str, Any]",
{
"base_log_folder": BASE_LOG_FOLDER,
"remote_base": remote_base_log_folder,
"delete_local_copy": delete_local_copy,
}
| remote_task_handler_kwargs
| _handler_kwargs,
)
)
remote_task_handler_kwargs = {}
_handler_kwargs = {}

elif remote_base_log_folder.startswith("cloudwatch://"):
from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudWatchRemoteLogIO

_default_conn_name_from("airflow.providers.amazon.aws.hooks.logs", "AwsLogsHook")
url_parts = urlsplit(remote_base_log_folder)
REMOTE_TASK_LOG = CloudWatchRemoteLogIO(
**(
**cast(
"dict[str, Any]",
{
"base_log_folder": BASE_LOG_FOLDER,
"remote_base": remote_base_log_folder,
"delete_local_copy": delete_local_copy,
"log_group_arn": url_parts.netloc + url_parts.path,
}
| remote_task_handler_kwargs
| _handler_kwargs,
)
)
remote_task_handler_kwargs = {}
_handler_kwargs = {}
elif remote_base_log_folder.startswith("gs://"):
from airflow.providers.google.cloud.log.gcs_task_handler import GCSRemoteLogIO

_default_conn_name_from("airflow.providers.google.cloud.hooks.gcs", "GCSHook")
key_path = conf.get_mandatory_value("logging", "google_key_path", fallback=None)

REMOTE_TASK_LOG = GCSRemoteLogIO(
**(
**cast(
"dict[str, Any]",
{
"base_log_folder": BASE_LOG_FOLDER,
"remote_base": remote_base_log_folder,
"delete_local_copy": delete_local_copy,
"gcp_key_path": key_path,
}
| remote_task_handler_kwargs
| _handler_kwargs,
)
)
remote_task_handler_kwargs = {}
_handler_kwargs = {}
elif remote_base_log_folder.startswith("wasb"):
from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbRemoteLogIO

Expand All @@ -224,17 +228,18 @@ def _default_conn_name_from(mod_path, hook_name):
wasb_remote_base = remote_base_log_folder.removeprefix("wasb://")

REMOTE_TASK_LOG = WasbRemoteLogIO(
**(
**cast(
"dict[str, Any]",
{
"base_log_folder": BASE_LOG_FOLDER,
"remote_base": wasb_remote_base,
"delete_local_copy": delete_local_copy,
"wasb_container": wasb_log_container,
}
| remote_task_handler_kwargs
| _handler_kwargs,
)
)
remote_task_handler_kwargs = {}
_handler_kwargs = {}
elif remote_base_log_folder.startswith("stackdriver://"):
key_path = conf.get_mandatory_value("logging", "GOOGLE_KEY_PATH", fallback=None)
# stackdriver:///airflow-tasks => airflow-tasks
Expand All @@ -255,32 +260,34 @@ def _default_conn_name_from(mod_path, hook_name):
_default_conn_name_from("airflow.providers.alibaba.cloud.hooks.oss", "OSSHook")

REMOTE_TASK_LOG = OSSRemoteLogIO(
**(
**cast(
"dict[str, Any]",
{
"base_log_folder": BASE_LOG_FOLDER,
"remote_base": remote_base_log_folder,
"delete_local_copy": delete_local_copy,
}
| remote_task_handler_kwargs
| _handler_kwargs,
)
)
remote_task_handler_kwargs = {}
_handler_kwargs = {}
elif remote_base_log_folder.startswith("hdfs://"):
from airflow.providers.apache.hdfs.log.hdfs_task_handler import HdfsRemoteLogIO

_default_conn_name_from("airflow.providers.apache.hdfs.hooks.webhdfs", "WebHDFSHook")

REMOTE_TASK_LOG = HdfsRemoteLogIO(
**(
**cast(
"dict[str, Any]",
{
"base_log_folder": BASE_LOG_FOLDER,
"remote_base": urlsplit(remote_base_log_folder).path,
"delete_local_copy": delete_local_copy,
}
| remote_task_handler_kwargs
| _handler_kwargs,
)
)
remote_task_handler_kwargs = {}
_handler_kwargs = {}
elif ELASTICSEARCH_HOST:
from airflow.providers.elasticsearch.log.es_task_handler import ElasticsearchRemoteLogIO

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import itertools
import os
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from unittest import mock

import pendulum
Expand Down Expand Up @@ -151,7 +151,7 @@ def create_task_instances(
assert dag_version

for mi in map_indexes:
kwargs = self.ti_init | {"map_index": mi}
kwargs: dict[str, Any] = self.ti_init | {"map_index": mi}
ti = TaskInstance(task=tasks[i], **kwargs, dag_version_id=dag_version.id)
session.add(ti)
ti.dag_run = dr
Expand Down
44 changes: 22 additions & 22 deletions dev/breeze/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion devel-common/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ dependencies = [
"mypy" = [
# Mypy dependencies
# TODO: upgrade to newer versions of MyPy continuously as they are released
"mypy==1.19.1",
"mypy==1.20.0",
"types-Deprecated>=1.2.9.20240311",
"types-Markdown>=3.6.0.20240316",
"types-PyMySQL>=1.1.0.20240425",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
if result_processor is not None:
self.result_processor = result_processor
self.method_name = method_name
self.method_params = method_params
self.method_params = method_params or {}

def poke(self, context: Context) -> Any:
hook = JiraHook(jira_conn_id=self.jira_conn_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class EdgeExecutor(BaseExecutor):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_reported_state: dict[TaskInstanceKey, TaskInstanceState] = {}
self.last_reported_state: dict[TaskInstanceKey, TaskInstanceState | str] = {}

# Check if self has the ExecutorConf set on the self.conf attribute with all required methods.
# In Airflow 2.x, ExecutorConf exists but lacks methods like getint, getboolean, getsection, etc.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
if result_processor is not None:
self.result_processor = result_processor
self.method_name = method_name
self.method_params = method_params
self.method_params = method_params or {}

def poke(self, context: Context) -> bool:
hook = GithubHook(github_conn_id=self.github_conn_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import json
from collections.abc import Collection, Iterable, Sequence
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING, Any

from dateutil import parser
from google.api_core.exceptions import NotFound
Expand All @@ -32,6 +32,9 @@
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from airflow.utils.state import TerminalTIState


class CloudComposerExecutionTrigger(BaseTrigger):
"""The trigger handles the async communication with the Google Cloud Composer."""
Expand Down Expand Up @@ -183,7 +186,7 @@ def __init__(
composer_dag_id: str,
start_date: datetime,
end_date: datetime,
allowed_states: list[str],
allowed_states: list[str] | list[TerminalTIState],
composer_dag_run_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -358,7 +361,7 @@ def __init__(
environment_id: str,
start_date: datetime,
end_date: datetime,
allowed_states: list[str],
allowed_states: list[str] | list[TerminalTIState],
skipped_states: list[str],
failed_states: list[str],
composer_external_dag_id: str,
Expand Down
Loading
Loading