Skip to content
16 changes: 9 additions & 7 deletions apps/dot_ext/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@

from django import forms
from django.conf import settings
from django.contrib.auth.models import Group, User
from django.core.exceptions import ValidationError
from django.forms.widgets import URLInput
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _
from oauth2_provider.forms import AllowForm as DotAllowForm
from oauth2_provider.models import get_application_model

from apps.accounts.models import UserProfile
from apps.capabilities.models import ProtectedCapability
from apps.constants import HHS_SERVER_LOGNAME_FMT
from apps.dot_ext.constants import BENE_PERSONAL_INFO_SCOPES, PRINTABLE_SPECIAL_ASCII
from apps.dot_ext.scopes import CapabilitiesScopes
from apps.dot_ext.models import Application, InternalApplicationLabels
from apps.dot_ext.scopes import CapabilitiesScopes
from apps.dot_ext.validators import validate_logo_image, validate_notags, validate_url
from django.contrib.auth.models import Group, User
from django.forms.widgets import URLInput

from apps.constants import HHS_SERVER_LOGNAME_FMT

logger = logging.getLogger(HHS_SERVER_LOGNAME_FMT.format(__name__))

Expand Down Expand Up @@ -71,8 +71,9 @@ class CustomRegisterApplicationForm(forms.ModelForm):
)

def __init__(self, user, *args, **kwargs):
agree_label = 'Yes I have read and agree to the <a target="_blank" href="%s">API Terms of Service Agreement</a>*' % (
settings.TOS_URI
agree_label = (
'Yes I have read and agree to the <a target="_blank" href="%s">API Terms of Service Agreement</a>*'
% (settings.TOS_URI)
)
super(CustomRegisterApplicationForm, self).__init__(*args, **kwargs)
self.fields['authorization_grant_type'].choices = settings.GRANT_TYPES
Expand Down Expand Up @@ -381,6 +382,7 @@ class SimpleAllowForm(DotAllowForm):
code_challenge = forms.CharField(required=False, widget=forms.HiddenInput())
code_challenge_method = forms.CharField(required=False, widget=forms.HiddenInput())
share_demographic_scopes = forms.CharField(required=False)
share_samhsa_data = forms.CharField(required=False)

def clean(self):
cleaned_data = super().clean()
Expand Down
15 changes: 2 additions & 13 deletions apps/dot_ext/signals.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging

from django.db.models.signals import post_save, pre_save
from django.dispatch import Signal, receiver
from django.dispatch import Signal
from oauth2_provider.models import get_access_token_model, get_application_model

from apps.constants import HHS_SERVER_LOGNAME_FMT
from apps.dot_ext.models import AccessTokenExtension, ArchivedToken
from apps.dot_ext.models import ArchivedToken
from libs.decorators import waffle_function_switch
from libs.mail import Mailer

Expand Down Expand Up @@ -86,14 +86,3 @@ def outreach_first_api_call(sender, instance=None, **kwargs):

post_save.connect(outreach_first_application, sender=Application)
pre_save.connect(outreach_first_api_call, sender=AccessToken)


@receiver(post_save, sender=AccessToken)
Comment thread
JamesDemeryNava marked this conversation as resolved.
def create_access_token_extension(sender, instance, created, **kwargs):
# TODO: Need to update to take into account what was passed for include_samhsa
# Once the checkbox is in place on v3 permissions screen
if created:
AccessTokenExtension.objects.create(
access_token=instance,
include_samhsa=True,
)
57 changes: 56 additions & 1 deletion apps/dot_ext/tests/test_authorization_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
IDME_HIGHER_ISS,
IDME_LOWER_ISS,
)
from apps.dot_ext.models import Application
from apps.dot_ext.models import AccessTokenExtension, Application
from apps.dot_ext.utils import (
get_application_from_data,
get_application_from_meta,
Expand Down Expand Up @@ -265,6 +265,61 @@ def test_validate_environment_for_id_token(self) -> None:
result = view_instance._validate_idme_url_for_id_token_and_environment(IDME_HIGHER_ISS)
assert result

def test_retrieve_prior_include_samhsa_value_non_refresh_token_grant_type(self) -> None:
view_instance = TokenView()
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('authorization_code', None)
assert prior_include_samhsa

@patch('apps.dot_ext.views.authorization.get_refresh_token_model')
@patch('apps.dot_ext.models.AccessTokenExtension.objects.get')
def test_retrieve_prior_include_samhsa_value(self, mock_access_token_extension, mock_refresh_model) -> None:
"""Confirm that if the prior access_token_extension record has include_samhsa set to True, that True is returned"""
view_instance = TokenView()
mock_request = MagicMock()
mock_refresh_token = MagicMock()
Comment thread
ryan-morosa marked this conversation as resolved.
mock_request.POST = {
'grant_type': 'refresh_token',
'refresh_token': 'tkn',
}
mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token
mock_access_token_extension.return_value = AccessTokenExtension(include_samhsa=True)
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request)
assert prior_include_samhsa

