Skip to content

Commit fad0c60

Browse files
authored
Fix: Automatically refresh the MWAA CLI token (#1564)
1 parent c37b0aa commit fad0c60

4 files changed

Lines changed: 82 additions & 45 deletions

File tree

docs/integrations/airflow.md

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,3 @@ default_scheduler:
8484
type: mwaa
8585
environment: <The MWAA Environment Name>
8686
```
87-
88-
Alternatively, the Airflow Webserver URL and the MWAA CLI token can be provided directly instead of the environment name:
89-
```yaml linenums="1"
90-
default_scheduler:
91-
type: mwaa
92-
airflow_url: https://<Airflow Webserver Host>/
93-
auth_token: <The MWAA CLI Token>
94-
```

sqlmesh/core/config/scheduler.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,6 @@ class MWAASchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig):
220220
221221
Args:
222222
environment: The name of the MWAA environment.
223-
airflow_url: The URL of the Airflow Webserver.
224-
auth_token: The MWAA authentication token.
225223
dag_run_poll_interval_secs: Determines how often a running DAG can be polled (in seconds).
226224
dag_creation_poll_interval_secs: Determines how often SQLMesh should check whether a DAG has been created (in seconds).
227225
dag_creation_max_retry_attempts: Determines the maximum number of attempts that SQLMesh will make while checking for
@@ -230,9 +228,7 @@ class MWAASchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig):
230228
ddl_concurrent_tasks: The number of concurrent tasks used for DDL operations (table / view creation, deletion, etc).
231229
"""
232230

233-
environment: t.Optional[str] = None
234-
airflow_url: t.Optional[str] = None
235-
auth_token: t.Optional[str] = None
231+
environment: str
236232
dag_run_poll_interval_secs: int = 10
237233
dag_creation_poll_interval_secs: int = 30
238234
dag_creation_max_retry_attempts: int = 10
@@ -245,20 +241,10 @@ class MWAASchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig):
245241
_concurrent_tasks_validator = concurrent_tasks_validator
246242

247243
def create_plan_evaluator(self, context: Context) -> PlanEvaluator:
248-
from sqlmesh.schedulers.airflow.mwaa_client import (
249-
MWAAClient,
250-
url_and_auth_token_for_environment,
251-
)
252-
253-
if self.environment:
254-
airflow_url, auth_token = url_and_auth_token_for_environment(self.environment)
255-
else:
256-
assert self.airflow_url and self.auth_token # Make mypy happy
257-
airflow_url = self.airflow_url
258-
auth_token = self.auth_token
244+
from sqlmesh.schedulers.airflow.mwaa_client import MWAAClient
259245

260246
return MWAAPlanEvaluator(
261-
client=MWAAClient(airflow_url, auth_token, console=context.console),
247+
client=MWAAClient(self.environment, console=context.console),
262248
state_sync=context.state_sync,
263249
console=context.console,
264250
dag_run_poll_interval_secs=self.dag_run_poll_interval_secs,
@@ -270,19 +256,6 @@ def create_plan_evaluator(self, context: Context) -> PlanEvaluator:
270256
users=context.users,
271257
)
272258

273-
@model_validator(mode="before")
274-
@classmethod
275-
def _ensure_environment_or_url_with_auth_token(
276-
cls, values: t.Dict[str, t.Any]
277-
) -> t.Dict[str, t.Any]:
278-
if not values.get("environment"):
279-
if not values.get("airflow_url") or not values.get("auth_token"):
280-
raise ValueError(
281-
"Either 'environment' or 'airflow_url' and 'auth_token' must be specified for the MWAA scheduler config."
282-
)
283-
284-
return values
285-
286259

287260
SchedulerConfig = Annotated[
288261
t.Union[

sqlmesh/schedulers/airflow/mwaa_client.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,31 @@
22

33
import base64
44
import json
5+
import logging
56
import typing as t
67
from urllib.parse import urljoin
78

89
from requests import Session
910

1011
from sqlmesh.core.console import Console
1112
from sqlmesh.schedulers.airflow.client import BaseAirflowClient, raise_for_status
13+
from sqlmesh.utils.date import now_timestamp
1214
from sqlmesh.utils.errors import NotFoundError
1315

16+
logger = logging.getLogger(__name__)
17+
18+
19+
TOKEN_TTL_MS = 30 * 1000
20+
1421

1522
class MWAAClient(BaseAirflowClient):
16-
def __init__(self, airflow_url: str, auth_token: str, console: t.Optional[Console] = None):
23+
def __init__(self, environment: str, console: t.Optional[Console] = None):
24+
airflow_url, auth_token = url_and_auth_token_for_environment(environment)
1725
super().__init__(airflow_url, console)
1826

19-
self._session = Session()
20-
self._session.headers.update(
21-
{"Authorization": f"Bearer {auth_token}", "Content-Type": "text/plain"}
22-
)
27+
self._environment = environment
28+
self._last_token_refresh_ts = now_timestamp()
29+
self.__session: Session = _create_session(auth_token)
2330

2431
def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
2532
dag_runs = self._list_dag_runs(dag_id)
@@ -49,10 +56,27 @@ def _post(self, data: str) -> t.Tuple[str, str]:
4956
cli_stderr = base64.b64decode(response_body["stderr"]).decode("utf8").strip()
5057
return cli_stdout, cli_stderr
5158

59+
@property
60+
def _session(self) -> Session:
61+
current_ts = now_timestamp()
62+
if current_ts - self._last_token_refresh_ts > TOKEN_TTL_MS:
63+
_, auth_token = url_and_auth_token_for_environment(self._environment)
64+
self.__session = _create_session(auth_token)
65+
self._last_token_refresh_ts = current_ts
66+
return self.__session
67+
68+
69+
def _create_session(auth_token: str) -> Session:
70+
session = Session()
71+
session.headers.update({"Authorization": f"Bearer {auth_token}", "Content-Type": "text/plain"})
72+
return session
73+
5274

5375
def url_and_auth_token_for_environment(environment_name: str) -> t.Tuple[str, str]:
5476
import boto3
5577

78+
logger.info("Fetching the MWAA CLI token")
79+
5680
client = boto3.client("mwaa")
5781
cli_token = client.create_cli_token(Name=environment_name)
5882

tests/schedulers/airflow/test_mwaa_client.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,20 @@ def test_get_first_dag_run_id(mocker: MockerFixture):
1616
list_runs_mock = mocker.patch("requests.Session.post")
1717
list_runs_mock.return_value = list_runs_response_mock
1818

19-
client = MWAAClient("https://test_airflow_host", "test_token")
19+
url_and_auth_token_mock = mocker.patch(
20+
"sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment"
21+
)
22+
url_and_auth_token_mock.return_value = ("https://test_airflow_host", "test_token")
23+
24+
client = MWAAClient("test_environment")
2025

2126
assert client.get_first_dag_run_id("test_dag_id") == "test_run_id"
2227

2328
list_runs_mock.assert_called_once_with(
2429
"https://test_airflow_host/aws_mwaa/cli",
2530
data="dags list-runs -o json -d test_dag_id",
2631
)
32+
url_and_auth_token_mock.assert_called_once_with("test_environment")
2733

2834

2935
def test_get_dag_run_state(mocker: MockerFixture):
@@ -43,14 +49,56 @@ def test_get_dag_run_state(mocker: MockerFixture):
4349
list_runs_mock = mocker.patch("requests.Session.post")
4450
list_runs_mock.return_value = list_runs_response_mock
4551

46-
client = MWAAClient("https://test_airflow_host", "test_token")
52+
url_and_auth_token_mock = mocker.patch(
53+
"sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment"
54+
)
55+
url_and_auth_token_mock.return_value = ("https://test_airflow_host", "test_token")
56+
57+
client = MWAAClient("test_environment")
4758

4859
assert client.get_dag_run_state("test_dag_id", "test_run_id_b") == "failed"
4960

5061
list_runs_mock.assert_called_once_with(
5162
"https://test_airflow_host/aws_mwaa/cli",
5263
data="dags list-runs -o json -d test_dag_id",
5364
)
65+
url_and_auth_token_mock.assert_called_once_with("test_environment")
66+
67+
68+
def test_token_refresh(mocker: MockerFixture):
69+
list_runs_response_mock = mocker.Mock()
70+
list_runs_response_mock.json.return_value = {
71+
"stdout": _encode_output(json.dumps([{"run_id": "test_run_id", "state": "success"}])),
72+
"stderr": "",
73+
}
74+
list_runs_response_mock.status_code = 200
75+
list_runs_mock = mocker.patch("requests.Session.post")
76+
list_runs_mock.return_value = list_runs_response_mock
77+
78+
url_and_auth_token_mock = mocker.patch(
79+
"sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment"
80+
)
81+
url_and_auth_token_mock.return_value = ("https://test_airflow_host", "test_token")
82+
83+
now_mock = mocker.patch("sqlmesh.schedulers.airflow.mwaa_client.now_timestamp")
84+
now_mock.return_value = 0
85+
86+
client = MWAAClient("test_environment")
87+
client.get_first_dag_run_id("test_dag_id")
88+
89+
now_mock.return_value = 15000 # 15 seconds later
90+
client.get_first_dag_run_id("test_dag_id")
91+
92+
now_mock.return_value = 31000 # 31 seconds later
93+
client.get_first_dag_run_id("test_dag_id")
94+
95+
now_mock.return_value = 45000 # 45 seconds later
96+
client.get_first_dag_run_id("test_dag_id")
97+
98+
now_mock.return_value = 63000 # 63 seconds later
99+
client.get_first_dag_run_id("test_dag_id")
100+
101+
assert url_and_auth_token_mock.call_count == 3
54102

55103

56104
def _encode_output(out: str) -> str:

0 commit comments

Comments
 (0)