Skip to content

Commit bfa6ca6

Browse files
committed
CAIP-79 exteneded tests
1 parent aef8a16 commit bfa6ca6

7 files changed

Lines changed: 286 additions & 38 deletions

File tree

symphony/bdk/core/auth/auth_session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66

77
from symphony.bdk.core.auth.exception import AuthInitializationError
8-
from symphony.bdk.core.auth.jwt_helper import extract_session_token_claims
8+
from symphony.bdk.core.auth.jwt_helper import extract_token_claims
99

1010

1111
logger = logging.getLogger(__name__)
@@ -101,7 +101,7 @@ def key_manager_token(self, value):
101101
@property
102102
async def skd_enabled(self):
103103

104-
token_data = extract_session_token_claims(await self.session_token)
104+
token_data = extract_token_claims(await self.session_token)
105105
if not token_data.get(SKD_FLAG_NAME, False):
106106
return False
107107
return await self._authenticator.agent_version_service.is_skd_supported()

symphony/bdk/core/auth/jwt_helper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ def _parse_public_key_from_x509_cert(certificate: str) -> str:
8585
raise AuthInitializationError("Unable to parse the certificate. Check certificate format.") from exc
8686

8787

88-
def extract_session_token_claims(session_token):
89-
return jwt.decode(session_token,
90-
algorithms=["RS512"],
88+
def extract_token_claims(session_token):
89+
try:
90+
return jwt.decode(session_token,
91+
algorithms=[JWT_ENCRYPTION_ALGORITHM],
9192
options={"verify_signature": False}
92-
)
93+
)
94+
except DecodeError:
95+
return {}

