Skip to content

Commit cb7f6af

Browse files
authored
fix RedshiftSQLHook._get_conn_params connection mutation with iam (#64991)
1 parent bb5a744 commit cb7f6af

2 files changed

Lines changed: 39 additions & 8 deletions

File tree

providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,20 @@ def _get_conn_params(self) -> dict[str, str | int]:
8888
conn_params: dict[str, str | int] = {}
8989

9090
if conn.extra_dejson.get("iam", False):
91-
conn.login, conn.password, conn.port = self.get_iam_token(conn)
92-
93-
if conn.login:
94-
conn_params["user"] = conn.login
95-
if conn.password:
96-
conn_params["password"] = conn.password
91+
login, password, port = self.get_iam_token(conn)
92+
else:
93+
login = conn.login
94+
password = conn.password
95+
port = conn.port
96+
97+
if login:
98+
conn_params["user"] = login
99+
if password:
100+
conn_params["password"] = password
101+
if port:
102+
conn_params["port"] = port
97103
if conn.host:
98104
conn_params["host"] = conn.host
99-
if conn.port:
100-
conn_params["port"] = conn.port
101105
if conn.schema:
102106
conn_params["database"] = conn.schema
103107

providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,33 @@ def test_get_iam_token(
242242
AutoCreate=False,
243243
)
244244

245+
@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
246+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
247+
def test_get_conn_iam_does_not_mutate_connection(self, mock_connect, mock_aws_hook_conn):
248+
self.connection.extra = json.dumps(
249+
{"iam": True, "profile": "default", "cluster_identifier": "my-test-cluster"}
250+
)
251+
252+
mock_db_user = f"IAM:{LOGIN_USER}"
253+
mock_db_pass = "aws_token"
254+
255+
mock_aws_hook_conn.get_cluster_credentials.return_value = {
256+
"DbPassword": mock_db_pass,
257+
"DbUser": mock_db_user,
258+
}
259+
self.db_hook.get_conn()
260+
self.db_hook.get_conn()
261+
assert mock_aws_hook_conn.get_cluster_credentials.call_count == 2
262+
for call in mock_aws_hook_conn.get_cluster_credentials.call_args_list:
263+
assert call == mock.call(
264+
DbUser=LOGIN_USER,
265+
DbName=LOGIN_SCHEMA,
266+
ClusterIdentifier="my-test-cluster",
267+
AutoCreate=False,
268+
)
269+
270+
assert self.connection.login == LOGIN_USER
271+
245272
@mock.patch.dict("os.environ", AIRFLOW_CONN_AWS_DEFAULT=f"aws://?region_name={MOCK_REGION_NAME}")
246273
@pytest.mark.parametrize(
247274
("connection_host", "connection_extra", "expected_identity"),

0 commit comments

Comments
 (0)