Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion symphony/bdk/core/auth/auth_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import logging

from symphony.bdk.core.auth.exception import AuthInitializationError
from symphony.bdk.core.auth.jwt_helper import extract_token_claims


logger = logging.getLogger(__name__)

EXPIRATION_SAFETY_BUFFER_SECONDS = 5
SKD_FLAG_NAME = "canUseSimplifiedKeyDelivery"


class AuthSession:
Expand All @@ -33,7 +36,10 @@ async def refresh(self):
"""
logger.debug("Authenticate")
self._session_token = await self._authenticator.retrieve_session_token()
self._key_manager_token = await self._authenticator.retrieve_key_manager_token()
if await self.skd_enabled:
self._key_manager_token = ""
return
self.key_manager_token = await self._authenticator.retrieve_key_manager_token()

@property
async def session_token(self):
Expand Down Expand Up @@ -71,6 +77,10 @@ async def key_manager_token(self):

:return: the key manager token
"""

if await self.skd_enabled:
return ""

if self._key_manager_token is None:
self._key_manager_token = await self._authenticator.retrieve_key_manager_token()
return self._key_manager_token
Expand All @@ -91,6 +101,15 @@ def key_manager_token(self, value):
"""
self._key_manager_token = value

@property
async def skd_enabled(self):

token_data = extract_token_claims(await self.session_token)
if not token_data.get(SKD_FLAG_NAME, False):
return False
return await self._authenticator.agent_version_service.is_skd_supported()



class OboAuthSession(AuthSession):
"""RSA OBO Authentication session handle to get the OBO session token from.
Expand Down
11 changes: 11 additions & 0 deletions symphony/bdk/core/auth/bot_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from symphony.bdk.core.config.model.bdk_retry_config import BdkRetryConfig
from symphony.bdk.core.retry import retry
from symphony.bdk.core.retry.strategy import authentication_retry
from symphony.bdk.core.service.version.agent_version_service import AgentVersionService
from symphony.bdk.gen.api_client import ApiClient
from symphony.bdk.gen.auth_api.certificate_authentication_api import CertificateAuthenticationApi
from symphony.bdk.gen.login_api.authentication_api import AuthenticationApi
Expand All @@ -24,6 +25,7 @@ def __init__(self, session_auth_client: ApiClient, key_manager_auth_client: ApiC
self._session_auth_client = session_auth_client
self._key_manager_auth_client = key_manager_auth_client
self._retry_config = retry_config
self._agent_version_service = None

async def retrieve_session_token(self) -> str:
"""Authenticates and retrieves a new session token.
Expand Down Expand Up @@ -59,6 +61,15 @@ async def _authenticate_and_get_token(self, api_client: ApiClient) -> str:
:return: the token as a string
"""

@property
def agent_version_service(self) -> Optional[AgentVersionService]:
return self._agent_version_service

@agent_version_service.setter
def agent_version_service(self, agent_version_service: AgentVersionService):
self._agent_version_service = agent_version_service



class BotAuthenticatorRsa(BotAuthenticator):
"""Bot authenticator RSA implementation.
Expand Down
10 changes: 10 additions & 0 deletions symphony/bdk/core/auth/jwt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,13 @@ def _parse_public_key_from_x509_cert(certificate: str) -> str:
return public_key.public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).decode()
except ValueError as exc:
raise AuthInitializationError("Unable to parse the certificate. Check certificate format.") from exc


def extract_token_claims(session_token):
try:
return jwt.decode(session_token,
algorithms=[JWT_ENCRYPTION_ALGORITHM],
options={"verify_signature": False}
)
except DecodeError:
return {}
Empty file.
67 changes: 67 additions & 0 deletions symphony/bdk/core/service/version/agent_version_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import re
from datetime import datetime, timezone

from symphony.bdk.core.auth.auth_session import AuthSession
from symphony.bdk.core.auth.jwt_helper import generate_expiration_time
from symphony.bdk.core.config.model.bdk_retry_config import BdkRetryConfig
from symphony.bdk.core.retry import retry
from symphony.bdk.gen.agent_api.signals_api import SignalsApi
from symphony.bdk.gen.exceptions import ApiException

MIN_MAJOR_VERSION = 24
MIN_MINOR_VERSION = 12
VERSION_REGEXP = r"Agent-(\d+)\.(\d+)\..*"


class AgentVersionService:
"""Service class has one purpose only. It checks if version of agents supports simplified key delivery mechanism

"""

def __init__(self, signals_api: SignalsApi, retry_config: BdkRetryConfig):
self._signals_api = signals_api
self._retry_config = retry_config
self._is_skd_supported = None
self._expire_at = -1

async def is_skd_supported(self) -> bool:
""" AgentVersionService stores cached version flag.
Caching interval is the same as in to session token caching.
Once cache is expired it calls agent info api to update version.

:return: boolean flag if skd supported for agent
"""
if (
self._is_skd_supported is not None
and self._expire_at
> datetime.now(timezone.utc).timestamp()
):
return self._is_skd_supported
self._expire_at = generate_expiration_time()
self._is_skd_supported = await self._get_agent_skd_support()
return self._is_skd_supported


@retry
async def _get_agent_skd_support(self) -> bool:
try:
agent_info = await self._signals_api.v1_info_get()
if not agent_info or not agent_info.version:
return False
except ApiException:
return False
agent_major_version, agent_minor_version = self._parse_version(agent_info.version)
if not agent_major_version:
return False
if agent_major_version == MIN_MAJOR_VERSION:
return agent_minor_version >= MIN_MINOR_VERSION
return agent_major_version > MIN_MAJOR_VERSION

@staticmethod
def _parse_version(version_string):
if not version_string:
return None, None
match = re.match(VERSION_REGEXP, version_string)
if match:
return int(match.group(1)), int(match.group(2))
return None, None
9 changes: 9 additions & 0 deletions symphony/bdk/core/service_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from symphony.bdk.core.service.signal.signal_service import SignalService, OboSignalService
from symphony.bdk.core.service.stream.stream_service import StreamService, OboStreamService
from symphony.bdk.core.service.user.user_service import UserService, OboUserService
from symphony.bdk.core.service.version.agent_version_service import AgentVersionService
from symphony.bdk.gen.agent_api.attachments_api import AttachmentsApi
from symphony.bdk.gen.agent_api.audit_trail_api import AuditTrailApi
from symphony.bdk.gen.agent_api.datafeed_api import DatafeedApi
Expand Down Expand Up @@ -203,6 +204,14 @@ def get_presence_service(self) -> PresenceService:
self._config.retry
)

def get_agent_version_service(self) -> AgentVersionService:
"""Returns a fully initialized AgentVersionService

:return: a new AgentVersionService instance
"""
return AgentVersionService(SignalsApi(self._agent_client),
self._config.retry)


class OboServiceFactory:
"""Factory responsible for creating BDK service instances for OBO-enabled endpoints only:
Expand Down
4 changes: 3 additions & 1 deletion symphony/bdk/core/symphony_bdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def __init__(self, config):
"You can however use services in OBO mode if app authentication is configured.")

