Skip to content

Commit e1b0b7b

Browse files
BB2-4675: Add SAMHSA checkbox to v3 permissions screen (#1607)
* BB2-4675: Add SAMHSA checkbox to v3 permissions screen * Use cache.add instead of cache.set, modify default for code * Address PR feedback - add docstrings * Add a new table to track user SAMHSA preferences rather than use caching * Address PR feedback * Remove commented out line * Address copilot feedback * Convert share_samhsa_data to a BooleanField to make the code less confusing
1 parent 004ffd7 commit e1b0b7b

11 files changed

Lines changed: 366 additions & 53 deletions

File tree

apps/dot_ext/forms.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@
33

44
from django import forms
55
from django.conf import settings
6+
from django.contrib.auth.models import Group, User
67
from django.core.exceptions import ValidationError
8+
from django.forms.widgets import URLInput
79
from django.utils.safestring import mark_safe
810
from django.utils.translation import gettext_lazy as _
911
from oauth2_provider.forms import AllowForm as DotAllowForm
1012
from oauth2_provider.models import get_application_model
13+
1114
from apps.accounts.models import UserProfile
1215
from apps.capabilities.models import ProtectedCapability
16+
from apps.constants import HHS_SERVER_LOGNAME_FMT
1317
from apps.dot_ext.constants import BENE_PERSONAL_INFO_SCOPES, PRINTABLE_SPECIAL_ASCII
14-
from apps.dot_ext.scopes import CapabilitiesScopes
1518
from apps.dot_ext.models import Application, InternalApplicationLabels
19+
from apps.dot_ext.scopes import CapabilitiesScopes
1620
from apps.dot_ext.validators import validate_logo_image, validate_notags, validate_url
17-
from django.contrib.auth.models import Group, User
18-
from django.forms.widgets import URLInput
19-
20-
from apps.constants import HHS_SERVER_LOGNAME_FMT
2121

2222
logger = logging.getLogger(HHS_SERVER_LOGNAME_FMT.format(__name__))
2323

@@ -71,8 +71,9 @@ class CustomRegisterApplicationForm(forms.ModelForm):
7171
)
7272

7373
def __init__(self, user, *args, **kwargs):
74-
agree_label = 'Yes I have read and agree to the <a target="_blank" href="%s">API Terms of Service Agreement</a>*' % (
75-
settings.TOS_URI
74+
agree_label = (
75+
'Yes I have read and agree to the <a target="_blank" href="%s">API Terms of Service Agreement</a>*'
76+
% (settings.TOS_URI)
7677
)
7778
super(CustomRegisterApplicationForm, self).__init__(*args, **kwargs)
7879
self.fields['authorization_grant_type'].choices = settings.GRANT_TYPES
@@ -381,6 +382,7 @@ class SimpleAllowForm(DotAllowForm):
381382
code_challenge = forms.CharField(required=False, widget=forms.HiddenInput())
382383
code_challenge_method = forms.CharField(required=False, widget=forms.HiddenInput())
383384
share_demographic_scopes = forms.CharField(required=False)
385+
share_samhsa_data = forms.BooleanField(required=False)
384386

385387
def clean(self):
386388
cleaned_data = super().clean()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Generated by Django 6.0.2 on 2026-06-01 18:52
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
('dot_ext', '0013_delete_expiresin'),
10+
]
11+
12+
operations = [
13+
migrations.CreateModel(
14+
name='AuthFlowTracking',
15+
fields=[
16+
('id', models.BigAutoField(primary_key=True, serialize=False)),
17+
('code', models.CharField(db_index=True, max_length=255, null=True, unique=True)),
18+
('include_samhsa', models.BooleanField(default=True)),
19+
('created', models.DateTimeField(auto_now_add=True)),
20+
('expires', models.DateTimeField()),
21+
],
22+
options={
23+
'db_table': 'dot_ext_auth_flow_tracking',
24+
},
25+
),
26+
]

apps/dot_ext/models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,17 @@ class Meta:
524524
db_table = 'oauth2_provider_accesstoken_extension'
525525

526526