symphony/bdk/core/service/version/agent_version_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ async def is_skd_supported(self) -> bool:
2727
"""
2828
try:
2929
agent_info = await self._signals_api.v1_info_get()
30+
if not agent_info or not agent_info.version:
31+
return False
3032
except ApiException:
3133
return False
3234
agent_major_version, agent_minor_version = self._parse_version(agent_info.version)
@@ -38,8 +40,9 @@ async def is_skd_supported(self) -> bool:
3840

3941
@staticmethod
4042
def _parse_version(version_string):
41-
match = re.match(r"Agent-(\d+)\.(\d+)\..*", version_string)
43+
if not version_string:
44+
return None, None
45+
match = re.match(VERSION_REGEXP, version_string)
4246
if match:
4347
return int(match.group(1)), int(match.group(2))
44-
4548
return None, None

tests/conftest.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,47 @@
1616
from symphony.bdk.core.auth.auth_session import SKD_FLAG_NAME
1717

1818

19-
@pytest.fixture(name="root_key", scope="session") # the fixture will be created only once for entire test session.
19+
@pytest.fixture(
20+
name="root_key", scope="session"
21+
) # the fixture will be created only once for entire test session.
2022
def fixture_root_key():
2123
return rsa.generate_private_key(
22-
public_exponent=65537,
23-
key_size=4096,
24-
backend=default_backend())
24+
public_exponent=65537, key_size=4096, backend=default_backend()
25+
)
2526

2627

2728
@pytest.fixture(name="rsa_key", scope="session")
2829
def fixture_rsa_key(root_key):
2930
return root_key.private_bytes(
3031
encoding=serialization.Encoding.PEM,
3132
format=serialization.PrivateFormat.PKCS8,
32-
encryption_algorithm=serialization.NoEncryption()).decode("utf-8")
33+
encryption_algorithm=serialization.NoEncryption(),
34+
).decode("utf-8")
3335

3436

3537
@pytest.fixture(name="certificate", scope="session")
3638
def fixture_certificate(root_key):
37-
subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"commonName")])
39+
subject = issuer = x509.Name(
40+
[x509.NameAttribute(NameOID.COMMON_NAME, "commonName")]
41+
)
3842
now = datetime.datetime.utcnow()
39-
cert = x509.CertificateBuilder() \
40-
.subject_name(subject) \
41-
.issuer_name(issuer) \
42-
.public_key(root_key.public_key()) \
43-
.serial_number(x509.random_serial_number()) \
44-
.not_valid_before(now) \
45-
.not_valid_after(now + datetime.timedelta(days=30)) \
43+
cert = (
44+
x509.CertificateBuilder()
45+
.subject_name(subject)
46+
.issuer_name(issuer)
47+
.public_key(root_key.public_key())
48+
.serial_number(x509.random_serial_number())
49+
.not_valid_before(now)
50+
.not_valid_after(now + datetime.timedelta(days=30))
4651
.sign(root_key, hashes.SHA512(), default_backend())
52+
)
4753
return cert.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8")
4854

4955

5056
@pytest.fixture(autouse=True)
5157
def mock_jwt_decode_for_skd():
52-
with patch("symphony.bdk.core.auth.auth_session.extract_session_token_claims",
53-
return_value={SKD_FLAG_NAME: False}) as mock:
54-
yield mock
58+
with patch(
59+
"symphony.bdk.core.auth.auth_session.extract_token_claims",
60+
return_value={SKD_FLAG_NAME: False},
61+
) as mock:
62+
yield mock

tests/core/auth/auth_session_test.py

Lines changed: 117 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,36 @@
11
from datetime import datetime, timezone
22

3-
from unittest.mock import AsyncMock
3+
from unittest.mock import AsyncMock, MagicMock, patch
44

55
import pytest
66

7-
from symphony.bdk.core.auth.auth_session import AuthSession, OboAuthSession, AppAuthSession
7+
from symphony.bdk.core.auth.auth_session import (
8+
AuthSession,
9+
OboAuthSession,
10+
AppAuthSession,
11+
SKD_FLAG_NAME,
12+
)
813
from symphony.bdk.core.auth.exception import AuthInitializationError
914
from symphony.bdk.gen.login_model.token import Token
1015
from symphony.bdk.gen.login_model.extension_app_tokens import ExtensionAppTokens
1116

1217

18+
@pytest.fixture
19+
def mock_authenticator():
20+
authenticator = MagicMock()
21+
authenticator.agent_version_service = AsyncMock()
22+
authenticator.retrieve_session_token = AsyncMock(
23+
return_value="session_token_string"
24+
)
25+
authenticator.retrieve_key_manager_token = AsyncMock(return_value="km_token_string")
26+
return authenticator
27+
28+
29+
@pytest.fixture
30+
def auth_session(mock_authenticator):
31+
return AuthSession(mock_authenticator)
32+
33+
1334
@pytest.mark.asyncio
1435
async def test_refresh():
1536
mock_bot_authenticator = AsyncMock()
@@ -57,8 +78,14 @@ async def test_auth_token():
5778
@pytest.mark.asyncio
5879
async def test_refresh_obo():
5980
mock_obo_authenticator = AsyncMock()
60-
mock_obo_authenticator.retrieve_obo_session_token_by_user_id.side_effect = ["session_token1", "session_token2"]
61-
mock_obo_authenticator.retrieve_obo_session_token_by_username.side_effect = ["session_token3", "session_token4"]
81+
mock_obo_authenticator.retrieve_obo_session_token_by_user_id.side_effect = [
82+
"session_token1",
83+
"session_token2",
84+
]
85+
mock_obo_authenticator.retrieve_obo_session_token_by_username.side_effect = [
86+
"session_token3",
87+
"session_token4",
88+
]
6289

6390
obo_session = OboAuthSession(mock_obo_authenticator, user_id=1234)
6491

@@ -95,14 +122,96 @@ async def test_app_auth_session():
95122
expire_at = 1539636528288
96123

97124
ext_app_authenticator = AsyncMock()
98-
ext_app_authenticator.authenticate_and_retrieve_tokens.return_value = \
99-
ExtensionAppTokens(app_id="app_id", app_token=retrieved_app_token, symphony_token=symphony_token,
100-
expire_at=expire_at)
125+
ext_app_authenticator.authenticate_and_retrieve_tokens.return_value = (
126+
ExtensionAppTokens(
127+
app_id="app_id",
128+
app_token=retrieved_app_token,
129+
symphony_token=symphony_token,
130+
expire_at=expire_at,
131+
)
132+
)
101133

102134
session = AppAuthSession(ext_app_authenticator, input_app_token)
103135
await session.refresh()
104136

105-
ext_app_authenticator.authenticate_and_retrieve_tokens.assert_called_once_with(input_app_token)
137+
ext_app_authenticator.authenticate_and_retrieve_tokens.assert_called_once_with(
138+
input_app_token
139+
)
106140
assert session.app_token == retrieved_app_token
107141
assert session.symphony_token == symphony_token
108142
assert session.expire_at == expire_at
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_skd_disabled_if_claim_is_missing(auth_session):
147+
# Given: The token claims do not contain the SKD flag
148+
with patch(
149+
"symphony.bdk.core.auth.auth_session.extract_token_claims", return_value={}
150+
):
151+
# When: skd_enabled is checked
152+
is_enabled = await auth_session.skd_enabled
153+
# Then: The result is False
154+
assert is_enabled is False
155+
156+
157+
@pytest.mark.asyncio
158+
async def test_skd_disabled_if_agent_not_supported(auth_session, mock_authenticator):
159+
# Given: The token has the SKD flag but the agent does not support it
160+
mock_authenticator.agent_version_service.is_skd_supported.return_value = False
161+
claims_with_skd = {SKD_FLAG_NAME: True}
162+
with patch(
163+
"symphony.bdk.core.auth.auth_session.extract_token_claims",
164+
return_value=claims_with_skd,
165+
):
166+
# When: skd_enabled is checked
167+
is_enabled = await auth_session.skd_enabled
168+
# Then: The result is False and the agent version was checked
169+
assert is_enabled is False
170+
mock_authenticator.agent_version_service.is_skd_supported.assert_called_once()
171+
172+
173+
@pytest.mark.asyncio
174+
async def test_skd_enabled_when_fully_supported(auth_session, mock_authenticator):
175+
# Given: The token has the SKD flag AND the agent supports it
176+
mock_authenticator.agent_version_service.is_skd_supported.return_value = True
177+
claims_with_skd = {SKD_FLAG_NAME: True}
178+
with patch(
179+
"symphony.bdk.core.auth.auth_session.extract_token_claims",
180+
return_value=claims_with_skd,
181+
):
182+
# When: skd_enabled is checked
183+
is_enabled = await auth_session.skd_enabled
184+
# Then: The result is True and the agent version was checked
185+
assert is_enabled is True
186+
mock_authenticator.agent_version_service.is_skd_supported.assert_called_once()
187+
188+
189+
@pytest.mark.asyncio
190+
async def test_km_token_is_empty_when_skd_enabled(auth_session, mock_authenticator):
191+
# Given: SKD is fully enabled
192+
mock_authenticator.agent_version_service.is_skd_supported.return_value = True
193+
claims_with_skd = {SKD_FLAG_NAME: True}
194+
with patch(
195+
"symphony.bdk.core.auth.auth_session.extract_token_claims",
196+
return_value=claims_with_skd,
197+
):
198+
# When: The key manager token is requested
199+
km_token = await auth_session.key_manager_token
200+
# Then: The token is an empty string and the real retrieval method was NOT called
201+
assert km_token == ""
202+
mock_authenticator.retrieve_key_manager_token.assert_not_called()
203+
204+
205+
@pytest.mark.asyncio
206+
async def test_km_token_is_retrieved_when_skd_disabled(
207+
auth_session, mock_authenticator
208+
):
209+
# Given: SKD is disabled because the token claim is missing
210+
with patch(
211+
"symphony.bdk.core.auth.auth_session.extract_token_claims", return_value={}
212+
):
213+
# When: The key manager token is requested
214+
km_token = await auth_session.key_manager_token
215+
# Then: The real token is returned and the retrieval method was called
216+
assert km_token == "km_token_string"
217+
mock_authenticator.retrieve_key_manager_token.assert_called_once()

tests/core/auth/jwt_helper_test.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
from jwt import InvalidAudienceError
66

77
from symphony.bdk.core.auth.exception import AuthInitializationError
8-
from symphony.bdk.core.auth.jwt_helper import create_signed_jwt, validate_jwt, create_signed_jwt_with_claims
8+
from symphony.bdk.core.auth.jwt_helper import (
9+
create_signed_jwt,
10+
validate_jwt,
11+
create_signed_jwt_with_claims,
12+
extract_token_claims,
13+
)
914
from symphony.bdk.core.config.model.bdk_rsa_key_config import BdkRsaKeyConfig
1015

1116
AUDIENCE = "app-id"
@@ -21,7 +26,7 @@ def fixture_jwt_payload():
2126
return {
2227
"sub": "bot-user",
2328
"exp": (datetime.datetime.now(datetime.timezone.utc).timestamp() + (5 * 60)),
24-
"aud": AUDIENCE
29+
"aud": AUDIENCE,
2530
}
2631

2732

@@ -31,7 +36,7 @@ def test_create_signed_jwt_from_path(key_config, rsa_key):
3136

3237
mock_open = mock.mock_open(read_data=rsa_key)
3338

34-
with mock.patch('builtins.open', mock_open):
39+
with mock.patch("builtins.open", mock_open):
3540
assert create_signed_jwt(key_config, "test_bot") is not None
3641
mock_open.assert_called_with(private_key_path, "r")
3742

@@ -49,16 +54,14 @@ def test_validate_jwt(jwt_payload, certificate, rsa_key):
4954
assert claims == jwt_payload
5055

5156

52-
5357
def test_validate_expired_jwt(jwt_payload, certificate, rsa_key):
54-
jwt_payload["exp"] = (datetime.datetime.now(datetime.timezone.utc).timestamp() - 10)
58+
jwt_payload["exp"] = datetime.datetime.now(datetime.timezone.utc).timestamp() - 10
5559
signed_jwt = create_signed_jwt_with_claims(rsa_key, jwt_payload)
5660

5761
with pytest.raises(AuthInitializationError):
5862
validate_jwt(signed_jwt, certificate, AUDIENCE)
5963

6064

61-
6265
def test_validate_jwt_with_empty_sub(jwt_payload, certificate, rsa_key):
6366
jwt_payload["sub"] = None
6467
signed_jwt = create_signed_jwt_with_claims(rsa_key, jwt_payload)
@@ -84,3 +87,33 @@ def test_validate_jwt_with_invalid_cert(jwt_payload, rsa_key):
8487
def test_validate_jwt_with_invalid_jwt(certificate):
8588
with pytest.raises(AuthInitializationError):
8689
validate_jwt("invalid jwt", certificate, AUDIENCE)
90+
91+
92+
def test_extract_claims_from_valid_token(rsa_key):
93+
# Given: A valid jwt token
94+
payload = {"sub": "test-bot", "skd": True, "userId": 12345}
95+
token = create_signed_jwt_with_claims(rsa_key, payload)
96+
# When: extract JWT claims is called with a valid token
97+
claims = extract_token_claims(token)
98+
# Then: fields are extracted as expected
99+
assert claims["sub"] == "test-bot"
100+
assert claims["skd"] is True
101+
assert claims["userId"] == 12345
102+
103+
104+
@pytest.mark.parametrize(
105+
"invalid_token",
106+
[
107+
# Given: invalid JWT to be extracted
108+
"not-a-jwt",
109+
"a.b.c", # malformed JWT
110+
"a.b", # not enough segments
111+
"", # empty string
112+
None, # None value
113+
],
114+
)
115+
def test_extract_claims_from_invalid_token(invalid_token):
116+
# When: extract JWT claims is called
117+
claims = extract_token_claims(invalid_token)
118+
# Then: empty response is returned
119+
assert claims == {}

0 commit comments

Comments
 (0)