diff --git a/symphony/bdk/core/auth/auth_session.py b/symphony/bdk/core/auth/auth_session.py index d76d70c3..a0b53608 100644 --- a/symphony/bdk/core/auth/auth_session.py +++ b/symphony/bdk/core/auth/auth_session.py @@ -5,10 +5,13 @@ import logging from symphony.bdk.core.auth.exception import AuthInitializationError +from symphony.bdk.core.auth.jwt_helper import extract_token_claims + logger = logging.getLogger(__name__) EXPIRATION_SAFETY_BUFFER_SECONDS = 5 +SKD_FLAG_NAME = "canUseSimplifiedKeyDelivery" class AuthSession: @@ -33,7 +36,10 @@ async def refresh(self): """ logger.debug("Authenticate") self._session_token = await self._authenticator.retrieve_session_token() - self._key_manager_token = await self._authenticator.retrieve_key_manager_token() + if await self.skd_enabled: + self._key_manager_token = "" + return + self.key_manager_token = await self._authenticator.retrieve_key_manager_token() @property async def session_token(self): @@ -71,6 +77,10 @@ async def key_manager_token(self): :return: the key manager token """ + + if await self.skd_enabled: + return "" + if self._key_manager_token is None: self._key_manager_token = await self._authenticator.retrieve_key_manager_token() return self._key_manager_token @@ -91,6 +101,15 @@ def key_manager_token(self, value): """ self._key_manager_token = value + @property + async def skd_enabled(self): + + token_data = extract_token_claims(await self.session_token) + if not token_data.get(SKD_FLAG_NAME, False): + return False + return await self._authenticator.agent_version_service.is_skd_supported() + + class OboAuthSession(AuthSession): """RSA OBO Authentication session handle to get the OBO session token from. diff --git a/symphony/bdk/core/auth/bot_authenticator.py b/symphony/bdk/core/auth/bot_authenticator.py index b8ae9abd..1b74bc64 100644 --- a/symphony/bdk/core/auth/bot_authenticator.py +++ b/symphony/bdk/core/auth/bot_authenticator.py @@ -8,6 +8,7 @@ from symphony.bdk.core.config.model.bdk_retry_config import BdkRetryConfig from symphony.bdk.core.retry import retry from symphony.bdk.core.retry.strategy import authentication_retry +from symphony.bdk.core.service.version.agent_version_service import AgentVersionService from symphony.bdk.gen.api_client import ApiClient from symphony.bdk.gen.auth_api.certificate_authentication_api import CertificateAuthenticationApi from symphony.bdk.gen.login_api.authentication_api import AuthenticationApi @@ -24,6 +25,7 @@ def __init__(self, session_auth_client: ApiClient, key_manager_auth_client: ApiC self._session_auth_client = session_auth_client self._key_manager_auth_client = key_manager_auth_client self._retry_config = retry_config + self._agent_version_service = None async def retrieve_session_token(self) -> str: """Authenticates and retrieves a new session token. @@ -59,6 +61,15 @@ async def _authenticate_and_get_token(self, api_client: ApiClient) -> str: :return: the token as a string """ + @property + def agent_version_service(self) -> Optional[AgentVersionService]: + return self._agent_version_service + + @agent_version_service.setter + def agent_version_service(self, agent_version_service: AgentVersionService): + self._agent_version_service = agent_version_service + + class BotAuthenticatorRsa(BotAuthenticator): """Bot authenticator RSA implementation. diff --git a/symphony/bdk/core/auth/jwt_helper.py b/symphony/bdk/core/auth/jwt_helper.py index 99029420..d4d7e94f 100644 --- a/symphony/bdk/core/auth/jwt_helper.py +++ b/symphony/bdk/core/auth/jwt_helper.py @@ -83,3 +83,13 @@ def _parse_public_key_from_x509_cert(certificate: str) -> str: return public_key.public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).decode() except ValueError as exc: raise AuthInitializationError("Unable to parse the certificate. Check certificate format.") from exc + + +def extract_token_claims(session_token): + try: + return jwt.decode(session_token, + algorithms=[JWT_ENCRYPTION_ALGORITHM], + options={"verify_signature": False} + ) + except DecodeError: + return {} diff --git a/symphony/bdk/core/service/version/__init__.py b/symphony/bdk/core/service/version/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/symphony/bdk/core/service/version/agent_version_service.py b/symphony/bdk/core/service/version/agent_version_service.py new file mode 100644 index 00000000..40b774f5 --- /dev/null +++ b/symphony/bdk/core/service/version/agent_version_service.py @@ -0,0 +1,67 @@ +import re +from datetime import datetime, timezone + +from symphony.bdk.core.auth.auth_session import AuthSession +from symphony.bdk.core.auth.jwt_helper import generate_expiration_time +from symphony.bdk.core.config.model.bdk_retry_config import BdkRetryConfig +from symphony.bdk.core.retry import retry +from symphony.bdk.gen.agent_api.signals_api import SignalsApi +from symphony.bdk.gen.exceptions import ApiException + +MIN_MAJOR_VERSION = 24 +MIN_MINOR_VERSION = 12 +VERSION_REGEXP = r"Agent-(\d+)\.(\d+)\..*" + + +class AgentVersionService: + """Service class has one purpose only. It checks if version of agents supports simplified key delivery mechanism + + """ + + def __init__(self, signals_api: SignalsApi, retry_config: BdkRetryConfig): + self._signals_api = signals_api + self._retry_config = retry_config + self._is_skd_supported = None + self._expire_at = -1 + + async def is_skd_supported(self) -> bool: + """ AgentVersionService stores cached version flag. + Caching interval is the same as in to session token caching. + Once cache is expired it calls agent info api to update version. + + :return: boolean flag if skd supported for agent + """ + if ( + self._is_skd_supported is not None + and self._expire_at + > datetime.now(timezone.utc).timestamp() + ): + return self._is_skd_supported + self._expire_at = generate_expiration_time() + self._is_skd_supported = await self._get_agent_skd_support() + return self._is_skd_supported + + + @retry + async def _get_agent_skd_support(self) -> bool: + try: + agent_info = await self._signals_api.v1_info_get() + if not agent_info or not agent_info.version: + return False + except ApiException: + return False + agent_major_version, agent_minor_version = self._parse_version(agent_info.version) + if not agent_major_version: + return False + if agent_major_version == MIN_MAJOR_VERSION: + return agent_minor_version >= MIN_MINOR_VERSION + return agent_major_version > MIN_MAJOR_VERSION + + @staticmethod + def _parse_version(version_string): + if not version_string: + return None, None + match = re.match(VERSION_REGEXP, version_string) + if match: + return int(match.group(1)), int(match.group(2)) + return None, None diff --git a/symphony/bdk/core/service_factory.py b/symphony/bdk/core/service_factory.py index 42445bee..f67cdc86 100644 --- a/symphony/bdk/core/service_factory.py +++ b/symphony/bdk/core/service_factory.py @@ -18,6 +18,7 @@ from symphony.bdk.core.service.signal.signal_service import SignalService, OboSignalService from symphony.bdk.core.service.stream.stream_service import StreamService, OboStreamService from symphony.bdk.core.service.user.user_service import UserService, OboUserService +from symphony.bdk.core.service.version.agent_version_service import AgentVersionService from symphony.bdk.gen.agent_api.attachments_api import AttachmentsApi from symphony.bdk.gen.agent_api.audit_trail_api import AuditTrailApi from symphony.bdk.gen.agent_api.datafeed_api import DatafeedApi @@ -203,6 +204,14 @@ def get_presence_service(self) -> PresenceService: self._config.retry ) + def get_agent_version_service(self) -> AgentVersionService: + """Returns a fully initialized AgentVersionService + + :return: a new AgentVersionService instance + """ + return AgentVersionService(SignalsApi(self._agent_client), + self._config.retry) + class OboServiceFactory: """Factory responsible for creating BDK service instances for OBO-enabled endpoints only: diff --git a/symphony/bdk/core/symphony_bdk.py b/symphony/bdk/core/symphony_bdk.py index 36fb0e3c..25f82675 100644 --- a/symphony/bdk/core/symphony_bdk.py +++ b/symphony/bdk/core/symphony_bdk.py @@ -105,7 +105,8 @@ def __init__(self, config): "You can however use services in OBO mode if app authentication is configured.") def _initialize_bot_services(self): - self._bot_session = AuthSession(self._authenticator_factory.get_bot_authenticator()) + bot_authenticator = self._authenticator_factory.get_bot_authenticator() + self._bot_session = AuthSession(bot_authenticator) self._service_factory = ServiceFactory(self._api_client_factory, self._bot_session, self._config) self._user_service = self._service_factory.get_user_service() self._message_service = self._service_factory.get_message_service() @@ -123,6 +124,7 @@ def _initialize_bot_services(self): self._datafeed_loop.subscribe(self._activity_registry) # initialises extension service and register decorated extensions self._extension_service = ExtensionService(self._api_client_factory, self._bot_session, self._config) + bot_authenticator.agent_version_service = self._service_factory.get_agent_version_service() @bot_service def bot_session(self) -> AuthSession: diff --git a/tests/conftest.py b/tests/conftest.py index 6164d687..e7ff4be5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ """ import datetime +from unittest.mock import patch import pytest @@ -12,13 +13,16 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509 import NameOID +from symphony.bdk.core.auth.auth_session import SKD_FLAG_NAME -@pytest.fixture(name="root_key", scope="session") # the fixture will be created only once for entire test session. + +@pytest.fixture( + name="root_key", scope="session" +) # the fixture will be created only once for entire test session. def fixture_root_key(): return rsa.generate_private_key( - public_exponent=65537, - key_size=4096, - backend=default_backend()) + public_exponent=65537, key_size=4096, backend=default_backend() + ) @pytest.fixture(name="rsa_key", scope="session") @@ -26,19 +30,33 @@ def fixture_rsa_key(root_key): return root_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption()).decode("utf-8") + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") @pytest.fixture(name="certificate", scope="session") def fixture_certificate(root_key): - subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"commonName")]) + subject = issuer = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "commonName")] + ) now = datetime.datetime.utcnow() - cert = x509.CertificateBuilder() \ - .subject_name(subject) \ - .issuer_name(issuer) \ - .public_key(root_key.public_key()) \ - .serial_number(x509.random_serial_number()) \ - .not_valid_before(now) \ - .not_valid_after(now + datetime.timedelta(days=30)) \ + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(root_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=30)) .sign(root_key, hashes.SHA512(), default_backend()) + ) return cert.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8") + + +@pytest.fixture(autouse=True) +def mock_jwt_decode_for_skd(): + with patch( + "symphony.bdk.core.auth.auth_session.extract_token_claims", + return_value={SKD_FLAG_NAME: False}, + ) as mock: + yield mock diff --git a/tests/core/auth/auth_session_test.py b/tests/core/auth/auth_session_test.py index 37f9c7c8..de6f6c47 100644 --- a/tests/core/auth/auth_session_test.py +++ b/tests/core/auth/auth_session_test.py @@ -1,13 +1,39 @@ from datetime import datetime, timezone -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from symphony.bdk.core.auth.auth_session import AuthSession, OboAuthSession, AppAuthSession +from symphony.bdk.core.auth.auth_session import ( + AuthSession, + OboAuthSession, + AppAuthSession, + SKD_FLAG_NAME, +) from symphony.bdk.core.auth.exception import AuthInitializationError from symphony.bdk.gen.login_model.token import Token from symphony.bdk.gen.login_model.extension_app_tokens import ExtensionAppTokens +from symphony.bdk.core.auth.bot_authenticator import BotAuthenticatorRsa +from symphony.bdk.core.config.model.bdk_bot_config import BdkBotConfig +from symphony.bdk.gen.api_client import ApiClient + +@pytest.fixture +def mock_authenticator(): + + config = MagicMock(spec=BdkBotConfig) + login_client = MagicMock(spec=ApiClient) + relay_client = MagicMock(spec=ApiClient) + retry_config = MagicMock() + authenticator = BotAuthenticatorRsa(config, login_client, relay_client, retry_config) + authenticator.retrieve_session_token = AsyncMock(return_value="session_token_string") + authenticator.retrieve_key_manager_token = AsyncMock(return_value="km_token_string") + authenticator.agent_version_service = AsyncMock() + return authenticator + + +@pytest.fixture +def auth_session(mock_authenticator): + return AuthSession(mock_authenticator) @pytest.mark.asyncio @@ -57,8 +83,14 @@ async def test_auth_token(): @pytest.mark.asyncio async def test_refresh_obo(): mock_obo_authenticator = AsyncMock() - mock_obo_authenticator.retrieve_obo_session_token_by_user_id.side_effect = ["session_token1", "session_token2"] - mock_obo_authenticator.retrieve_obo_session_token_by_username.side_effect = ["session_token3", "session_token4"] + mock_obo_authenticator.retrieve_obo_session_token_by_user_id.side_effect = [ + "session_token1", + "session_token2", + ] + mock_obo_authenticator.retrieve_obo_session_token_by_username.side_effect = [ + "session_token3", + "session_token4", + ] obo_session = OboAuthSession(mock_obo_authenticator, user_id=1234) @@ -95,14 +127,96 @@ async def test_app_auth_session(): expire_at = 1539636528288 ext_app_authenticator = AsyncMock() - ext_app_authenticator.authenticate_and_retrieve_tokens.return_value = \ - ExtensionAppTokens(app_id="app_id", app_token=retrieved_app_token, symphony_token=symphony_token, - expire_at=expire_at) + ext_app_authenticator.authenticate_and_retrieve_tokens.return_value = ( + ExtensionAppTokens( + app_id="app_id", + app_token=retrieved_app_token, + symphony_token=symphony_token, + expire_at=expire_at, + ) + ) session = AppAuthSession(ext_app_authenticator, input_app_token) await session.refresh() - ext_app_authenticator.authenticate_and_retrieve_tokens.assert_called_once_with(input_app_token) + ext_app_authenticator.authenticate_and_retrieve_tokens.assert_called_once_with( + input_app_token + ) assert session.app_token == retrieved_app_token assert session.symphony_token == symphony_token assert session.expire_at == expire_at + + +@pytest.mark.asyncio +async def test_skd_disabled_if_claim_is_missing(auth_session): + # Given: The token claims do not contain the SKD flag + with patch( + "symphony.bdk.core.auth.auth_session.extract_token_claims", return_value={} + ): + # When: skd_enabled is checked + is_enabled = await auth_session.skd_enabled + # Then: The result is False + assert is_enabled is False + + +@pytest.mark.asyncio +async def test_skd_disabled_if_agent_not_supported(auth_session, mock_authenticator): + # Given: The token has the SKD flag but the agent does not support it + mock_authenticator.agent_version_service.is_skd_supported.return_value = False + claims_with_skd = {SKD_FLAG_NAME: True} + with patch( + "symphony.bdk.core.auth.auth_session.extract_token_claims", + return_value=claims_with_skd, + ): + # When: skd_enabled is checked + is_enabled = await auth_session.skd_enabled + # Then: The result is False and the agent version was checked + assert is_enabled is False + mock_authenticator.agent_version_service.is_skd_supported.assert_called_once() + + +@pytest.mark.asyncio +async def test_skd_enabled_when_fully_supported(auth_session, mock_authenticator): + # Given: The token has the SKD flag AND the agent supports it + mock_authenticator.agent_version_service.is_skd_supported.return_value = True + claims_with_skd = {SKD_FLAG_NAME: True} + with patch( + "symphony.bdk.core.auth.auth_session.extract_token_claims", + return_value=claims_with_skd, + ): + # When: skd_enabled is checked + is_enabled = await auth_session.skd_enabled + # Then: The result is True and the agent version was checked + assert is_enabled is True + mock_authenticator.agent_version_service.is_skd_supported.assert_called_once() + + +@pytest.mark.asyncio +async def test_km_token_is_empty_when_skd_enabled(auth_session, mock_authenticator): + # Given: SKD is fully enabled + mock_authenticator.agent_version_service.is_skd_supported.return_value = True + claims_with_skd = {SKD_FLAG_NAME: True} + with patch( + "symphony.bdk.core.auth.auth_session.extract_token_claims", + return_value=claims_with_skd, + ): + # When: The key manager token is requested + km_token = await auth_session.key_manager_token + # Then: The token is an empty string and the real retrieval method was NOT called + assert km_token == "" + mock_authenticator.retrieve_key_manager_token.assert_not_called() + + +@pytest.mark.asyncio +async def test_km_token_is_retrieved_when_skd_disabled( + auth_session, mock_authenticator +): + # Given: SKD is disabled because the token claim is missing + with patch( + "symphony.bdk.core.auth.auth_session.extract_token_claims", return_value={} + ): + # When: The key manager token is requested + km_token = await auth_session.key_manager_token + # Then: The real token is returned and the retrieval method was called + assert km_token == "km_token_string" + mock_authenticator.retrieve_key_manager_token.assert_called_once() diff --git a/tests/core/auth/jwt_helper_test.py b/tests/core/auth/jwt_helper_test.py index 8aaa92f7..8b8f2685 100644 --- a/tests/core/auth/jwt_helper_test.py +++ b/tests/core/auth/jwt_helper_test.py @@ -5,7 +5,12 @@ from jwt import InvalidAudienceError from symphony.bdk.core.auth.exception import AuthInitializationError -from symphony.bdk.core.auth.jwt_helper import create_signed_jwt, validate_jwt, create_signed_jwt_with_claims +from symphony.bdk.core.auth.jwt_helper import ( + create_signed_jwt, + validate_jwt, + create_signed_jwt_with_claims, + extract_token_claims, +) from symphony.bdk.core.config.model.bdk_rsa_key_config import BdkRsaKeyConfig AUDIENCE = "app-id" @@ -21,7 +26,7 @@ def fixture_jwt_payload(): return { "sub": "bot-user", "exp": (datetime.datetime.now(datetime.timezone.utc).timestamp() + (5 * 60)), - "aud": AUDIENCE + "aud": AUDIENCE, } @@ -31,7 +36,7 @@ def test_create_signed_jwt_from_path(key_config, rsa_key): mock_open = mock.mock_open(read_data=rsa_key) - with mock.patch('builtins.open', mock_open): + with mock.patch("builtins.open", mock_open): assert create_signed_jwt(key_config, "test_bot") is not None mock_open.assert_called_with(private_key_path, "r") @@ -49,16 +54,14 @@ def test_validate_jwt(jwt_payload, certificate, rsa_key): assert claims == jwt_payload - def test_validate_expired_jwt(jwt_payload, certificate, rsa_key): - jwt_payload["exp"] = (datetime.datetime.now(datetime.timezone.utc).timestamp() - 10) + jwt_payload["exp"] = datetime.datetime.now(datetime.timezone.utc).timestamp() - 10 signed_jwt = create_signed_jwt_with_claims(rsa_key, jwt_payload) with pytest.raises(AuthInitializationError): validate_jwt(signed_jwt, certificate, AUDIENCE) - def test_validate_jwt_with_empty_sub(jwt_payload, certificate, rsa_key): jwt_payload["sub"] = None signed_jwt = create_signed_jwt_with_claims(rsa_key, jwt_payload) @@ -84,3 +87,33 @@ def test_validate_jwt_with_invalid_cert(jwt_payload, rsa_key): def test_validate_jwt_with_invalid_jwt(certificate): with pytest.raises(AuthInitializationError): validate_jwt("invalid jwt", certificate, AUDIENCE) + + +def test_extract_claims_from_valid_token(rsa_key): + # Given: A valid jwt token + payload = {"sub": "test-bot", "skd": True, "userId": 12345} + token = create_signed_jwt_with_claims(rsa_key, payload) + # When: extract JWT claims is called with a valid token + claims = extract_token_claims(token) + # Then: fields are extracted as expected + assert claims["sub"] == "test-bot" + assert claims["skd"] is True + assert claims["userId"] == 12345 + + +@pytest.mark.parametrize( + "invalid_token", + [ + # Given: invalid JWT to be extracted + "not-a-jwt", + "a.b.c", # malformed JWT + "a.b", # not enough segments + "", # empty string + None, # None value + ], +) +def test_extract_claims_from_invalid_token(invalid_token): + # When: extract JWT claims is called + claims = extract_token_claims(invalid_token) + # Then: empty response is returned + assert claims == {} diff --git a/tests/core/service/version/__init__.py b/tests/core/service/version/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/service/version/agent_version_test.py b/tests/core/service/version/agent_version_test.py new file mode 100644 index 00000000..ceed883d --- /dev/null +++ b/tests/core/service/version/agent_version_test.py @@ -0,0 +1,92 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from symphony.bdk.core.service.version.agent_version_service import AgentVersionService +from symphony.bdk.gen.agent_model.agent_info import AgentInfo +from symphony.bdk.gen.exceptions import ApiException + + +@pytest.fixture +def signals_api(): + return AsyncMock() + + +@pytest.fixture +def retry_config(): + return MagicMock() + + +@pytest.fixture +def service(signals_api, retry_config): + return AgentVersionService(signals_api, retry_config) + + +@pytest.mark.parametrize( + "version_string, expected_major, expected_minor", + [ + # Given: Agent version string + ("Agent-24.12.0", 24, 12), + ("Agent-25.0.0-SNAPSHOT", 25, 0), + ("Agent-100.1.23", 100, 1), + ("NotAnAgent-1.0.0", None, None), + ("Agent-24", None, None), + ("Agent-24.12", None, None), + ("some random string", None, None), + (None, None, None), + ], +) +def test_parse_version(version_string, expected_major, expected_minor): + # When: agent parser is called + major, minor = AgentVersionService._parse_version(version_string) + # Then: major and minor version are extracted correctly + assert major == expected_major + assert minor == expected_minor + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "agent_version, expected_result", + [ + ("Agent-24.12.0", True), # Exact minimum version + ("Agent-24.13.0", True), # Minor version greater + ("Agent-25.0.0", True), # Major version greater + ("Agent-24.11.0", False), # Minor version smaller + ("Agent-23.15.0", False), # Major version smaller + ("Malformed-Version-String", False), # Malformed string + ], +) +async def test_is_skd_supported_versions( + service, signals_api, agent_version, expected_result +): + """Tests the SKD support check against various agent version strings.""" + # Given: Agent version string is returned from info API + signals_api.v1_info_get.return_value = AgentInfo(version=agent_version) + # When: is_skd_supported is called + is_supported = await service.is_skd_supported() + # Then: the expected boolean result is returned + assert is_supported is expected_result + signals_api.v1_info_get.assert_called_once() + + +@pytest.mark.asyncio +async def test_is_skd_supported_api_exception(service, signals_api): + """Tests that SKD support is False when the agent API call fails.""" + # Given: The call to the agent info API will raise an exception + signals_api.v1_info_get.side_effect = ApiException(reason="Agent unavailable") + # When: is_skd_supported is called + is_supported = await service.is_skd_supported() + # Then: False is returned and exception is handled + assert is_supported is False + signals_api.v1_info_get.assert_called_once() + + +@pytest.mark.asyncio +async def test_is_skd_supported_no_version_in_response(service, signals_api): + """Tests that SKD support is False when the agent info response is missing a version.""" + # Given: The agent info API returns a response with a null version + signals_api.v1_info_get.return_value = AgentInfo(version=None) + # When: is_skd_supported is called + is_supported = await service.is_skd_supported() + # Then: False is returned + assert is_supported is False