527+
class AuthFlowTracking(models.Model):
528+
id = models.BigAutoField(primary_key=True)
529+
code = models.CharField(max_length=255, null=True, unique=True, db_index=True)
530+
include_samhsa = models.BooleanField(null=False, default=True)
531+
created = models.DateTimeField(auto_now_add=True)
532+
expires = models.DateTimeField()
533+
534+
class Meta:
535+
db_table = 'dot_ext_auth_flow_tracking'
536+
537+
527538
def get_application_counts():
528539
"""
529540
Get the active and inactive counts of applications.

apps/dot_ext/signals.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
22

33
from django.db.models.signals import post_save, pre_save
4-
from django.dispatch import Signal, receiver
4+
from django.dispatch import Signal
55
from oauth2_provider.models import get_access_token_model, get_application_model
66

77
from apps.constants import HHS_SERVER_LOGNAME_FMT
8-
from apps.dot_ext.models import AccessTokenExtension, ArchivedToken
8+
from apps.dot_ext.models import ArchivedToken
99
from libs.decorators import waffle_function_switch
1010
from libs.mail import Mailer
1111

@@ -86,14 +86,3 @@ def outreach_first_api_call(sender, instance=None, **kwargs):
8686

8787
post_save.connect(outreach_first_application, sender=Application)
8888
pre_save.connect(outreach_first_api_call, sender=AccessToken)
89-
90-
91-
@receiver(post_save, sender=AccessToken)
92-
def create_access_token_extension(sender, instance, created, **kwargs):
93-
# TODO: Need to update to take into account what was passed for include_samhsa
94-
# Once the checkbox is in place on v3 permissions screen
95-
if created:
96-
AccessTokenExtension.objects.create(
97-
access_token=instance,
98-
include_samhsa=True,
99-
)

apps/dot_ext/tests/test_authorization_token.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
IDME_HIGHER_ISS,
2929
IDME_LOWER_ISS,
3030
)
31-
from apps.dot_ext.models import Application
31+
from apps.dot_ext.models import AccessTokenExtension, Application
3232
from apps.dot_ext.utils import (
3333
get_application_from_data,
3434
get_application_from_meta,
@@ -265,6 +265,65 @@ def test_validate_environment_for_id_token(self) -> None:
265265
result = view_instance._validate_idme_url_for_id_token_and_environment(IDME_HIGHER_ISS)
266266
assert result
267267

268+
def test_retrieve_prior_include_samhsa_value_non_refresh_token_grant_type(self) -> None:
269+
"""The _retrieve_prior_include_samhsa_value will always return a value of True if the grant_type is
270+
authorization_code, as we only attempt to retrieve the prior include samhsa value if the grant_type is
271+
refresh_token
272+
"""
273+
view_instance = TokenView()
274+
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('authorization_code', None)
275+
assert prior_include_samhsa
276+
277+
@patch('apps.dot_ext.views.authorization.get_refresh_token_model')
278+
@patch('apps.dot_ext.models.AccessTokenExtension.objects.get')
279+
def test_retrieve_prior_include_samhsa_value(self, mock_access_token_extension, mock_refresh_model) -> None:
280+
"""Confirm that if the prior access_token_extension record has include_samhsa set to True, that True is returned"""
281+
view_instance = TokenView()
282+
mock_request = MagicMock()
283+
mock_refresh_token = MagicMock()
284+
mock_request.POST = {
285+
'grant_type': 'refresh_token',
286+
'refresh_token': 'tkn',
287+
}
288+
mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token
289+
mock_access_token_extension.return_value = AccessTokenExtension(include_samhsa=True)
290+
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request)
291+
assert prior_include_samhsa
292+
293+
@patch('apps.dot_ext.views.authorization.get_refresh_token_model')
294+
@patch('apps.dot_ext.models.AccessTokenExtension.objects.get')
295+
def test_retrieve_prior_include_samhsa_value_false(self, mock_access_token_extension, mock_refresh_model) -> None:
296+
"""Confirm that if the prior access_token_extension record has include_samhsa set to False, that False is returned"""
297+
view_instance = TokenView()
298+
mock_request = MagicMock()
299+
mock_refresh_token = MagicMock()
300+
mock_request.POST = {
301+
'grant_type': 'refresh_token',
302+
'refresh_token': 'tkn',
303+
}
304+
mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token
305+
mock_access_token_extension.return_value = AccessTokenExtension(include_samhsa=False)
306+
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request)
307+
assert not prior_include_samhsa
308+
309+
@patch('apps.dot_ext.views.authorization.get_refresh_token_model')
310+
@patch('apps.dot_ext.models.AccessTokenExtension.objects.get')
311+
def test_retrieve_prior_include_samhsa_value_access_token_extension_dne(
312+
self, mock_access_token_extension, mock_refresh_model
313+
) -> None:
314+
"""Confirm that when there is no access_token_extension record returned, a value of True is returned"""
315+
view_instance = TokenView()
316+
mock_request = MagicMock()
317+
mock_refresh_token = MagicMock()
318+
mock_request.POST = {
319+
'grant_type': 'refresh_token',
320+
'refresh_token': 'tkn',
321+
}
322+
mock_refresh_model.return_value.objects.get.return_value = mock_refresh_token
323+
mock_access_token_extension.side_effect = AccessTokenExtension.DoesNotExist
324+
prior_include_samhsa = view_instance._retrieve_prior_include_samhsa_value('refresh_token', mock_request)
325+
assert prior_include_samhsa
326+
268327

269328
# we set empty GET/META/POST because get_application_from_data does not like it if a GET is missing.
270329
class TestClientIdExtraction(BaseApiTest):

apps/dot_ext/tests/test_models.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -287,23 +287,6 @@ def test_internal_application_labels_admin(self):
287287
self.assertTrue(l5.slug in internal_labels)
288288
self.assertTrue(l11.slug not in internal_labels)
289289

290-
def test_access_token_extension_is_created(self) -> None:
291-
"""Ensure that when an access token is saved, a corresponding AccessTokenExtension record
292-
is created
293-
"""
294-
295-
first_access_token = self.create_token(
296-
'John', 'Smith', fhir_id_v2=DEFAULT_SAMPLE_FHIR_ID_V2, fhir_id_v3=DEFAULT_SAMPLE_FHIR_ID_V3
297-
)
298-
ac = AccessToken.objects.get(token=first_access_token)
299-
ac.scope = 'patient/Coverage.search patient/Patient.search patient/ExplanationOfBenefit.search'
300-
ac.save()
301-
access_token_extension = AccessTokenExtension.objects.get(access_token=ac)
302-
303-
assert access_token_extension is not None
304-
assert access_token_extension.access_token == ac
305-
assert access_token_extension.include_samhsa
306-
307290
def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None:
308291
"""Ensure that when an access token is deleted, the corresponding AccessTokenExtension record
309292
is deleted
@@ -315,8 +298,11 @@ def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None:
315298
ac = AccessToken.objects.get(token=first_access_token)
316299
ac.scope = 'patient/Coverage.search patient/Patient.search patient/ExplanationOfBenefit.search'
317300
ac.save()
318-
access_token_extension = AccessTokenExtension.objects.get(access_token=ac)
319-
access_token_extension_id = access_token_extension.id
301+
302+
access_token_extension = AccessTokenExtension()
303+
access_token_extension.access_token = ac
304+
access_token_extension.include_samhsa = True
305+
access_token_extension.save()
320306

