Skip to content

Commit f633eb1

Browse files
authored
Merge pull request #96 from NHSDigital/feat/acquire-managed-identity-token
Add ability to present Managed Identity derived bearer token
2 parents 219b95d + e1255a9 commit f633eb1

4 files changed

Lines changed: 129 additions & 12 deletions

File tree

src/environment.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
from enum import Enum
3+
4+
5+
class Envs(Enum):
6+
DEVELOPMENT = "dev"
7+
REVIEW = "review"
8+
PREPROD = "preprod"
9+
PRODUCTION = "prod"
10+
11+
12+
class Environment:
13+
@property
14+
def development(self) -> bool:
15+
return self.environment == Envs.DEVELOPMENT.value
16+
17+
@property
18+
def production(self) -> bool:
19+
return self.environment == Envs.PRODUCTION.value
20+
21+
@property
22+
def review(self) -> bool:
23+
return self.environment == Envs.REVIEW.value
24+
25+
@property
26+
def preprod(self) -> bool:
27+
return self.environment == Envs.PREPROD.value
28+
29+
@property
30+
def environment(self) -> str:
31+
env = os.getenv("ENVIRONMENT")
32+
if not env or env.lower() not in (e.value for e in Envs):
33+
return Envs.DEVELOPMENT.value
34+
else:
35+
return os.getenv("ENVIRONMENT", Envs.DEVELOPMENT.value).lower()

src/services/dicom/dicom_uploader.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from typing import Optional
1111

1212
import requests
13+
from azure.identity import ManagedIdentityCredential
14+
15+
from environment import Environment
1316

1417
logger = logging.getLogger(__name__)
1518

