|
1 | 1 | from datetime import datetime, timezone |
2 | 2 |
|
3 | | -from unittest.mock import AsyncMock |
| 3 | +from unittest.mock import AsyncMock, MagicMock, patch |
4 | 4 |
|
5 | 5 | import pytest |
6 | 6 |
|
7 | | -from symphony.bdk.core.auth.auth_session import AuthSession, OboAuthSession, AppAuthSession |
| 7 | +from symphony.bdk.core.auth.auth_session import ( |
| 8 | + AuthSession, |
| 9 | + OboAuthSession, |
| 10 | + AppAuthSession, |
| 11 | + SKD_FLAG_NAME, |
| 12 | +) |
8 | 13 | from symphony.bdk.core.auth.exception import AuthInitializationError |
9 | 14 | from symphony.bdk.gen.login_model.token import Token |
10 | 15 | from symphony.bdk.gen.login_model.extension_app_tokens import ExtensionAppTokens |
11 | 16 |
|
12 | 17 |
|
| 18 | +@pytest.fixture |
| 19 | +def mock_authenticator(): |
| 20 | + authenticator = MagicMock() |
| 21 | + authenticator.agent_version_service = AsyncMock() |
| 22 | + authenticator.retrieve_session_token = AsyncMock( |
| 23 | + return_value="session_token_string" |
| 24 | + ) |
| 25 | + authenticator.retrieve_key_manager_token = AsyncMock(return_value="km_token_string") |
| 26 | + return authenticator |
| 27 | + |
| 28 | + |
| 29 | +@pytest.fixture |
| 30 | +def auth_session(mock_authenticator): |
| 31 | + return AuthSession(mock_authenticator) |
| 32 | + |
| 33 | + |
13 | 34 | @pytest.mark.asyncio |
14 | 35 | async def test_refresh(): |
15 | 36 | mock_bot_authenticator = AsyncMock() |
@@ -57,8 +78,14 @@ async def test_auth_token(): |
57 | 78 | @pytest.mark.asyncio |
58 | 79 | async def test_refresh_obo(): |
59 | 80 | mock_obo_authenticator = AsyncMock() |
60 | | - mock_obo_authenticator.retrieve_obo_session_token_by_user_id.side_effect = ["session_token1", "session_token2"] |
61 | | - mock_obo_authenticator.retrieve_obo_session_token_by_username.side_effect = ["session_token3", "session_token4"] |
| 81 | + mock_obo_authenticator.retrieve_obo_session_token_by_user_id.side_effect = [ |
| 82 | + "session_token1", |
| 83 | + "session_token2", |
| 84 | + ] |
| 85 | + mock_obo_authenticator.retrieve_obo_session_token_by_username.side_effect = [ |
| 86 | + "session_token3", |
| 87 | + "session_token4", |
| 88 | + ] |
62 | 89 |
|
63 | 90 | obo_session = OboAuthSession(mock_obo_authenticator, user_id=1234) |
64 | 91 |
|
@@ -95,14 +122,96 @@ async def test_app_auth_session(): |
95 | 122 | expire_at = 1539636528288 |
96 | 123 |
|
97 | 124 | ext_app_authenticator = AsyncMock() |
98 | | - ext_app_authenticator.authenticate_and_retrieve_tokens.return_value = \ |
99 | | - ExtensionAppTokens(app_id="app_id", app_token=retrieved_app_token, symphony_token=symphony_token, |
100 | | - expire_at=expire_at) |
| 125 | + ext_app_authenticator.authenticate_and_retrieve_tokens.return_value = ( |
| 126 | + ExtensionAppTokens( |
| 127 | + app_id="app_id", |
| 128 | + app_token=retrieved_app_token, |
| 129 | + symphony_token=symphony_token, |
| 130 | + expire_at=expire_at, |
| 131 | + ) |
| 132 | + ) |
101 | 133 |
|
102 | 134 | session = AppAuthSession(ext_app_authenticator, input_app_token) |
103 | 135 | await session.refresh() |
104 | 136 |
|
105 | | - ext_app_authenticator.authenticate_and_retrieve_tokens.assert_called_once_with(input_app_token) |
| 137 | + ext_app_authenticator.authenticate_and_retrieve_tokens.assert_called_once_with( |
| 138 | + input_app_token |
| 139 | + ) |
106 | 140 | assert session.app_token == retrieved_app_token |
107 | 141 | assert session.symphony_token == symphony_token |
108 | 142 | assert session.expire_at == expire_at |
| 143 | + |
| 144 | + |
| 145 | +@pytest.mark.asyncio |
| 146 | +async def test_skd_disabled_if_claim_is_missing(auth_session): |
| 147 | + # Given: The token claims do not contain the SKD flag |
| 148 | + with patch( |
| 149 | + "symphony.bdk.core.auth.auth_session.extract_token_claims", return_value={} |
| 150 | + ): |
| 151 | + # When: skd_enabled is checked |
| 152 | + is_enabled = await auth_session.skd_enabled |
| 153 | + # Then: The result is False |
| 154 | + assert is_enabled is False |
| 155 | + |
| 156 | + |
| 157 | +@pytest.mark.asyncio |
| 158 | +async def test_skd_disabled_if_agent_not_supported(auth_session, mock_authenticator): |
| 159 | + # Given: The token has the SKD flag but the agent does not support it |
| 160 | + mock_authenticator.agent_version_service.is_skd_supported.return_value = False |
| 161 | + claims_with_skd = {SKD_FLAG_NAME: True} |
| 162 | + with patch( |
| 163 | + "symphony.bdk.core.auth.auth_session.extract_token_claims", |
| 164 | + return_value=claims_with_skd, |
| 165 | + ): |
| 166 | + # When: skd_enabled is checked |
| 167 | + is_enabled = await auth_session.skd_enabled |
| 168 | + # Then: The result is False and the agent version was checked |
| 169 | + assert is_enabled is False |
| 170 | + mock_authenticator.agent_version_service.is_skd_supported.assert_called_once() |
| 171 | + |
| 172 | + |
| 173 | +@pytest.mark.asyncio |
| 174 | +async def test_skd_enabled_when_fully_supported(auth_session, mock_authenticator): |
| 175 | + # Given: The token has the SKD flag AND the agent supports it |
| 176 | + mock_authenticator.agent_version_service.is_skd_supported.return_value = True |
| 177 | + claims_with_skd = {SKD_FLAG_NAME: True} |
| 178 | + with patch( |
| 179 | + "symphony.bdk.core.auth.auth_session.extract_token_claims", |
| 180 | + return_value=claims_with_skd, |
| 181 | + ): |
| 182 | + # When: skd_enabled is checked |
| 183 | + is_enabled = await auth_session.skd_enabled |
| 184 | + # Then: The result is True and the agent version was checked |
| 185 | + assert is_enabled is True |
| 186 | + mock_authenticator.agent_version_service.is_skd_supported.assert_called_once() |
| 187 | + |
| 188 | + |
| 189 | +@pytest.mark.asyncio |
| 190 | +async def test_km_token_is_empty_when_skd_enabled(auth_session, mock_authenticator): |
| 191 | + # Given: SKD is fully enabled |
| 192 | + mock_authenticator.agent_version_service.is_skd_supported.return_value = True |
| 193 | + claims_with_skd = {SKD_FLAG_NAME: True} |
| 194 | + with patch( |
| 195 | + "symphony.bdk.core.auth.auth_session.extract_token_claims", |
| 196 | + return_value=claims_with_skd, |
| 197 | + ): |
| 198 | + # When: The key manager token is requested |
| 199 | + km_token = await auth_session.key_manager_token |
| 200 | + # Then: The token is an empty string and the real retrieval method was NOT called |
| 201 | + assert km_token == "" |
| 202 | + mock_authenticator.retrieve_key_manager_token.assert_not_called() |
| 203 | + |
| 204 | + |
| 205 | +@pytest.mark.asyncio |
| 206 | +async def test_km_token_is_retrieved_when_skd_disabled( |
| 207 | + auth_session, mock_authenticator |
| 208 | +): |
| 209 | + # Given: SKD is disabled because the token claim is missing |
| 210 | + with patch( |
| 211 | + "symphony.bdk.core.auth.auth_session.extract_token_claims", return_value={} |
| 212 | + ): |
| 213 | + # When: The key manager token is requested |
| 214 | + km_token = await auth_session.key_manager_token |
| 215 | + # Then: The real token is returned and the retrieval method was called |
| 216 | + assert km_token == "km_token_string" |
| 217 | + mock_authenticator.retrieve_key_manager_token.assert_called_once() |
0 commit comments