diff --git a/openedx/core/lib/tests/test_jwt.py b/openedx/core/lib/tests/test_jwt.py index c206e78e6cc3..25330a905822 100644 --- a/openedx/core/lib/tests/test_jwt.py +++ b/openedx/core/lib/tests/test_jwt.py @@ -1,9 +1,12 @@ """ Tests for token handling """ +import datetime import unittest from time import time +import pytest +from freezegun import freeze_time from jwt.exceptions import ExpiredSignatureError, InvalidSignatureError, MissingRequiredClaimError from openedx.core.djangolib.testing.utils import skip_unless_lms @@ -13,6 +16,7 @@ invalid_test_user_id = 120 test_timeout = 1000 test_now = int(time()) +time_snapshot = datetime.datetime.fromtimestamp(test_now, tz=datetime.UTC) test_claims = {"foo": "bar", "baz": "quux", "meaning": 42} expected_full_token = { "lms_user_id": test_user_id, @@ -24,6 +28,7 @@ @skip_unless_lms +@freeze_time(time_snapshot) class TestSign(unittest.TestCase): """ Tests for JWT creation and signing. @@ -33,7 +38,7 @@ def test_create_jwt(self): token = create_jwt(test_user_id, test_timeout, {}, test_now) decoded = unpack_and_verify(token) - self.assertEqual(expected_full_token, decoded) # noqa: PT009 + assert decoded == expected_full_token def test_create_jwt_with_claims(self): token = create_jwt(test_user_id, test_timeout, test_claims, test_now) @@ -42,20 +47,18 @@ def test_create_jwt_with_claims(self): expected_token_with_claims.update(test_claims) decoded = unpack_and_verify(token) - self.assertEqual(expected_token_with_claims, decoded) # noqa: PT009 + assert decoded == expected_token_with_claims def test_malformed_token(self): token = create_jwt(test_user_id, test_timeout, test_claims, test_now) token = token + "a" - expected_token_with_claims = expected_full_token.copy() - expected_token_with_claims.update(test_claims) - - with self.assertRaises(InvalidSignatureError): # noqa: PT027 + with pytest.raises(InvalidSignatureError): unpack_and_verify(token) @skip_unless_lms +@freeze_time(time_snapshot) class TestUnpack(unittest.TestCase): """ Tests for JWT unpacking. @@ -65,7 +68,7 @@ def test_unpack_jwt(self): token = create_jwt(test_user_id, test_timeout, {}, test_now) decoded = unpack_jwt(token, test_user_id, test_now) - self.assertEqual(expected_full_token, decoded) # noqa: PT009 + assert decoded == expected_full_token def test_unpack_jwt_with_claims(self): token = create_jwt(test_user_id, test_timeout, test_claims, test_now) @@ -75,28 +78,25 @@ def test_unpack_jwt_with_claims(self): decoded = unpack_jwt(token, test_user_id, test_now) - self.assertEqual(expected_token_with_claims, decoded) # noqa: PT009 + assert decoded == expected_token_with_claims def test_malformed_token(self): token = create_jwt(test_user_id, test_timeout, test_claims, test_now) token = token + "a" - expected_token_with_claims = expected_full_token.copy() - expected_token_with_claims.update(test_claims) - - with self.assertRaises(InvalidSignatureError): # noqa: PT027 + with pytest.raises(InvalidSignatureError): unpack_jwt(token, test_user_id, test_now) def test_unpack_token_with_invalid_user(self): token = create_jwt(invalid_test_user_id, test_timeout, {}, test_now) - with self.assertRaises(InvalidSignatureError): # noqa: PT027 + with pytest.raises(InvalidSignatureError): unpack_jwt(token, test_user_id, test_now) def test_unpack_expired_token(self): token = create_jwt(test_user_id, test_timeout, {}, test_now) - with self.assertRaises(ExpiredSignatureError): # noqa: PT027 + with pytest.raises(ExpiredSignatureError): unpack_jwt(token, test_user_id, test_now + test_timeout + 1) def test_missing_expired_lms_user_id(self): @@ -104,7 +104,7 @@ def test_missing_expired_lms_user_id(self): del payload['lms_user_id'] token = _encode_and_sign(payload) - with self.assertRaises(MissingRequiredClaimError): # noqa: PT027 + with pytest.raises(MissingRequiredClaimError): unpack_jwt(token, test_user_id, test_now) def test_missing_expired_key(self): @@ -112,5 +112,5 @@ def test_missing_expired_key(self): del payload['exp'] token = _encode_and_sign(payload) - with self.assertRaises(MissingRequiredClaimError): # noqa: PT027 + with pytest.raises(MissingRequiredClaimError): unpack_jwt(token, test_user_id, test_now)