321307
assert access_token_extension is not None
322308
assert access_token_extension.access_token == ac
@@ -325,4 +311,4 @@ def test_access_token_extension_is_deleted_when_token_is_deleted(self) -> None:
325311
ac.delete()
326312

327313
with self.assertRaises(AccessTokenExtension.DoesNotExist):
328-
access_token_extension = AccessTokenExtension.objects.get(id=access_token_extension_id)
314+
access_token_extension = AccessTokenExtension.objects.get(id=access_token_extension.id)

apps/dot_ext/tests/test_utils.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from datetime import UTC, datetime
12
from unittest.mock import MagicMock, patch
23

34
from django.test import TestCase
45

56
from apps.dot_ext.constants import SUPPORTED_VERSION_TEST_CASES
7+
from apps.dot_ext.models import AuthFlowTracking
68
from apps.dot_ext.utils import (
9+
check_auth_tracking_and_create_access_token_extension,
710
get_api_version_number_from_url,
811
remove_application_user_pair_tokens_data_access,
912
validate_latin_extended_string,
@@ -12,6 +15,11 @@
1215

1316

1417
class TestDOTUtils(TestCase):
18+
def setUp(self):
19+
self.token = MagicMock()
20+
self.code = 'test_code'
21+
self.prior_include_samhsa = False
22+
1523
def test_get_api_version_number(self):
1624
for test in SUPPORTED_VERSION_TEST_CASES:
1725
result = get_api_version_number_from_url(test['url_path'])
@@ -40,6 +48,112 @@ def test_latin_extended_failure(self):
4048
for text in invalid_inputs:
4149
assert not validate_latin_extended_string(text)
4250

51+
@patch('apps.dot_ext.utils.AccessTokenExtension')
52+
@patch('apps.dot_ext.utils.AuthFlowTracking.objects.get')
53+
def test_check_auth_tracking_and_create_access_token_extension_use_database_value(
54+
self, mock_auth_flow_tracking, mock_access_token_extension
55+
):
56+
"""
57+
When dot_ext_auth_flow_tracking has a record for the code and grant_type is NOT refresh_token,
58+
the dot_ext_auth_flow_tracking.include_samhsa value should be used for include_samhsa.
59+
"""
60+
tracking_object = AuthFlowTracking.objects.create(
61+
code=self.code,
62+
include_samhsa=False,
63+
expires=datetime.now(UTC),
64+
)
65+
mock_auth_flow_tracking.return_value = tracking_object
66+
67+
check_auth_tracking_and_create_access_token_extension(
68+
prior_include_samhsa=False,
69+
code=self.code,
70+
grant_type='authorization_code',
71+
token=self.token,
72+
)
73+
74+
mock_access_token_extension.objects.get_or_create.assert_called_once_with(
75+
access_token=self.token,
76+
include_samhsa=False,
77+
)
78+
79+
@patch('apps.dot_ext.utils.AccessTokenExtension')
80+
@patch('apps.dot_ext.utils.AuthFlowTracking.objects.get')
81+
def test_check_auth_tracking_and_create_access_token_extension_use_database_value_true(
82+
self, mock_auth_flow_tracking, mock_access_token_extension
83+
):
84+
"""
85+
When dot_ext_auth_flow_tracking has a record for the code and grant_type is NOT refresh_token,
86+
the dot_ext_auth_flow_tracking.include_samhsa value should be used for include_samhsa.
87+
"""
88+
tracking_object = AuthFlowTracking.objects.create(
89+
code=self.code,
90+
include_samhsa=True,
91+
expires=datetime.now(UTC),
92+
)
93+
mock_auth_flow_tracking.return_value = tracking_object
94+
95+
check_auth_tracking_and_create_access_token_extension(
96+
prior_include_samhsa=True,
97+
code=self.code,
98+
grant_type='authorization_code',
99+
token=self.token,
100+
)
101+
102+
mock_access_token_extension.objects.get_or_create.assert_called_once_with(
103+
access_token=self.token,
104+
include_samhsa=True,
105+
)
106+
107+
@patch('apps.dot_ext.utils.AccessTokenExtension')
108+
@patch('apps.dot_ext.utils.AuthFlowTracking.objects.get')
109+
def test_check_auth_tracking_and_create_access_token_extension_no_database_value(
110+
self, mock_auth_flow_tracking, mock_access_token_extension
111+
):
112+
"""
113+
When there is no dot_ext_auth_flow_tracking record, just use the default of True
114+
"""
115+
mock_auth_flow_tracking.side_effect = AuthFlowTracking.DoesNotExist
116+
117+
check_auth_tracking_and_create_access_token_extension(
118+
prior_include_samhsa=True,
119+
code=self.code,
120+
grant_type='authorization_code',
121+
token=self.token,
122+
)
123+
124+
mock_access_token_extension.objects.get_or_create.assert_called_once_with(
125+
access_token=self.token,
126+
include_samhsa=True,
127+
)
128+
129+
@patch('apps.dot_ext.utils.AccessTokenExtension')
130+
@patch('apps.dot_ext.utils.AuthFlowTracking.objects.get')
131+
def test_check_auth_tracking_and_create_access_token_extension_refresh_token_grant(
132+
self, mock_auth_flow_tracking, mock_access_token_extension
133+
):
134+
"""
135+
When grant_type is 'refresh_token', prior_include_samhsa=False
136+
should override any dot_ext_auth_flow_tracking record include_samhsa value.
137+
"""
138+
tracking_object = AuthFlowTracking.objects.create(
139+
code=self.code,
140+
include_samhsa=True,
141+
expires=datetime.now(UTC),
142+
)
143+
mock_auth_flow_tracking.return_value = tracking_object
144+
145+
check_auth_tracking_and_create_access_token_extension(
146+
prior_include_samhsa=False,
147+
code=self.code,
148+
grant_type='refresh_token',
149+
token=self.token,
150+
)
151+
152+
mock_access_token_extension.objects.get_or_create.assert_called_once_with(
153+
access_token=self.token,
154+
include_samhsa=False,
155+
)
156+
43157
@patch('apps.dot_ext.utils.AccessToken')
44158
@patch('apps.dot_ext.utils.DataAccessGrant')
45159
@patch('apps.dot_ext.utils.RefreshToken')

0 commit comments

Comments
 (0)