Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions openedx/core/lib/tests/test_jwt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""
Tests for token handling
"""
import datetime
import unittest
from time import time

import pytest
from freezegun import freeze_time
from jwt.exceptions import ExpiredSignatureError, InvalidSignatureError, MissingRequiredClaimError

from openedx.core.djangolib.testing.utils import skip_unless_lms
Expand All @@ -13,6 +16,7 @@
invalid_test_user_id = 120
test_timeout = 1000
test_now = int(time())
time_snapshot = datetime.datetime.fromtimestamp(test_now, tz=datetime.UTC)
test_claims = {"foo": "bar", "baz": "quux", "meaning": 42}
expected_full_token = {
"lms_user_id": test_user_id,
Expand All @@ -24,6 +28,7 @@


@skip_unless_lms
@freeze_time(time_snapshot)
class TestSign(unittest.TestCase):
"""
Tests for JWT creation and signing.
Expand All @@ -33,7 +38,7 @@ def test_create_jwt(self):
token = create_jwt(test_user_id, test_timeout, {}, test_now)

decoded = unpack_and_verify(token)
self.assertEqual(expected_full_token, decoded) # noqa: PT009
assert decoded == expected_full_token

def test_create_jwt_with_claims(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
Expand All @@ -42,20 +47,18 @@ def test_create_jwt_with_claims(self):
expected_token_with_claims.update(test_claims)

decoded = unpack_and_verify(token)
self.assertEqual(expected_token_with_claims, decoded) # noqa: PT009
assert decoded == expected_token_with_claims

def test_malformed_token(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
token = token + "a"

expected_token_with_claims = expected_full_token.copy()
expected_token_with_claims.update(test_claims)

with self.assertRaises(InvalidSignatureError): # noqa: PT027
with pytest.raises(InvalidSignatureError):
unpack_and_verify(token)


@skip_unless_lms
@freeze_time(time_snapshot)
class TestUnpack(unittest.TestCase):
"""
Tests for JWT unpacking.
Expand All @@ -65,7 +68,7 @@ def test_unpack_jwt(self):
token = create_jwt(test_user_id, test_timeout, {}, test_now)
decoded = unpack_jwt(token, test_user_id, test_now)

self.assertEqual(expected_full_token, decoded) # noqa: PT009
assert decoded == expected_full_token

def test_unpack_jwt_with_claims(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
Expand All @@ -75,42 +78,39 @@ def test_unpack_jwt_with_claims(self):

decoded = unpack_jwt(token, test_user_id, test_now)

self.assertEqual(expected_token_with_claims, decoded) # noqa: PT009
assert decoded == expected_token_with_claims

def test_malformed_token(self):
token = create_jwt(test_user_id, test_timeout, test_claims, test_now)
token = token + "a"

expected_token_with_claims = expected_full_token.copy()
expected_token_with_claims.update(test_claims)

with self.assertRaises(InvalidSignatureError): # noqa: PT027
with pytest.raises(InvalidSignatureError):
unpack_jwt(token, test_user_id, test_now)

def test_unpack_token_with_invalid_user(self):
token = create_jwt(invalid_test_user_id, test_timeout, {}, test_now)

with self.assertRaises(InvalidSignatureError): # noqa: PT027
with pytest.raises(InvalidSignatureError):
unpack_jwt(token, test_user_id, test_now)

def test_unpack_expired_token(self):
token = create_jwt(test_user_id, test_timeout, {}, test_now)

with self.assertRaises(ExpiredSignatureError): # noqa: PT027
with pytest.raises(ExpiredSignatureError):
unpack_jwt(token, test_user_id, test_now + test_timeout + 1)

def test_missing_expired_lms_user_id(self):
payload = expected_full_token.copy()
del payload['lms_user_id']
token = _encode_and_sign(payload)

with self.assertRaises(MissingRequiredClaimError): # noqa: PT027
with pytest.raises(MissingRequiredClaimError):
unpack_jwt(token, test_user_id, test_now)

def test_missing_expired_key(self):
payload = expected_full_token.copy()
del payload['exp']
token = _encode_and_sign(payload)

with self.assertRaises(MissingRequiredClaimError): # noqa: PT027
with pytest.raises(MissingRequiredClaimError):
unpack_jwt(token, test_user_id, test_now)
Loading