@patch('apps.dot_ext.views.authorization.get_refresh_token_model')
@patch('apps.dot_ext.models.AccessTokenExtension.objects.get')
def test_retrieve_prior_include_samhsa_value_false(self, mock_access_token_extension, mock_refresh_model) -> None:
Comment thread
ryan-morosa marked this conversation as resolved.
"""Confirm that if the prior access_token_extension record has include_samhsa set to False, that False is returned"""
view_instance = TokenView()
mock_request = MagicMock()
mock_refresh_token = MagicMock()
mock_request.POST = {
'grant_type': 'refresh_token',
'refresh_token': 'tkn',
}
mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token
mock_access_token_extension.return_value = AccessTokenExtension(include_samhsa=False)
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request)
assert not prior_include_samhsa

@patch('apps.dot_ext.views.authorization.get_refresh_token_model')
@patch('apps.dot_ext.models.AccessTokenExtension.objects.get')
def test_retrieve_prior_include_samhsa_value_access_token_extension_dne(
self, mock_access_token_extension, mock_refresh_model
) -> None:
"""Confirm that when there is no access_token_extension record returned, a value of True is returned"""
view_instance = TokenView()
mock_request = MagicMock()
mock_refresh_token = MagicMock()
mock_request.POST = {
'grant_type': 'refresh_token',
'refresh_token': 'tkn',
}
mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token
mock_access_token_extension.side_effect = AccessTokenExtension.DoesNotExist
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request)
assert prior_include_samhsa


# we set empty GET/META/POST because get_application_from_data does not like it if a GET is missing.
class TestClientIdExtraction(BaseApiTest):
Expand Down
26 changes: 6 additions & 20 deletions apps/dot_ext/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,23 +287,6 @@ def test_internal_application_labels_admin(self):
self.assertTrue(l5.slug in internal_labels)
self.assertTrue(l11.slug not in internal_labels)

def test_access_token_extension_is_created(self) -> None:
Comment thread
ryan-morosa marked this conversation as resolved.
"""Ensure that when an access token is saved, a corresponding AccessTokenExtension record
is created
"""

first_access_token = self.create_token(
'John', 'Smith', fhir_id_v2=DEFAULT_SAMPLE_FHIR_ID_V2, fhir_id_v3=DEFAULT_SAMPLE_FHIR_ID_V3
)
ac = AccessToken.objects.get(token=first_access_token)
ac.scope = 'patient/Coverage.search patient/Patient.search patient/ExplanationOfBenefit.search'
ac.save()
access_token_extension = AccessTokenExtension.objects.get(access_token=ac)

assert access_token_extension is not None
assert access_token_extension.access_token == ac
assert access_token_extension.include_samhsa

