diff --git a/apps/dot_ext/forms.py b/apps/dot_ext/forms.py index f2a9d6a00..c2d95c240 100644 --- a/apps/dot_ext/forms.py +++ b/apps/dot_ext/forms.py @@ -3,21 +3,21 @@ from django import forms from django.conf import settings +from django.contrib.auth.models import Group, User from django.core.exceptions import ValidationError +from django.forms.widgets import URLInput from django.utils.safestring import mark_safe from django.utils.translation import gettext_lazy as _ from oauth2_provider.forms import AllowForm as DotAllowForm from oauth2_provider.models import get_application_model + from apps.accounts.models import UserProfile from apps.capabilities.models import ProtectedCapability +from apps.constants import HHS_SERVER_LOGNAME_FMT from apps.dot_ext.constants import BENE_PERSONAL_INFO_SCOPES, PRINTABLE_SPECIAL_ASCII -from apps.dot_ext.scopes import CapabilitiesScopes from apps.dot_ext.models import Application, InternalApplicationLabels +from apps.dot_ext.scopes import CapabilitiesScopes from apps.dot_ext.validators import validate_logo_image, validate_notags, validate_url -from django.contrib.auth.models import Group, User -from django.forms.widgets import URLInput - -from apps.constants import HHS_SERVER_LOGNAME_FMT logger = logging.getLogger(HHS_SERVER_LOGNAME_FMT.format(__name__)) @@ -71,8 +71,9 @@ class CustomRegisterApplicationForm(forms.ModelForm): ) def __init__(self, user, *args, **kwargs): - agree_label = 'Yes I have read and agree to the API Terms of Service Agreement*' % ( - settings.TOS_URI + agree_label = ( + 'Yes I have read and agree to the API Terms of Service Agreement*' + % (settings.TOS_URI) ) super(CustomRegisterApplicationForm, self).__init__(*args, **kwargs) self.fields['authorization_grant_type'].choices = settings.GRANT_TYPES @@ -381,6 +382,7 @@ class SimpleAllowForm(DotAllowForm): code_challenge = forms.CharField(required=False, widget=forms.HiddenInput()) code_challenge_method = forms.CharField(required=False, widget=forms.HiddenInput()) share_demographic_scopes = forms.CharField(required=False) + share_samhsa_data = forms.CharField(required=False) def clean(self): cleaned_data = super().clean() diff --git a/apps/dot_ext/signals.py b/apps/dot_ext/signals.py index 9f9695a04..5ec07d8c9 100644 --- a/apps/dot_ext/signals.py +++ b/apps/dot_ext/signals.py @@ -1,11 +1,11 @@ import logging from django.db.models.signals import post_save, pre_save -from django.dispatch import Signal, receiver +from django.dispatch import Signal from oauth2_provider.models import get_access_token_model, get_application_model from apps.constants import HHS_SERVER_LOGNAME_FMT -from apps.dot_ext.models import AccessTokenExtension, ArchivedToken +from apps.dot_ext.models import ArchivedToken from libs.decorators import waffle_function_switch from libs.mail import Mailer @@ -86,14 +86,3 @@ def outreach_first_api_call(sender, instance=None, **kwargs): post_save.connect(outreach_first_application, sender=Application) pre_save.connect(outreach_first_api_call, sender=AccessToken) - - -@receiver(post_save, sender=AccessToken) -def create_access_token_extension(sender, instance, created, **kwargs): - # TODO: Need to update to take into account what was passed for include_samhsa - # Once the checkbox is in place on v3 permissions screen - if created: - AccessTokenExtension.objects.create( - access_token=instance, - include_samhsa=True, - ) diff --git a/apps/dot_ext/tests/test_authorization_token.py b/apps/dot_ext/tests/test_authorization_token.py index db77a62c5..4a3bda4cb 100644 --- a/apps/dot_ext/tests/test_authorization_token.py +++ b/apps/dot_ext/tests/test_authorization_token.py @@ -28,7 +28,7 @@ IDME_HIGHER_ISS, IDME_LOWER_ISS, ) -from apps.dot_ext.models import Application +from apps.dot_ext.models import AccessTokenExtension, Application from apps.dot_ext.utils import ( get_application_from_data, get_application_from_meta, @@ -265,6 +265,61 @@ def test_validate_environment_for_id_token(self) -> None: result = view_instance._validate_idme_url_for_id_token_and_environment(IDME_HIGHER_ISS) assert result + def test_retrieve_prior_include_samhsa_value_non_refresh_token_grant_type(self) -> None: + view_instance = TokenView() + prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('authorization_code', None) + assert prior_include_samhsa + + @patch('apps.dot_ext.views.authorization.get_refresh_token_model') + @patch('apps.dot_ext.models.AccessTokenExtension.objects.get') + def test_retrieve_prior_include_samhsa_value(self, mock_access_token_extension, mock_refresh_model) -> None: + """Confirm that if the prior access_token_extension record has include_samhsa set to True, that True is returned""" + view_instance = TokenView() + mock_request = MagicMock() + mock_refresh_token = MagicMock() + mock_request.POST = { + 'grant_type': 'refresh_token', + 'refresh_token': 'tkn', + } + mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token + mock_access_token_extension.return_value = AccessTokenExtension(include_samhsa=True) + prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request) + assert prior_include_samhsa + + @patch('apps.dot_ext.views.authorization.get_refresh_token_model') + @patch('apps.dot_ext.models.AccessTokenExtension.objects.get') + def test_retrieve_prior_include_samhsa_value_false(self, mock_access_token_extension, mock_refresh_model) -> None: + """Confirm that if the prior access_token_extension record has include_samhsa set to False, that False is returned""" + view_instance = TokenView() + mock_request = MagicMock() + mock_refresh_token = MagicMock() + mock_request.POST = { + 'grant_type': 'refresh_token', + 'refresh_token': 'tkn', + } + mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token + mock_access_token_extension.return_value = AccessTokenExtension(include_samhsa=False) + prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request) + assert not prior_include_samhsa + + @patch('apps.dot_ext.views.authorization.get_refresh_token_model') + @patch('apps.dot_ext.models.AccessTokenExtension.objects.get') + def test_retrieve_prior_include_samhsa_value_access_token_extension_dne( + self, mock_access_token_extension, mock_refresh_model + ) -> None: + """Confirm that when there is no access_token_extension record returned, a value of True is returned""" + view_instance = TokenView() + mock_request = MagicMock() + mock_refresh_token = MagicMock() + mock_request.POST = { + 'grant_type': 'refresh_token', + 'refresh_token': 'tkn', + } + mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token + mock_access_token_extension.side_effect = AccessTokenExtension.DoesNotExist + prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request) + assert prior_include_samhsa + # we set empty GET/META/POST because get_application_from_data does not like it if a GET is missing. class TestClientIdExtraction(BaseApiTest): diff --git a/apps/dot_ext/tests/test_models.py b/apps/dot_ext/tests/test_models.py index a69bb0c71..d37e72e60 100644 --- a/apps/dot_ext/tests/test_models.py +++ b/apps/dot_ext/tests/test_models.py @@ -287,23 +287,6 @@ def test_internal_application_labels_admin(self): self.assertTrue(l5.slug in internal_labels) self.assertTrue(l11.slug not in internal_labels) - def test_access_token_extension_is_created(self) -> None: - """Ensure that when an access token is saved, a corresponding AccessTokenExtension record - is created - """ - - first_access_token = self.create_token( - 'John', 'Smith', fhir_id_v2=DEFAULT_SAMPLE_FHIR_ID_V2, fhir_id_v3=DEFAULT_SAMPLE_FHIR_ID_V3 - ) - ac = AccessToken.objects.get(token=first_access_token) - ac.scope = 'patient/Coverage.search patient/Patient.search patient/ExplanationOfBenefit.search' - ac.save() - access_token_extension = AccessTokenExtension.objects.get(access_token=ac) - - assert access_token_extension is not None - assert access_token_extension.access_token == ac - assert access_token_extension.include_samhsa - def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None: """Ensure that when an access token is deleted, the corresponding AccessTokenExtension record is deleted @@ -315,8 +298,11 @@ def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None: ac = AccessToken.objects.get(token=first_access_token) ac.scope = 'patient/Coverage.search patient/Patient.search patient/ExplanationOfBenefit.search' ac.save() - access_token_extension = AccessTokenExtension.objects.get(access_token=ac) - access_token_extension_id = access_token_extension.id + + access_token_extension = AccessTokenExtension() + access_token_extension.access_token = ac + access_token_extension.include_samhsa = True + access_token_extension.save() assert access_token_extension is not None assert access_token_extension.access_token == ac @@ -325,4 +311,4 @@ def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None: ac.delete() with self.assertRaises(AccessTokenExtension.DoesNotExist): - access_token_extension = AccessTokenExtension.objects.get(id=access_token_extension_id) + access_token_extension = AccessTokenExtension.objects.get(id=access_token_extension.id) diff --git a/apps/dot_ext/tests/test_utils.py b/apps/dot_ext/tests/test_utils.py index c314a3dad..bb6693f5c 100644 --- a/apps/dot_ext/tests/test_utils.py +++ b/apps/dot_ext/tests/test_utils.py @@ -1,11 +1,22 @@ +from unittest.mock import MagicMock, patch + from django.test import TestCase -from apps.versions import VersionNotMatched from apps.dot_ext.constants import SUPPORTED_VERSION_TEST_CASES -from apps.dot_ext.utils import get_api_version_number_from_url, validate_latin_extended_string +from apps.dot_ext.utils import ( + check_samhsa_cache_and_create_access_token_extension, + get_api_version_number_from_url, + validate_latin_extended_string, +) +from apps.versions import VersionNotMatched class TestDOTUtils(TestCase): + def setUp(self): + self.token = MagicMock() + self.code = 'test_code' + self.prior_include_samhsa = False + def test_get_api_version_number(self): for test in SUPPORTED_VERSION_TEST_CASES: result = get_api_version_number_from_url(test['url_path']) @@ -33,3 +44,75 @@ def test_latin_extended_failure(self): for text in invalid_inputs: assert not validate_latin_extended_string(text) + + @patch('apps.dot_ext.utils.AccessTokenExtension') + @patch('apps.dot_ext.utils.cache') + def test_check_samhsa_cache_and_create_access_token_extension_use_cached_value( + self, mock_cache, mock_access_token_extension + ): + """ + When cache has a value for the code and grant_type is NOT refresh_token, + the cached value should be used for include_samhsa. + """ + mock_cache.get.return_value = False + + check_samhsa_cache_and_create_access_token_extension( + prior_include_samhsa=False, + code=self.code, + grant_type='authorization_code', + token=self.token, + ) + + mock_access_token_extension.objects.get_or_create.assert_called_once_with( + access_token=self.token, + include_samhsa=False, + ) + mock_cache.delete.assert_called_once_with(f'include_samhsa:{self.code}') + + @patch('apps.dot_ext.utils.AccessTokenExtension') + @patch('apps.dot_ext.utils.cache') + def test_check_samhsa_cache_and_create_access_token_extension_no_cache_value( + self, mock_cache, mock_access_token_extension + ): + """ + When cache has NO value for the code and grant_type is NOT refresh_token, + include_samhsa should default to True. + """ + mock_cache.get.return_value = None + + check_samhsa_cache_and_create_access_token_extension( + prior_include_samhsa=False, + code=self.code, + grant_type='authorization_code', + token=self.token, + ) + + mock_access_token_extension.objects.get_or_create.assert_called_once_with( + access_token=self.token, + include_samhsa=True, + ) + mock_cache.delete.assert_called_once_with(f'include_samhsa:{self.code}') + + @patch('apps.dot_ext.utils.AccessTokenExtension') + @patch('apps.dot_ext.utils.cache') + def test_check_samhsa_cache_and_create_access_token_extension_refresh_token_grant( + self, mock_cache, mock_access_token_extension + ): + """ + When grant_type is 'refresh_token', prior_include_samhsa=False + should override any cache value. + """ + mock_cache.get.return_value = True + + check_samhsa_cache_and_create_access_token_extension( + prior_include_samhsa=False, + code=self.code, + grant_type='refresh_token', + token=self.token, + ) + + mock_access_token_extension.objects.get_or_create.assert_called_once_with( + access_token=self.token, + include_samhsa=False, + ) + mock_cache.delete.assert_called_once_with(f'include_samhsa:{self.code}') diff --git a/apps/dot_ext/utils.py b/apps/dot_ext/utils.py index 134fcc26c..9a7634f48 100644 --- a/apps/dot_ext/utils.py +++ b/apps/dot_ext/utils.py @@ -5,6 +5,7 @@ import jwt from django.contrib.auth import get_user_model +from django.core.cache import cache from django.db import transaction from django.http import HttpRequest from django.http.response import JsonResponse @@ -23,7 +24,7 @@ HHS_SERVER_LOGNAME_FMT, ) from apps.dot_ext.constants import APPLICATION_THIRTEEN_MONTH_DATA_ACCESS_NOT_FOUND_MESG -from apps.dot_ext.models import Application +from apps.dot_ext.models import AccessTokenExtension, Application from apps.versions import VersionNotMatched, Versions User = get_user_model() @@ -326,3 +327,31 @@ def validate_latin_extended_string(text: str) -> bool: bool: if all strings are encoded less than U+017F (383) and it is not empty """ return all(ord(char) <= 383 for char in text) and bool(text) + + +def check_samhsa_cache_and_create_access_token_extension( + prior_include_samhsa: bool, code: str, grant_type: str, token: AccessToken +) -> None: + """Retrieve a value from the cache, if available, for the code being used in the authorization or refresh request + + Args: + prior_include_samhsa (bool): The value the prior access_token_extension record had for include_samhsa + code (str): The code for the auth or refresh request, used to retrieve cached value + grant_type (str): Grant type of the call to TokenView.post + token (AccessToken): The access token that was generated + """ + include_samhsa = True + + # This was evaluating even if the cache had False for the value, but modifying the conditional like this + # allowed for better unit test coverage + if cache.get(f'include_samhsa:{code}') is not None and cache.get(f'include_samhsa:{code}') != '': + include_samhsa = cache.get(f'include_samhsa:{code}') + + if grant_type == 'refresh_token': + include_samhsa = prior_include_samhsa + + AccessTokenExtension.objects.get_or_create( + access_token=token, + include_samhsa=include_samhsa, + ) + cache.delete(f'include_samhsa:{code}') diff --git a/apps/dot_ext/views/authorization.py b/apps/dot_ext/views/authorization.py index 4a719daec..169085857 100644 --- a/apps/dot_ext/views/authorization.py +++ b/apps/dot_ext/views/authorization.py @@ -95,11 +95,12 @@ set_session_auth_flow_trace_value, update_instance_auth_flow_trace_with_code, ) -from apps.dot_ext.models import Application, Approval +from apps.dot_ext.models import AccessTokenExtension, Application, Approval from apps.dot_ext.parser import normalize_address from apps.dot_ext.scopes import CapabilitiesScopes from apps.dot_ext.signals import beneficiary_authorized_application from apps.dot_ext.utils import ( + check_samhsa_cache_and_create_access_token_extension, get_api_version_number_from_url, json_response_from_oauth2_error, remove_application_user_pair_tokens_data_access, @@ -492,6 +493,9 @@ def form_valid(self, form): url_query = parse_qs(urlparse(self.success_url).query) code = url_query.get('code', [None])[0] + share_samhsa_data = form.cleaned_data.get('share_samhsa_data') + cache.add(f'include_samhsa:{code}', share_samhsa_data, timeout=300) + # Get auth flow trace session values dict. auth_dict = get_session_auth_flow_trace(self.request) @@ -1024,6 +1028,21 @@ def _parse_ial_into_parameter(self, payload: dict) -> dict: return id_match_payload.model_dump(mode='json', exclude_none=True) + def _retrieve_prior_include_samhsa_value(self, grant_type: str, request: HttpRequest) -> bool: + prior_include_samhsa = True + if grant_type == 'refresh_token': + refresh_token_str = request.POST.get('refresh_token') + refresh_token = get_refresh_token_model().objects.get(token=refresh_token_str) + try: + prior_access_token_extension = AccessTokenExtension.objects.get( + access_token_id=refresh_token.access_token_id + ) + prior_include_samhsa = prior_access_token_extension.include_samhsa + except AccessTokenExtension.DoesNotExist: + # this case indicates it was an access token created before the access token extension was added + log.info(f'No access token extension for access token id: {refresh_token.access_token_id}') + return prior_include_samhsa + @method_decorator(sensitive_post_parameters('password')) def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: version = get_api_version_number_from_url(self.request.path_info) @@ -1138,6 +1157,8 @@ def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: status=HTTPStatus.FORBIDDEN, ) + prior_include_samhsa = self._retrieve_prior_include_samhsa_value(grant_type, request) + url, headers, body, status = self.create_token_response(request) # retrieve the access token, update user_id with the user.id sourced above @@ -1147,6 +1168,9 @@ def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: if access_token: token = get_access_token_model().objects.get(token=access_token) + code = request.POST.get('code', None) + check_samhsa_cache_and_create_access_token_extension(prior_include_samhsa, code, grant_type, token) + if grant_type == CLIENT_CREDENTIALS: token.user_id = user.id token.save() diff --git a/apps/fhir/bluebutton/tests/test_fhir_resources_read_search_w_validation.py b/apps/fhir/bluebutton/tests/test_fhir_resources_read_search_w_validation.py index 964733efe..a0068bb61 100755 --- a/apps/fhir/bluebutton/tests/test_fhir_resources_read_search_w_validation.py +++ b/apps/fhir/bluebutton/tests/test_fhir_resources_read_search_w_validation.py @@ -792,9 +792,11 @@ def test_v12_include_samhsa_false_fails(self): ac = self.create_token( 'John', 'Smith', fhir_id_v2=DEFAULT_SAMPLE_FHIR_ID_V2, fhir_id_v3=DEFAULT_SAMPLE_FHIR_ID_V3 ) - extension = AccessToken.objects.get(token=ac).accesstokenextension - extension.include_samhsa = False - extension.save() + ac_record = AccessToken.objects.get(token=ac) + access_token_extension = AccessTokenExtension() + access_token_extension.access_token = ac_record + access_token_extension.include_samhsa = False + access_token_extension.save() for version in [Versions.V1, Versions.V2]: response = self.client.get(reverse(SEARCH_EOB_URLS[version]), Authorization=f'Bearer {ac}') @@ -809,9 +811,11 @@ def test_v12_include_samhsa_true_succeeds(self): ac = self.create_token( 'John', 'Smith', fhir_id_v2=DEFAULT_SAMPLE_FHIR_ID_V2, fhir_id_v3=DEFAULT_SAMPLE_FHIR_ID_V3 ) - extension = AccessToken.objects.get(token=ac).accesstokenextension - extension.include_samhsa = True - extension.save() + ac_record = AccessToken.objects.get(token=ac) + access_token_extension = AccessTokenExtension() + access_token_extension.access_token = ac_record + access_token_extension.include_samhsa = True + access_token_extension.save() for version in [Versions.V1, Versions.V2]: response = self.client.get(reverse(SEARCH_EOB_URLS[version]), Authorization=f'Bearer {ac}') diff --git a/templates/design_system/authorize_v3.html b/templates/design_system/authorize_v3.html index 962c4e9c3..5cb6afa20 100644 --- a/templates/design_system/authorize_v3.html +++ b/templates/design_system/authorize_v3.html @@ -109,7 +109,38 @@
These include claims related to alcohol and/or substance use disorder treatment.
+You may or may not have these claims - not all Medicare enrollees do.
++
+ +