@@ -16,14 +16,20 @@ def test_get_first_dag_run_id(mocker: MockerFixture):
1616 list_runs_mock = mocker .patch ("requests.Session.post" )
1717 list_runs_mock .return_value = list_runs_response_mock
1818
19- client = MWAAClient ("https://test_airflow_host" , "test_token" )
19+ url_and_auth_token_mock = mocker .patch (
20+ "sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment"
21+ )
22+ url_and_auth_token_mock .return_value = ("https://test_airflow_host" , "test_token" )
23+
24+ client = MWAAClient ("test_environment" )
2025
2126 assert client .get_first_dag_run_id ("test_dag_id" ) == "test_run_id"
2227
2328 list_runs_mock .assert_called_once_with (
2429 "https://test_airflow_host/aws_mwaa/cli" ,
2530 data = "dags list-runs -o json -d test_dag_id" ,
2631 )
32+ url_and_auth_token_mock .assert_called_once_with ("test_environment" )
2733
2834
2935def test_get_dag_run_state (mocker : MockerFixture ):
@@ -43,14 +49,56 @@ def test_get_dag_run_state(mocker: MockerFixture):
4349 list_runs_mock = mocker .patch ("requests.Session.post" )
4450 list_runs_mock .return_value = list_runs_response_mock
4551
46- client = MWAAClient ("https://test_airflow_host" , "test_token" )
52+ url_and_auth_token_mock = mocker .patch (
53+ "sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment"
54+ )
55+ url_and_auth_token_mock .return_value = ("https://test_airflow_host" , "test_token" )
56+
57+ client = MWAAClient ("test_environment" )
4758
4859 assert client .get_dag_run_state ("test_dag_id" , "test_run_id_b" ) == "failed"
4960
5061 list_runs_mock .assert_called_once_with (
5162 "https://test_airflow_host/aws_mwaa/cli" ,
5263 data = "dags list-runs -o json -d test_dag_id" ,
5364 )
65+ url_and_auth_token_mock .assert_called_once_with ("test_environment" )
66+
67+
68+ def test_token_refresh (mocker : MockerFixture ):
69+ list_runs_response_mock = mocker .Mock ()
70+ list_runs_response_mock .json .return_value = {
71+ "stdout" : _encode_output (json .dumps ([{"run_id" : "test_run_id" , "state" : "success" }])),
72+ "stderr" : "" ,
73+ }
74+ list_runs_response_mock .status_code = 200
75+ list_runs_mock = mocker .patch ("requests.Session.post" )
76+ list_runs_mock .return_value = list_runs_response_mock
77+
78+ url_and_auth_token_mock = mocker .patch (
79+ "sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment"
80+ )
81+ url_and_auth_token_mock .return_value = ("https://test_airflow_host" , "test_token" )
82+
83+ now_mock = mocker .patch ("sqlmesh.schedulers.airflow.mwaa_client.now_timestamp" )
84+ now_mock .return_value = 0
85+
86+ client = MWAAClient ("test_environment" )
87+ client .get_first_dag_run_id ("test_dag_id" )
88+
89+ now_mock .return_value = 15000 # 15 seconds later
90+ client .get_first_dag_run_id ("test_dag_id" )
91+
92+ now_mock .return_value = 31000 # 31 seconds later
93+ client .get_first_dag_run_id ("test_dag_id" )
94+
95+ now_mock .return_value = 45000 # 45 seconds later
96+ client .get_first_dag_run_id ("test_dag_id" )
97+
98+ now_mock .return_value = 63000 # 63 seconds later
99+ client .get_first_dag_run_id ("test_dag_id" )
100+
101+ assert url_and_auth_token_mock .call_count == 3
54102
55103
56104def _encode_output (out : str ) -> str :
0 commit comments