def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None:
"""Ensure that when an access token is deleted, the corresponding AccessTokenExtension record
is deleted
Expand All @@ -315,8 +298,11 @@ def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None:
ac = AccessToken.objects.get(token=first_access_token)
ac.scope = 'patient/Coverage.search patient/Patient.search patient/ExplanationOfBenefit.search'
ac.save()
access_token_extension = AccessTokenExtension.objects.get(access_token=ac)
access_token_extension_id = access_token_extension.id

access_token_extension = AccessTokenExtension()
access_token_extension.access_token = ac
access_token_extension.include_samhsa = True
access_token_extension.save()

assert access_token_extension is not None
assert access_token_extension.access_token == ac
Expand All @@ -325,4 +311,4 @@ def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None:
ac.delete()

with self.assertRaises(AccessTokenExtension.DoesNotExist):
access_token_extension = AccessTokenExtension.objects.get(id=access_token_extension_id)
access_token_extension = AccessTokenExtension.objects.get(id=access_token_extension.id)
87 changes: 85 additions & 2 deletions apps/dot_ext/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from unittest.mock import MagicMock, patch

from django.test import TestCase

from apps.versions import VersionNotMatched
from apps.dot_ext.constants import SUPPORTED_VERSION_TEST_CASES
from apps.dot_ext.utils import get_api_version_number_from_url, validate_latin_extended_string
from apps.dot_ext.utils import (
check_samhsa_cache_and_create_access_token_extension,
get_api_version_number_from_url,
validate_latin_extended_string,
)
from apps.versions import VersionNotMatched


class TestDOTUtils(TestCase):
def setUp(self):
self.token = MagicMock()
self.code = 'test_code'
self.prior_include_samhsa = False

def test_get_api_version_number(self):
for test in SUPPORTED_VERSION_TEST_CASES:
result = get_api_version_number_from_url(test['url_path'])
Expand Down Expand Up @@ -33,3 +44,75 @@ def test_latin_extended_failure(self):

for text in invalid_inputs:
assert not validate_latin_extended_string(text)

@patch('apps.dot_ext.utils.AccessTokenExtension')
@patch('apps.dot_ext.utils.cache')
def test_check_samhsa_cache_and_create_access_token_extension_use_cached_value(
self, mock_cache, mock_access_token_extension
):
"""
When cache has a value for the code and grant_type is NOT refresh_token,
the cached value should be used for include_samhsa.
"""
mock_cache.get.return_value = False

check_samhsa_cache_and_create_access_token_extension(
prior_include_samhsa=False,
code=self.code,
grant_type='authorization_code',
token=self.token,
)

mock_access_token_extension.objects.get_or_create.assert_called_once_with(
access_token=self.token,
include_samhsa=False,
)
mock_cache.delete.assert_called_once_with(f'include_samhsa:{self.code}')

@patch('apps.dot_ext.utils.AccessTokenExtension')
@patch('apps.dot_ext.utils.cache')
def test_check_samhsa_cache_and_create_access_token_extension_no_cache_value(
self, mock_cache, mock_access_token_extension
):
"""
When cache has NO value for the code and grant_type is NOT refresh_token,
include_samhsa should default to True.
"""
mock_cache.get.return_value = None

check_samhsa_cache_and_create_access_token_extension(
prior_include_samhsa=False,
code=self.code,
grant_type='authorization_code',
token=self.token,
)

mock_access_token_extension.objects.get_or_create.assert_called_once_with(
access_token=self.token,
include_samhsa=True,
)
mock_cache.delete.assert_called_once_with(f'include_samhsa:{self.code}')

@patch('apps.dot_ext.utils.AccessTokenExtension')
@patch('apps.dot_ext.utils.cache')
def test_check_samhsa_cache_and_create_access_token_extension_refresh_token_grant(
self, mock_cache, mock_access_token_extension
):
"""
When grant_type is 'refresh_token', prior_include_samhsa=False
should override any cache value.
"""
mock_cache.get.return_value = True

check_samhsa_cache_and_create_access_token_extension(
prior_include_samhsa=False,
code=self.code,
grant_type='refresh_token',
token=self.token,
)

mock_access_token_extension.objects.get_or_create.assert_called_once_with(
access_token=self.token,
include_samhsa=False,
)
mock_cache.delete.assert_called_once_with(f'include_samhsa:{self.code}')
31 changes: 30 additions & 1 deletion apps/dot_ext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jwt
from django.contrib.auth import get_user_model
from django.core.cache import cache
from django.db import transaction
from django.http import HttpRequest
from django.http.response import JsonResponse
Expand All @@ -23,7 +24,7 @@
HHS_SERVER_LOGNAME_FMT,
)
from apps.dot_ext.constants import APPLICATION_THIRTEEN_MONTH_DATA_ACCESS_NOT_FOUND_MESG
from apps.dot_ext.models import Application
from apps.dot_ext.models import AccessTokenExtension, Application
from apps.versions import VersionNotMatched, Versions

User = get_user_model()
Expand Down Expand Up @@ -326,3 +327,31 @@ def validate_latin_extended_string(text: str) -> bool:
bool: if all strings are encoded less than U+017F (383) and it is not empty
"""
return all(ord(char) <= 383 for char in text) and bool(text)


def check_samhsa_cache_and_create_access_token_extension(
Comment thread
ryan-morosa marked this conversation as resolved.
prior_include_samhsa: bool, code: str, grant_type: str, token: AccessToken
) -> None:
"""Retrieve a value from the cache, if available, for the code being used in the authorization or refresh request

Args:
prior_include_samhsa (bool): The value the prior access_token_extension record had for include_samhsa
code (str): The code for the auth or refresh request, used to retrieve cached value
grant_type (str): Grant type of the call to TokenView.post
token (AccessToken): The access token that was generated
"""
include_samhsa = True

# This was evaluating even if the cache had False for the value, but modifying the conditional like this
# allowed for better unit test coverage
if cache.get(f'include_samhsa:{code}') is not None and cache.get(f'include_samhsa:{code}') != '':
Comment thread
ryan-morosa marked this conversation as resolved.
include_samhsa = cache.get(f'include_samhsa:{code}')

if grant_type == 'refresh_token':
include_samhsa = prior_include_samhsa

AccessTokenExtension.objects.get_or_create(
access_token=token,
include_samhsa=include_samhsa,
)
cache.delete(f'include_samhsa:{code}')
Loading
Loading