diff --git a/apps/dot_ext/tests/test_beneficiary_demographic_scope_changes.py b/apps/dot_ext/tests/test_beneficiary_demographic_scope_changes.py index ce6cc247d..c0280ffd5 100644 --- a/apps/dot_ext/tests/test_beneficiary_demographic_scope_changes.py +++ b/apps/dot_ext/tests/test_beneficiary_demographic_scope_changes.py @@ -1,18 +1,20 @@ import json -from apps.test import BaseApiTest -from django.core.management import call_command -from django.http import HttpRequest -from django.urls import reverse +from http import HTTPStatus +from unittest import mock # from oauth2_provider.compat import parse_qs, urlparse from urllib.parse import parse_qs, urlparse + +from django.core.management import call_command +from django.http import HttpRequest +from django.urls import reverse from oauth2_provider.models import AccessToken, RefreshToken from rest_framework.test import APIClient from waffle.testutils import override_switch -from apps.authorization.models import DataAccessGrant, ArchivedDataAccessGrant -from apps.dot_ext.models import ArchivedToken, Application -from http import HTTPStatus -from unittest import mock + +from apps.authorization.models import ArchivedDataAccessGrant, DataAccessGrant +from apps.dot_ext.models import Application, ArchivedToken +from apps.test import BaseApiTest class TestBeneficiaryDemographicScopesChanges(BaseApiTest): @@ -127,7 +129,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): ) # Assert auth request was successful - self.assertEqual(status_code, 200) + self.assertEqual(status_code, HTTPStatus.OK) # Assert scope in response content self.assertEqual(response_scopes, sorted(APPLICATION_SCOPES_FULL)) @@ -138,7 +140,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Assert access to userinfo end point? client.credentials(HTTP_AUTHORIZATION='Bearer ' + token_1.token) response = client.get('/v1/connect/userinfo') - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, HTTPStatus.OK) # ------ TEST #2: Test refresh of token_1 refresh_request_data = { @@ -152,7 +154,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): content = json.loads(response.content.decode('utf-8')) # Assert successful - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, HTTPStatus.OK) # Assert response scopes response_scopes = sorted(content['scope'].split()) @@ -166,7 +168,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Assert access to userinfo end point? client.credentials(HTTP_AUTHORIZATION='Bearer ' + token.token) response = client.get('/v1/connect/userinfo') - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, HTTPStatus.OK) # Verify token counts expected. self.assertEqual(AccessToken.objects.count(), 1) @@ -197,7 +199,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Assert NO access to userinfo end point? client.credentials(HTTP_AUTHORIZATION='Bearer ' + token_3.token) response = client.get('/v1/connect/userinfo') - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN) # Verify token counts expected. self.assertEqual(AccessToken.objects.count(), 1) @@ -215,7 +217,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Test access to userinfo end point? NO ACCESS! response = client.get('/v1/connect/userinfo') content = json.loads(response.content) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) self.assertEqual(content.get('detail', None), 'Authentication credentials were not provided.') # ------ TEST #5: Test token_1 from TEST #1 token refresh? NO ACCESS! @@ -241,7 +243,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): ) # Assert auth request was successful - self.assertEqual(status_code, 200) + self.assertEqual(status_code, HTTPStatus.OK) # Assert scope in response content self.assertEqual(response_scopes, sorted(APPLICATION_SCOPES_FULL)) @@ -252,7 +254,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Assert access to userinfo end point? client.credentials(HTTP_AUTHORIZATION='Bearer ' + token_6.token) response = client.get('/v1/connect/userinfo') - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, HTTPStatus.OK) # ------ TEST #7: Test token_3 from TEST #3 again. It should still have access, but no permission with status=403. @@ -262,8 +264,8 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Test access to userinfo end point? response = client.get('/v1/connect/userinfo') content = json.loads(response.content) - self.assertEqual(response.status_code, 403) - self.assertEqual(content.get('detail', None), 'You do not have permission to perform this action.') + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) + self.assertEqual(content.get('detail', None), 'Authentication credentials were not provided.') # Verify token counts expected. self.assertEqual(AccessToken.objects.count(), 2) @@ -280,7 +282,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Perform partial authorization request, with out application getting an access token. response = self.client.post(reverse('oauth2_provider:authorize'), data=payload) - self.assertEqual(response.status_code, 302) + self.assertEqual(response.status_code, HTTPStatus.FOUND) # Setup token_3 in APIClient from previous step. It should be removed now? client.credentials(HTTP_AUTHORIZATION='Bearer ' + token_3.token) @@ -288,7 +290,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Test access to userinfo end point? response = client.get('/v1/connect/userinfo') content = json.loads(response.content) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) self.assertEqual(content.get('detail', None), 'Authentication credentials were not provided.') # Verify token counts expected. @@ -309,7 +311,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): ) # Assert auth request was successful - self.assertEqual(status_code, 200) + self.assertEqual(status_code, HTTPStatus.OK) # Verify token counts expected. self.assertEqual(AccessToken.objects.count(), 1) @@ -323,14 +325,14 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Assert access to userinfo end point? client.credentials(HTTP_AUTHORIZATION='Bearer ' + token_9.token) response = client.get('/v1/connect/userinfo') - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, HTTPStatus.OK) # Beneficiary chooses the DENY button choice on consent page payload['allow'] = False # Perform partial authorization request, with out application getting an access token. response = self.client.post(reverse('oauth2_provider:authorize'), data=payload) - self.assertEqual(response.status_code, 302) + self.assertEqual(response.status_code, HTTPStatus.FOUND) # Verify token counts expected. self.assertEqual(AccessToken.objects.count(), 1) @@ -346,7 +348,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # when the allow parameter is false client.credentials(HTTP_AUTHORIZATION='Bearer ' + token_9.token) response = client.get('/v1/connect/userinfo') - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, HTTPStatus.OK) # BB2-4270: Remove prior active tokens so tests below are not looking for multiple active tokens # which is an impossible state @@ -360,12 +362,12 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): payload['allow'] = True # Perform authorization request - token_10, refresh_token_10, status_code, response_scopes, access_token_scopes = self._authorize_and_request_token( - payload, application + token_10, refresh_token_10, status_code, response_scopes, access_token_scopes = ( + self._authorize_and_request_token(payload, application) ) # Assert auth request was successful - self.assertEqual(status_code, 200) + self.assertEqual(status_code, HTTPStatus.OK) # Verify token counts expected. self.assertEqual(AccessToken.objects.count(), 1) @@ -379,7 +381,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Assert access to userinfo end point? client.credentials(HTTP_AUTHORIZATION='Bearer ' + token_10.token) response = client.get('/v1/connect/userinfo') - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, HTTPStatus.OK) # Application changes choice to require demographic scopes application.require_demographic_scopes = False @@ -387,7 +389,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Perform partial authorization request, with out application getting an access token. response = self.client.post(reverse('oauth2_provider:authorize'), data=payload) - self.assertEqual(response.status_code, 302) + self.assertEqual(response.status_code, HTTPStatus.FOUND) # Verify token counts expected. self.assertEqual(AccessToken.objects.count(), 0) @@ -400,7 +402,7 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Perform partial authorization request, with out application getting an access token. response = self.client.post(reverse('oauth2_provider:authorize'), data=payload) - self.assertEqual(response.status_code, 302) + self.assertEqual(response.status_code, HTTPStatus.FOUND) # Verify token counts expected. self.assertEqual(AccessToken.objects.count(), 0) @@ -414,5 +416,5 @@ def test_bene_demo_scopes_change(self, mock_get_and_update): # Assert access to userinfo end point? client.credentials(HTTP_AUTHORIZATION='Bearer ' + token_10.token) response = client.get('/v1/connect/userinfo') - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) self.assertEqual(content.get('detail', None), 'Authentication credentials were not provided.') diff --git a/apps/dot_ext/tests/test_utils.py b/apps/dot_ext/tests/test_utils.py index c314a3dad..3c4052813 100644 --- a/apps/dot_ext/tests/test_utils.py +++ b/apps/dot_ext/tests/test_utils.py @@ -1,8 +1,16 @@ +from datetime import timedelta +from unittest.mock import MagicMock, patch + from django.test import TestCase +from django.utils import timezone -from apps.versions import VersionNotMatched from apps.dot_ext.constants import SUPPORTED_VERSION_TEST_CASES -from apps.dot_ext.utils import get_api_version_number_from_url, validate_latin_extended_string +from apps.dot_ext.utils import ( + get_api_version_number_from_url, + revoke_prior_tokens_for_user_and_app_if_they_exist, + validate_latin_extended_string, +) +from apps.versions import VersionNotMatched class TestDOTUtils(TestCase): @@ -33,3 +41,74 @@ def test_latin_extended_failure(self): for text in invalid_inputs: assert not validate_latin_extended_string(text) + + @patch('apps.dot_ext.utils.timezone') + @patch('apps.dot_ext.utils.get_refresh_token_model') + @patch('apps.dot_ext.utils.get_access_token_model') + def test_revoke_prior_tokens_for_user_and_app_if_they_exist_prior_tokens_exit( + self, mock_get_access_token, mock_get_refresh_token, mock_timezone + ): + """Confirm that if there are multiple access tokens, the prior ones will have their expires value updated. + Also, any associated refresh tokens will have their access_token_id set to null, and their revoked value + set to the current time (UTC) + """ + mock_timezone.now.return_value = timezone.now() + + new_access_token = MagicMock( + id=1, user_id=1, expires=timezone.now() + timedelta(hours=2), created=timezone.now() + ) + prior_access_token = MagicMock( + id=2, user_id=1, expires=timezone.now() + timedelta(minutes=30), created=timezone.now() - timedelta(days=10) + ) + prior_access_token_two = MagicMock( + id=3, user_id=1, expires=timezone.now() - timedelta(days=20), created=timezone.now() - timedelta(days=20) + ) + + mock_access_token_model = MagicMock() + mock_access_token_model.objects.filter.return_value.order_by.return_value = [ + new_access_token, + prior_access_token, + prior_access_token_two, + ] + mock_get_access_token.return_value = mock_access_token_model + + refresh_token = MagicMock(access_token_id=2, revoked=None) + mock_refresh_token_model = MagicMock() + mock_refresh_token_model.objects.get.return_value = refresh_token + mock_get_refresh_token.return_value = mock_refresh_token_model + + revoke_prior_tokens_for_user_and_app_if_they_exist(1, 10) + + prior_access_token.save.assert_called_once() + prior_access_token_two.save.assert_not_called() + + assert refresh_token.revoked is not None + assert refresh_token.access_token_id is None + refresh_token.save.assert_called_once() + + @patch('apps.dot_ext.utils.timezone') + @patch('apps.dot_ext.utils.get_access_token_model') + def test_revoke_prior_tokens_for_user_and_app_if_they_exist_no_associated_refresh_token( + self, mock_get_access_token, mock_timezone + ): + """Confirm that even if there is no associated refresh_token for an access_token, that the access_token + still has its expires value updated. + """ + mock_timezone.now.return_value = timezone.now() + + new_access_token = MagicMock( + id=1, user_id=1, expires=timezone.now() + timedelta(hours=2), created=timezone.now() + ) + prior_access_token = MagicMock( + id=2, user_id=1, expires=timezone.now() + timedelta(minutes=30), created=timezone.now() - timedelta(days=10) + ) + + mock_access_token_model = MagicMock() + mock_access_token_model.objects.filter.return_value.order_by.return_value = [ + new_access_token, + prior_access_token, + ] + mock_get_access_token.return_value = mock_access_token_model + + revoke_prior_tokens_for_user_and_app_if_they_exist(1, 10) + prior_access_token.save.assert_called_once() diff --git a/apps/dot_ext/utils.py b/apps/dot_ext/utils.py index 134fcc26c..b3348d4a7 100644 --- a/apps/dot_ext/utils.py +++ b/apps/dot_ext/utils.py @@ -8,7 +8,14 @@ from django.db import transaction from django.http import HttpRequest from django.http.response import JsonResponse -from oauth2_provider.models import AccessToken, RefreshToken, get_application_model +from django.utils import timezone +from oauth2_provider.models import ( + AccessToken, + RefreshToken, + get_access_token_model, + get_application_model, + get_refresh_token_model, +) from oauthlib.oauth2.rfc6749.errors import ( InvalidClientError, InvalidGrantError, @@ -326,3 +333,36 @@ def validate_latin_extended_string(text: str) -> bool: bool: if all strings are encoded less than U+017F (383) and it is not empty """ return all(ord(char) <= 383 for char in text) and bool(text) + + +def revoke_prior_tokens_for_user_and_app_if_they_exist(user_id: int, app_id: int) -> None: + """Revoke prior tokens for a user/app id pair to ensure that if a user has reauthorized + that prior tokens can't be used, in case any of those prior tokens have more scopes than + the newly created one + + Args: + user_id (int): ID for the user who just re-authorized an app they have authorized previously + app_id (int): ID for the application the user just re-authorized for + """ + AccessToken = get_access_token_model() + RefreshToken = get_refresh_token_model() + prior_access_tokens = list(AccessToken.objects.filter(user=user_id, application=app_id).order_by('-created')) + + for access_token in prior_access_tokens: + try: + refresh_token = RefreshToken.objects.get(access_token=access_token.id) + + if refresh_token.revoked is None: + refresh_token.revoked = timezone.now() + refresh_token.access_token_id = None + refresh_token.save() + + except RefreshToken.DoesNotExist: + # indicates it is an access token created via CAN flow, as it does not have an associated refresh token + # no action needed + pass + + # Only update the access token expires value if it is in the future + if access_token.expires > timezone.now(): + access_token.expires = timezone.now() + access_token.save() diff --git a/apps/dot_ext/views/authorization.py b/apps/dot_ext/views/authorization.py index 4a719daec..bd432db1b 100644 --- a/apps/dot_ext/views/authorization.py +++ b/apps/dot_ext/views/authorization.py @@ -103,6 +103,7 @@ get_api_version_number_from_url, json_response_from_oauth2_error, remove_application_user_pair_tokens_data_access, + revoke_prior_tokens_for_user_and_app_if_they_exist, validate_app_is_active, validate_latin_extended_string, ) @@ -376,6 +377,10 @@ def validate_v3_authorization_request(self): try: application_user = get_user_model().objects.get(id=self.application.user_id) + # If the v3_early_adopter does not exist in the database, a WaffleFlag object is returned, + # but the id is None. In that case, we want to return and leave it up to the v3_endpoints switch + # as to whether v3 calls can be made. If the flag does exist, then the id will not be None + # and we will check to see if the flag is active for the application if flag.id is None or flag.is_active_for_user(application_user): # Update the class variable to ensure subsequent calls to dispatch don't call this function # more times than is needed @@ -501,6 +506,9 @@ def form_valid(self, form): # Update AuthFlowUuid instance with code. update_instance_auth_flow_trace_with_code(auth_dict, code) + # Check for prior tokens to ensure they can't continue to be used + revoke_prior_tokens_for_user_and_app_if_they_exist(self.request.user.id, application.id) + return self.redirect(self.success_url, application)