def _initialize_bot_services(self):
self._bot_session = AuthSession(self._authenticator_factory.get_bot_authenticator())
bot_authenticator = self._authenticator_factory.get_bot_authenticator()
self._bot_session = AuthSession(bot_authenticator)
self._service_factory = ServiceFactory(self._api_client_factory, self._bot_session, self._config)
self._user_service = self._service_factory.get_user_service()
self._message_service = self._service_factory.get_message_service()
Expand All @@ -123,6 +124,7 @@ def _initialize_bot_services(self):
self._datafeed_loop.subscribe(self._activity_registry)
# initialises extension service and register decorated extensions
self._extension_service = ExtensionService(self._api_client_factory, self._bot_session, self._config)
bot_authenticator.agent_version_service = self._service_factory.get_agent_version_service()

@bot_service
def bot_session(self) -> AuthSession:
Expand Down
44 changes: 31 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import datetime
from unittest.mock import patch

import pytest

Expand All @@ -12,33 +13,50 @@
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509 import NameOID

from symphony.bdk.core.auth.auth_session import SKD_FLAG_NAME

@pytest.fixture(name="root_key", scope="session") # the fixture will be created only once for entire test session.

@pytest.fixture(
name="root_key", scope="session"
) # the fixture will be created only once for entire test session.
def fixture_root_key():
return rsa.generate_private_key(
public_exponent=65537,
key_size=4096,
backend=default_backend())
public_exponent=65537, key_size=4096, backend=default_backend()
)


@pytest.fixture(name="rsa_key", scope="session")
def fixture_rsa_key(root_key):
return root_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()).decode("utf-8")
encryption_algorithm=serialization.NoEncryption(),
).decode("utf-8")


@pytest.fixture(name="certificate", scope="session")
def fixture_certificate(root_key):
subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"commonName")])
subject = issuer = x509.Name(
[x509.NameAttribute(NameOID.COMMON_NAME, "commonName")]
)
now = datetime.datetime.utcnow()
cert = x509.CertificateBuilder() \
.subject_name(subject) \
.issuer_name(issuer) \
.public_key(root_key.public_key()) \
.serial_number(x509.random_serial_number()) \
.not_valid_before(now) \
.not_valid_after(now + datetime.timedelta(days=30)) \
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(root_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + datetime.timedelta(days=30))
.sign(root_key, hashes.SHA512(), default_backend())
)
return cert.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8")


@pytest.fixture(autouse=True)
def mock_jwt_decode_for_skd():
with patch(
"symphony.bdk.core.auth.auth_session.extract_token_claims",
return_value={SKD_FLAG_NAME: False},
) as mock:
yield mock
Loading
Loading