Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions apps/dot_ext/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@

from django import forms
from django.conf import settings
from django.contrib.auth.models import Group, User
from django.core.exceptions import ValidationError
from django.forms.widgets import URLInput
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _
from oauth2_provider.forms import AllowForm as DotAllowForm
from oauth2_provider.models import get_application_model

from apps.accounts.models import UserProfile
from apps.capabilities.models import ProtectedCapability
from apps.constants import HHS_SERVER_LOGNAME_FMT
from apps.dot_ext.constants import BENE_PERSONAL_INFO_SCOPES, PRINTABLE_SPECIAL_ASCII
from apps.dot_ext.models import Application, InternalApplicationLabels
from apps.dot_ext.scopes import CapabilitiesScopes
from apps.dot_ext.models import Application, InternalApplicationLabels
from apps.dot_ext.validators import validate_logo_image, validate_notags, validate_url
from django.contrib.auth.models import Group, User
from django.forms.widgets import URLInput

from apps.constants import HHS_SERVER_LOGNAME_FMT

logger = logging.getLogger(HHS_SERVER_LOGNAME_FMT.format(__name__))

Expand Down Expand Up @@ -71,9 +71,8 @@ class CustomRegisterApplicationForm(forms.ModelForm):
)

def __init__(self, user, *args, **kwargs):
agree_label = (
'Yes I have read and agree to the <a target="_blank" href="%s">API Terms of Service Agreement</a>*'
% (settings.TOS_URI)
agree_label = 'Yes I have read and agree to the <a target="_blank" href="%s">API Terms of Service Agreement</a>*' % (
settings.TOS_URI
)
super(CustomRegisterApplicationForm, self).__init__(*args, **kwargs)
self.fields['authorization_grant_type'].choices = settings.GRANT_TYPES
Expand Down Expand Up @@ -382,7 +381,6 @@ class SimpleAllowForm(DotAllowForm):
code_challenge = forms.CharField(required=False, widget=forms.HiddenInput())
code_challenge_method = forms.CharField(required=False, widget=forms.HiddenInput())
share_demographic_scopes = forms.CharField(required=False)
share_samhsa_data = forms.BooleanField(required=False)

def clean(self):
cleaned_data = super().clean()
Expand Down
26 changes: 0 additions & 26 deletions apps/dot_ext/migrations/0014_authflowtracking.py

This file was deleted.

11 changes: 0 additions & 11 deletions apps/dot_ext/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,17 +524,6 @@ class Meta:
db_table = 'oauth2_provider_accesstoken_extension'


class AuthFlowTracking(models.Model):
id = models.BigAutoField(primary_key=True)
code = models.CharField(max_length=255, null=True, unique=True, db_index=True)
include_samhsa = models.BooleanField(null=False, default=True)
created = models.DateTimeField(auto_now_add=True)
expires = models.DateTimeField()

class Meta:
db_table = 'dot_ext_auth_flow_tracking'


def get_application_counts():
"""
Get the active and inactive counts of applications.
Expand Down
15 changes: 13 additions & 2 deletions apps/dot_ext/signals.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging

from django.db.models.signals import post_save, pre_save
from django.dispatch import Signal
from django.dispatch import Signal, receiver
from oauth2_provider.models import get_access_token_model, get_application_model

from apps.constants import HHS_SERVER_LOGNAME_FMT
from apps.dot_ext.models import ArchivedToken
from apps.dot_ext.models import AccessTokenExtension, ArchivedToken
from libs.decorators import waffle_function_switch
from libs.mail import Mailer

Expand Down Expand Up @@ -86,3 +86,14 @@ def outreach_first_api_call(sender, instance=None, **kwargs):

post_save.connect(outreach_first_application, sender=Application)
pre_save.connect(outreach_first_api_call, sender=AccessToken)


@receiver(post_save, sender=AccessToken)
def create_access_token_extension(sender, instance, created, **kwargs):
# TODO: Need to update to take into account what was passed for include_samhsa
# Once the checkbox is in place on v3 permissions screen
if created:
AccessTokenExtension.objects.create(
access_token=instance,
include_samhsa=True,
)
61 changes: 1 addition & 60 deletions apps/dot_ext/tests/test_authorization_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
IDME_HIGHER_ISS,
IDME_LOWER_ISS,
)
from apps.dot_ext.models import AccessTokenExtension, Application
from apps.dot_ext.models import Application
from apps.dot_ext.utils import (
get_application_from_data,
get_application_from_meta,
Expand Down Expand Up @@ -265,65 +265,6 @@ def test_validate_environment_for_id_token(self) -> None:
result = view_instance._validate_idme_url_for_id_token_and_environment(IDME_HIGHER_ISS)
assert result

def test_retrieve_prior_include_samhsa_value_non_refresh_token_grant_type(self) -> None:
"""The _retrieve_prior_include_samhsa_value will always return a value of True if the grant_type is
authorization_code, as we only attempt to retrieve the prior include samhsa value if the grant_type is
refresh_token
"""
view_instance = TokenView()
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('authorization_code', None)
assert prior_include_samhsa

@patch('apps.dot_ext.views.authorization.get_refresh_token_model')
@patch('apps.dot_ext.models.AccessTokenExtension.objects.get')
def test_retrieve_prior_include_samhsa_value(self, mock_access_token_extension, mock_refresh_model) -> None:
"""Confirm that if the prior access_token_extension record has include_samhsa set to True, that True is returned"""
view_instance = TokenView()
mock_request = MagicMock()
mock_refresh_token = MagicMock()
mock_request.POST = {
'grant_type': 'refresh_token',
'refresh_token': 'tkn',
}
mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token
mock_access_token_extension.return_value = AccessTokenExtension(include_samhsa=True)
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request)
assert prior_include_samhsa

@patch('apps.dot_ext.views.authorization.get_refresh_token_model')
@patch('apps.dot_ext.models.AccessTokenExtension.objects.get')
def test_retrieve_prior_include_samhsa_value_false(self, mock_access_token_extension, mock_refresh_model) -> None:
"""Confirm that if the prior access_token_extension record has include_samhsa set to False, that False is returned"""
view_instance = TokenView()
mock_request = MagicMock()
mock_refresh_token = MagicMock()
mock_request.POST = {
'grant_type': 'refresh_token',
'refresh_token': 'tkn',
}
mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token
mock_access_token_extension.return_value = AccessTokenExtension(include_samhsa=False)
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request)
assert not prior_include_samhsa

@patch('apps.dot_ext.views.authorization.get_refresh_token_model')
@patch('apps.dot_ext.models.AccessTokenExtension.objects.get')
def test_retrieve_prior_include_samhsa_value_access_token_extension_dne(
self, mock_access_token_extension, mock_refresh_model
) -> None:
"""Confirm that when there is no access_token_extension record returned, a value of True is returned"""
view_instance = TokenView()
mock_request = MagicMock()
mock_refresh_token = MagicMock()
mock_request.POST = {
'grant_type': 'refresh_token',
'refresh_token': 'tkn',
}
mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token
mock_access_token_extension.side_effect = AccessTokenExtension.DoesNotExist
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request)
assert prior_include_samhsa


# we set empty GET/META/POST because get_application_from_data does not like it if a GET is missing.
class TestClientIdExtraction(BaseApiTest):
Expand Down
26 changes: 20 additions & 6 deletions apps/dot_ext/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,23 @@ def test_internal_application_labels_admin(self):
self.assertTrue(l5.slug in internal_labels)
self.assertTrue(l11.slug not in internal_labels)

def test_access_token_extension_is_created(self) -> None:
"""Ensure that when an access token is saved, a corresponding AccessTokenExtension record
is created
"""

first_access_token = self.create_token(
'John', 'Smith', fhir_id_v2=DEFAULT_SAMPLE_FHIR_ID_V2, fhir_id_v3=DEFAULT_SAMPLE_FHIR_ID_V3
)
ac = AccessToken.objects.get(token=first_access_token)
ac.scope = 'patient/Coverage.search patient/Patient.search patient/ExplanationOfBenefit.search'
ac.save()
access_token_extension = AccessTokenExtension.objects.get(access_token=ac)

assert access_token_extension is not None
assert access_token_extension.access_token == ac
assert access_token_extension.include_samhsa

def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None:
"""Ensure that when an access token is deleted, the corresponding AccessTokenExtension record
is deleted
Expand All @@ -298,11 +315,8 @@ def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None:
ac = AccessToken.objects.get(token=first_access_token)
ac.scope = 'patient/Coverage.search patient/Patient.search patient/ExplanationOfBenefit.search'
ac.save()

access_token_extension = AccessTokenExtension()
access_token_extension.access_token = ac
access_token_extension.include_samhsa = True
access_token_extension.save()
access_token_extension = AccessTokenExtension.objects.get(access_token=ac)
access_token_extension_id = access_token_extension.id

assert access_token_extension is not None
assert access_token_extension.access_token == ac
Expand All @@ -311,4 +325,4 @@ def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None:
ac.delete()

with self.assertRaises(AccessTokenExtension.DoesNotExist):
access_token_extension = AccessTokenExtension.objects.get(id=access_token_extension.id)
access_token_extension = AccessTokenExtension.objects.get(id=access_token_extension_id)
114 changes: 0 additions & 114 deletions apps/dot_ext/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch

from django.test import TestCase

from apps.dot_ext.constants import SUPPORTED_VERSION_TEST_CASES
from apps.dot_ext.models import AuthFlowTracking
from apps.dot_ext.utils import (
check_auth_tracking_and_create_access_token_extension,
get_api_version_number_from_url,
remove_application_user_pair_tokens_data_access,
validate_latin_extended_string,
Expand All @@ -15,11 +12,6 @@


class TestDOTUtils(TestCase):
def setUp(self):
self.token = MagicMock()
self.code = 'test_code'
self.prior_include_samhsa = False

def test_get_api_version_number(self):
for test in SUPPORTED_VERSION_TEST_CASES:
result = get_api_version_number_from_url(test['url_path'])
Expand Down Expand Up @@ -48,112 +40,6 @@ def test_latin_extended_failure(self):
for text in invalid_inputs:
assert not validate_latin_extended_string(text)

@patch('apps.dot_ext.utils.AccessTokenExtension')
@patch('apps.dot_ext.utils.AuthFlowTracking.objects.get')
def test_check_auth_tracking_and_create_access_token_extension_use_database_value(
self, mock_auth_flow_tracking, mock_access_token_extension
):
"""
When dot_ext_auth_flow_tracking has a record for the code and grant_type is NOT refresh_token,
the dot_ext_auth_flow_tracking.include_samhsa value should be used for include_samhsa.
"""
tracking_object = AuthFlowTracking.objects.create(
code=self.code,
include_samhsa=False,
expires=datetime.now(UTC),
)
mock_auth_flow_tracking.return_value = tracking_object

check_auth_tracking_and_create_access_token_extension(
prior_include_samhsa=False,
code=self.code,
grant_type='authorization_code',
token=self.token,
)

mock_access_token_extension.objects.get_or_create.assert_called_once_with(
access_token=self.token,
include_samhsa=False,
)

@patch('apps.dot_ext.utils.AccessTokenExtension')
@patch('apps.dot_ext.utils.AuthFlowTracking.objects.get')
def test_check_auth_tracking_and_create_access_token_extension_use_database_value_true(
self, mock_auth_flow_tracking, mock_access_token_extension
):
"""
When dot_ext_auth_flow_tracking has a record for the code and grant_type is NOT refresh_token,
the dot_ext_auth_flow_tracking.include_samhsa value should be used for include_samhsa.
"""
tracking_object = AuthFlowTracking.objects.create(
code=self.code,
include_samhsa=True,
expires=datetime.now(UTC),
)
mock_auth_flow_tracking.return_value = tracking_object

check_auth_tracking_and_create_access_token_extension(
prior_include_samhsa=True,
code=self.code,
grant_type='authorization_code',
token=self.token,
)

mock_access_token_extension.objects.get_or_create.assert_called_once_with(
access_token=self.token,
include_samhsa=True,
)

@patch('apps.dot_ext.utils.AccessTokenExtension')
@patch('apps.dot_ext.utils.AuthFlowTracking.objects.get')
def test_check_auth_tracking_and_create_access_token_extension_no_database_value(
self, mock_auth_flow_tracking, mock_access_token_extension
):
"""
When there is no dot_ext_auth_flow_tracking record, just use the default of True
"""
mock_auth_flow_tracking.side_effect = AuthFlowTracking.DoesNotExist

check_auth_tracking_and_create_access_token_extension(
prior_include_samhsa=True,
code=self.code,
grant_type='authorization_code',
token=self.token,
)

mock_access_token_extension.objects.get_or_create.assert_called_once_with(
access_token=self.token,
include_samhsa=True,
)

@patch('apps.dot_ext.utils.AccessTokenExtension')
@patch('apps.dot_ext.utils.AuthFlowTracking.objects.get')
def test_check_auth_tracking_and_create_access_token_extension_refresh_token_grant(
self, mock_auth_flow_tracking, mock_access_token_extension
):
"""
When grant_type is 'refresh_token', prior_include_samhsa=False
should override any dot_ext_auth_flow_tracking record include_samhsa value.
"""
tracking_object = AuthFlowTracking.objects.create(
code=self.code,
include_samhsa=True,
expires=datetime.now(UTC),
)
mock_auth_flow_tracking.return_value = tracking_object

check_auth_tracking_and_create_access_token_extension(
prior_include_samhsa=False,
code=self.code,
grant_type='refresh_token',
token=self.token,
)

mock_access_token_extension.objects.get_or_create.assert_called_once_with(
access_token=self.token,
include_samhsa=False,
)

@patch('apps.dot_ext.utils.AccessToken')
@patch('apps.dot_ext.utils.DataAccessGrant')
@patch('apps.dot_ext.utils.RefreshToken')
Expand Down
Loading
Loading