@@ -20,11 +23,6 @@ def __init__(self, api_endpoint: str | None = None, timeout: int = 30, verify_ss
2023
self.timeout = timeout
2124
self.verify_ssl = verify_ssl
2225

23-
def headers(self) -> dict:
24-
return {
25-
"Authorization": f"Bearer {os.getenv('CLOUD_API_TOKEN', '')}",
26-
}
27-
2826
def upload_dicom(self, sop_instance_uid: str, dicom_stream: io.BufferedReader, action_id: Optional[str]) -> bool:
2927
if not action_id:
3028
logger.error(f"No action_id for {sop_instance_uid}, upload will be rejected by server")
@@ -42,7 +40,7 @@ def upload_dicom(self, sop_instance_uid: str, dicom_stream: io.BufferedReader, a
4240
files=files,
4341
timeout=self.timeout,
4442
verify=self.verify_ssl,
45-
headers=self.headers(),
43+
headers=self.headers,
4644
)
4745

4846
if response.status_code == 201:
@@ -60,3 +58,17 @@ def upload_dicom(self, sop_instance_uid: str, dicom_stream: io.BufferedReader, a
6058
except requests.exceptions.RequestException as e:
6159
logger.error(f"Upload error for {sop_instance_uid}: {e}", exc_info=True)
6260
return False
61+
62+
@property
63+
def headers(self) -> dict:
64+
return {
65+
"Authorization": f"Bearer {self.access_token}",
66+
}
67+
68+
@property
69+
def access_token(self) -> str | None:
70+
resource = os.getenv("CLOUD_API_RESOURCE", "")
71+
if resource or Environment().production:
72+
return ManagedIdentityCredential().get_token(resource).token
73+
else:
74+
return os.getenv("CLOUD_API_TOKEN", "")

tests/services/dicom/test_dicom_uploader.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from services.dicom.dicom_uploader import DICOMUploader
1010

1111

12+
@patch("services.dicom.dicom_uploader.requests.put")
1213
class TestDICOMUploader:
1314
@pytest.fixture
1415
def dicom_file(self):
@@ -17,7 +18,6 @@ def dicom_file(self):
1718
tf.close()
1819
yield tf.name
1920

20-
@patch("services.dicom.dicom_uploader.requests.put")
2121
def test_upload_success(self, mock_put, dicom_file):
2222
mock_response = Mock()
2323
mock_response.status_code = 201
@@ -39,7 +39,7 @@ def test_upload_success(self, mock_put, dicom_file):
3939
files=mock_put.call_args[1]["files"],
4040
timeout=30,
4141
verify=True,
42-
headers=uploader.headers(),
42+
headers=uploader.headers,
4343
)
4444

4545
call_kwargs = mock_put.call_args[1]
@@ -49,14 +49,13 @@ def test_upload_success(self, mock_put, dicom_file):
4949
assert isinstance(file_tuple[1], io.BufferedReader)
5050
assert file_tuple[1].read() == open(dicom_file, "rb").read()
5151

52-
def test_upload_without_action_id(self, dicom_file):
52+
def test_upload_without_action_id(self, _, dicom_file):
5353
"""Upload without action_id does not make request."""
5454
uploader = DICOMUploader()
5555
result = uploader.upload_dicom(sop_instance_uid="1.2.3", dicom_stream=open(dicom_file, "rb"), action_id=None)
5656

5757
assert result is False
5858

59-
@patch("services.dicom.dicom_uploader.requests.put")
6059
def test_upload_failure_status_code(self, mock_put, dicom_file):
6160
mock_response = Mock()
6261
mock_response.status_code = 500
@@ -68,7 +67,6 @@ def test_upload_failure_status_code(self, mock_put, dicom_file):
6867

6968
assert result is False
7069

71-
@patch("services.dicom.dicom_uploader.requests.put")
7270
def test_upload_timeout(self, mock_put, dicom_file):
7371
mock_put.side_effect = requests.exceptions.Timeout()
7472

@@ -77,11 +75,43 @@ def test_upload_timeout(self, mock_put, dicom_file):
7775

7876
assert result is False
7977

80-
@patch("services.dicom.dicom_uploader.requests.put")
8178
def test_upload_network_error(self, mock_put, dicom_file):
8279
mock_put.side_effect = requests.exceptions.ConnectionError()
8380

8481
uploader = DICOMUploader()
8582
result = uploader.upload_dicom("1.2.3", open(dicom_file, "rb"), None)
8683

8784
assert result is False
85+
86+
def test_upload_headers_with_managed_identity_access_token(self, _, monkeypatch):
87+
"""Test that headers include access token from ManagedIdentityCredential."""
88+
monkeypatch.setenv("CLOUD_API_RESOURCE", "https://example.com/.default")
89+
with patch("services.dicom.dicom_uploader.ManagedIdentityCredential") as mock_credential:
90+
mock_credential_instance = Mock()
91+
mock_credential_instance.get_token.return_value.token = "fake_access_token"
92+
mock_credential.return_value = mock_credential_instance
93+
94+
assert DICOMUploader().headers == {"Authorization": "Bearer fake_access_token"}
95+
96+
def test_upload_headers_without_managed_identity_resource(self, _, monkeypatch):
97+
"""Test that headers include CLOUD_API_TOKEN if CLOUD_API_RESOURCE is not set."""
98+
monkeypatch.setenv("CLOUD_API_TOKEN", "env_access_token")
99+
100+
assert DICOMUploader().headers == {"Authorization": "Bearer env_access_token"}
101+
102+
def test_upload_headers_in_production_with_no_cloud_api_resource(self, _, monkeypatch):
103+
"""Test that headers include access token from ManagedIdentityCredential in production even if CLOUD_API_RESOURCE is not set."""
104+
monkeypatch.setenv("ENVIRONMENT", "prod")
105+
with patch("services.dicom.dicom_uploader.ManagedIdentityCredential") as mock_credential:
106+
mock_credential_instance = Mock()
107+
mock_credential_instance.get_token.return_value.token = "prod_access_token"
108+
mock_credential.return_value = mock_credential_instance
109+
110+
assert DICOMUploader().headers == {"Authorization": "Bearer prod_access_token"}
111+
assert mock_credential_instance.get_token.call_args[0][0] == ""
112+
113+
def test_upload_headers_without_any_token(self, _, monkeypatch):
114+
"""Test that headers include empty token if neither CLOUD_API_RESOURCE nor CLOUD_API_TOKEN is set."""
115+
monkeypatch.delenv("CLOUD_API_RESOURCE", raising=False)
116+
monkeypatch.delenv("CLOUD_API_TOKEN", raising=False)
117+
assert DICOMUploader().headers == {"Authorization": "Bearer "}

tests/test_environment.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from environment import Environment
2+
3+
4+
class TestEnvironment:
5+
def test_environment(self, monkeypatch):
6+
7+
env = Environment()
8+
9+
# Default should be development
10+
assert env.development
11+
assert not env.production
12+
assert not env.review
13+
assert not env.preprod
14+
15+
# Test production environment
16+
monkeypatch.setenv("ENVIRONMENT", "prod")
17+
assert env.production
18+
assert not env.development
19+
assert not env.review
20+
assert not env.preprod
21+
22+
# Test review environment
23+
monkeypatch.setenv("ENVIRONMENT", "review")
24+
assert env.review
25+
assert not env.development
26+
assert not env.production
27+
assert not env.preprod
28+
29+
# Test preprod environment
30+
monkeypatch.setenv("ENVIRONMENT", "preprod")
31+
assert env.preprod
32+
assert not env.development
33+
assert not env.production
34+
35+
# Test unknown environment defaults to development
36+
monkeypatch.setenv("ENVIRONMENT", "unknown")
37+
assert env.development
38+
assert not env.production
39+
assert not env.review
40+
assert not env.preprod

0 commit comments

Comments
 (0)