Skip to content

Commit 1704dd8

Browse files
committed
add unit tests to mtls_helper
1 parent 41282b8 commit 1704dd8

4 files changed

Lines changed: 289 additions & 5 deletions

File tree

packages/google-auth/google/auth/transport/_mtls_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,15 +219,15 @@ def _memfd_cert_key_paths(
219219
if cert_bytes is not None:
220220
# MFD_CLOEXEC prevents FD leaks to spawned subprocesses.
221221
fd_cert = os.memfd_create("mtls_cert", os.MFD_CLOEXEC)
222+
cleanup_fds.append(fd_cert)
222223
os.write(fd_cert, cert_bytes)
223224
cert_path = f"/proc/self/fd/{fd_cert}"
224-
cleanup_fds.append(fd_cert)
225225

226226
if key_bytes is not None:
227227
fd_key = os.memfd_create("mtls_key", os.MFD_CLOEXEC)
228+
cleanup_fds.append(fd_key)
228229
os.write(fd_key, key_bytes)
229230
key_path = f"/proc/self/fd/{fd_key}"
230-
cleanup_fds.append(fd_key)
231231

232232
yield cert_path, key_path
233233
finally:

packages/google-auth/google/auth/transport/requests.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,14 @@ def __init__(self, cert, key):
224224
):
225225
try:
226226
ctx_poolmanager.load_cert_chain(
227-
certfile=cert_path, keyfile=key_path, password=passphrase
227+
certfile=cert_path,
228+
keyfile=key_path,
229+
password=passphrase or "",
228230
)
229231
ctx_proxymanager.load_cert_chain(
230-
certfile=cert_path, keyfile=key_path, password=passphrase
232+
certfile=cert_path,
233+
keyfile=key_path,
234+
password=passphrase or "",
231235
)
232236
except (ssl.SSLError, OSError, IOError, ValueError, RuntimeError) as exc:
233237
raise exceptions.MutualTLSChannelError(

packages/google-auth/google/auth/transport/urllib3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ def _make_mutual_tls_http(cert, key):
189189
):
190190
try:
191191
ctx.load_cert_chain(
192-
certfile=cert_path, keyfile=key_path, password=passphrase
192+
certfile=cert_path,
193+
keyfile=key_path,
194+
password=passphrase or "",
193195
)
194196
except (ssl.SSLError, OSError, IOError, ValueError, RuntimeError) as exc:
195197
raise exceptions.MutualTLSChannelError(

packages/google-auth/tests/transport/test__mtls_helper.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import os
1616
import re
17+
import sys
18+
import tempfile
1719
from unittest import mock
1820

1921
from cryptography.hazmat.primitives import hashes, serialization
@@ -992,3 +994,279 @@ def test_call_client_cert_callback(self, mock_get_client_ssl_credentials):
992994
mock_get_client_ssl_credentials.assert_called_once_with(
993995
generate_encrypted_key=True
994996
)
997+
998+
999+
class TestSecureCertKeyPaths(object):
1000+
def test_tier1_pass_through(self):
1001+
with _mtls_helper.secure_cert_key_paths(
1002+
"/path/to/cert", "/path/to/key", b"passphrase"
1003+
) as (cert_path, key_path, passphrase):
1004+
assert cert_path == "/path/to/cert"
1005+
assert key_path == "/path/to/key"
1006+
assert passphrase == b"passphrase"
1007+
1008+
@mock.patch.object(sys, "platform", "linux")
1009+
@mock.patch.object(os, "memfd_create", create=True)
1010+
@mock.patch.object(_mtls_helper, "_memfd_cert_key_paths", autospec=True)
1011+
def test_tier2_memfd_success(self, mock_memfd_cm, mock_memfd_create):
1012+
mock_memfd_ctx = mock.MagicMock()
1013+
mock_memfd_ctx.__enter__.return_value = (
1014+
"/proc/self/fd/3",
1015+
"/proc/self/fd/4",
1016+
)
1017+
mock_memfd_cm.return_value = mock_memfd_ctx
1018+
1019+
with mock.patch.object(os.path, "exists", return_value=True):
1020+
with _mtls_helper.secure_cert_key_paths(
1021+
pytest.public_cert_bytes,
1022+
pytest.private_key_bytes,
1023+
b"passphrase",
1024+
) as (cert_path, key_path, passphrase):
1025+
assert cert_path == "/proc/self/fd/3"
1026+
assert key_path == "/proc/self/fd/4"
1027+
assert passphrase == b"passphrase"
1028+
assert mock_memfd_ctx.__exit__.called
1029+
1030+
@mock.patch.object(sys, "platform", "linux")
1031+
@mock.patch.object(os, "memfd_create", create=True)
1032+
@mock.patch.object(_mtls_helper, "_memfd_cert_key_paths", autospec=True)
1033+
@mock.patch.object(_mtls_helper, "_tempfile_cert_key_paths", autospec=True)
1034+
def test_tier2_restricted_filesystem(
1035+
self, mock_tempfile_cm, mock_memfd_cm, mock_memfd_create
1036+
):
1037+
mock_memfd_ctx = mock.MagicMock()
1038+
mock_memfd_ctx.__enter__.return_value = (
1039+
"/proc/self/fd/3",
1040+
"/proc/self/fd/4",
1041+
)
1042+
mock_memfd_cm.return_value = mock_memfd_ctx
1043+
1044+
mock_tempfile_ctx = mock.MagicMock()
1045+
mock_tempfile_ctx.__enter__.return_value = (
1046+
"/tmp/cert",
1047+
"/tmp/key",
1048+
b"new_pass",
1049+
)
1050+
mock_tempfile_cm.return_value = mock_tempfile_ctx
1051+
1052+
with mock.patch.object(os.path, "exists", return_value=False):
1053+
with _mtls_helper.secure_cert_key_paths(
1054+
pytest.public_cert_bytes, pytest.private_key_bytes, b"passphrase"
1055+
) as (cert_path, key_path, passphrase):
1056+
assert cert_path == "/tmp/cert"
1057+
assert key_path == "/tmp/key"
1058+
assert passphrase == b"new_pass"
1059+
mock_memfd_ctx.__exit__.assert_called_once_with(None, None, None)
1060+
1061+
@mock.patch.object(sys, "platform", "linux")
1062+
@mock.patch.object(os, "memfd_create", create=True)
1063+
@mock.patch.object(_mtls_helper, "_memfd_cert_key_paths", autospec=True)
1064+
@mock.patch.object(_mtls_helper, "_tempfile_cert_key_paths", autospec=True)
1065+
def test_tier2_fallback_to_tier3_on_oserror(
1066+
self, mock_tempfile_cm, mock_memfd_cm, mock_memfd_create
1067+
):
1068+
mock_memfd_ctx = mock.MagicMock()
1069+
mock_memfd_ctx.__enter__.side_effect = OSError("memfd failed")
1070+
mock_memfd_cm.return_value = mock_memfd_ctx
1071+
1072+
mock_tempfile_ctx = mock.MagicMock()
1073+
mock_tempfile_ctx.__enter__.return_value = (
1074+
"/tmp/cert",
1075+
"/tmp/key",
1076+
b"new_pass",
1077+
)
1078+
mock_tempfile_cm.return_value = mock_tempfile_ctx
1079+
1080+
with _mtls_helper.secure_cert_key_paths(
1081+
pytest.public_cert_bytes, pytest.private_key_bytes, b"passphrase"
1082+
) as (cert_path, key_path, passphrase):
1083+
assert cert_path == "/tmp/cert"
1084+
assert key_path == "/tmp/key"
1085+
assert passphrase == b"new_pass"
1086+
1087+
@mock.patch.object(sys, "platform", "darwin")
1088+
@mock.patch.object(_mtls_helper, "_tempfile_cert_key_paths", autospec=True)
1089+
def test_tier3_tempfile_success_non_linux(self, mock_tempfile_cm):
1090+
mock_tempfile_ctx = mock.MagicMock()
1091+
mock_tempfile_ctx.__enter__.return_value = (
1092+
"/tmp/cert",
1093+
"/tmp/key",
1094+
b"new_pass",
1095+
)
1096+
mock_tempfile_cm.return_value = mock_tempfile_ctx
1097+
1098+
with _mtls_helper.secure_cert_key_paths(
1099+
pytest.public_cert_bytes, pytest.private_key_bytes, b"passphrase"
1100+
) as (cert_path, key_path, passphrase):
1101+
assert cert_path == "/tmp/cert"
1102+
assert key_path == "/tmp/key"
1103+
assert passphrase == b"new_pass"
1104+
1105+
@mock.patch.object(sys, "platform", "darwin")
1106+
@mock.patch.object(_mtls_helper, "_tempfile_cert_key_paths", autospec=True)
1107+
def test_hybrid_inputs(self, mock_tempfile_cm):
1108+
mock_tempfile_ctx = mock.MagicMock()
1109+
mock_tempfile_ctx.__enter__.return_value = (
1110+
None,
1111+
"/tmp/key",
1112+
b"new_pass",
1113+
)
1114+
mock_tempfile_cm.return_value = mock_tempfile_ctx
1115+
1116+
with _mtls_helper.secure_cert_key_paths(
1117+
"/pass/through/cert.pem", pytest.private_key_bytes, b"passphrase"
1118+
) as (cert_path, key_path, passphrase):
1119+
assert cert_path == "/pass/through/cert.pem"
1120+
assert key_path == "/tmp/key"
1121+
assert passphrase == b"new_pass"
1122+
1123+
1124+
class TestMemfdCertKeyPaths(object):
1125+
@mock.patch.object(os, "memfd_create", create=True)
1126+
@mock.patch.object(os, "write")
1127+
@mock.patch.object(os, "close")
1128+
def test_success_both_bytes(self, mock_close, mock_write, mock_memfd_create):
1129+
mock_memfd_create.side_effect = [10, 11]
1130+
with _mtls_helper._memfd_cert_key_paths(b"cert", b"key") as (
1131+
cert_path,
1132+
key_path,
1133+
):
1134+
assert cert_path == "/proc/self/fd/10"
1135+
assert key_path == "/proc/self/fd/11"
1136+
mock_write.assert_has_calls([mock.call(10, b"cert"), mock.call(11, b"key")])
1137+
assert mock_close.call_count == 2
1138+
1139+
@mock.patch.object(os, "memfd_create", create=True)
1140+
@mock.patch.object(os, "write")
1141+
@mock.patch.object(os, "close")
1142+
def test_close_ignores_oserror(self, mock_close, mock_write, mock_memfd_create):
1143+
mock_memfd_create.return_value = 12
1144+
mock_close.side_effect = OSError("close error")
1145+
with _mtls_helper._memfd_cert_key_paths(b"cert", None) as (cert_path, key_path):
1146+
assert cert_path == "/proc/self/fd/12"
1147+
assert key_path is None
1148+
mock_close.assert_called_once_with(12)
1149+
1150+
@mock.patch.object(os, "memfd_create", create=True)
1151+
@mock.patch.object(os, "write")
1152+
@mock.patch.object(os, "close")
1153+
def test_write_oserror_prevents_fd_leak(
1154+
self, mock_close, mock_write, mock_memfd_create
1155+
):
1156+
mock_memfd_create.return_value = 15
1157+
mock_write.side_effect = OSError("write fault")
1158+
with pytest.raises(OSError):
1159+
with _mtls_helper._memfd_cert_key_paths(b"cert", None):
1160+
pass
1161+
mock_close.assert_called_once_with(15)
1162+
1163+
1164+
class TestTempfileCertKeyPaths(object):
1165+
@mock.patch.object(os.path, "isdir", return_value=True)
1166+
@mock.patch.object(tempfile, "mkstemp")
1167+
@mock.patch.object(os, "fdopen")
1168+
@mock.patch.object(_mtls_helper, "_encrypt_key_if_plaintext", autospec=True)
1169+
@mock.patch.object(_mtls_helper, "_secure_wipe_and_remove", autospec=True)
1170+
def test_success_shm(
1171+
self,
1172+
mock_wipe,
1173+
mock_encrypt,
1174+
mock_fdopen,
1175+
mock_mkstemp,
1176+
mock_isdir,
1177+
):
1178+
mock_mkstemp.side_effect = [(1, "/shm/cert"), (2, "/shm/key")]
1179+
mock_encrypt.return_value = (b"encrypted_key", b"new_pass")
1180+
mock_file = mock.MagicMock()
1181+
mock_file.fileno.return_value = 1
1182+
mock_fdopen.return_value.__enter__.return_value = mock_file
1183+
1184+
with mock.patch.object(os, "remove") as mock_remove, mock.patch.object(
1185+
os.path, "exists", return_value=True
1186+
):
1187+
with _mtls_helper._tempfile_cert_key_paths(b"cert", b"key", b"pass") as (
1188+
cert_path,
1189+
key_path,
1190+
passphrase,
1191+
):
1192+
assert cert_path == "/shm/cert"
1193+
assert key_path == "/shm/key"
1194+
assert passphrase == b"new_pass"
1195+
mock_remove.assert_called_once_with("/shm/cert")
1196+
1197+
mock_mkstemp.assert_has_calls(
1198+
[mock.call(dir="/dev/shm"), mock.call(dir="/dev/shm")]
1199+
)
1200+
mock_wipe.assert_called_once_with("/shm/key")
1201+
1202+
@mock.patch.object(os.path, "isdir", return_value=True)
1203+
@mock.patch.object(tempfile, "mkstemp")
1204+
@mock.patch.object(os, "fdopen")
1205+
@mock.patch.object(_mtls_helper, "_encrypt_key_if_plaintext", autospec=True)
1206+
@mock.patch.object(_mtls_helper, "_secure_wipe_and_remove", autospec=True)
1207+
def test_permission_error_loop_resilience(
1208+
self,
1209+
mock_wipe,
1210+
mock_encrypt,
1211+
mock_fdopen,
1212+
mock_mkstemp,
1213+
mock_isdir,
1214+
):
1215+
mock_mkstemp.side_effect = [(1, "/shm/cert"), (2, "/shm/key")]
1216+
mock_encrypt.return_value = (b"encrypted_key", b"new_pass")
1217+
mock_file = mock.MagicMock()
1218+
mock_file.fileno.return_value = 1
1219+
mock_fdopen.return_value.__enter__.return_value = mock_file
1220+
1221+
mock_wipe.side_effect = PermissionError("lock error")
1222+
1223+
with mock.patch.object(os, "remove") as mock_remove, mock.patch.object(
1224+
os.path, "exists", return_value=True
1225+
):
1226+
with _mtls_helper._tempfile_cert_key_paths(b"cert", b"key", b"pass"):
1227+
pass
1228+
mock_remove.assert_called_once_with("/shm/cert")
1229+
1230+
1231+
class TestEncryptKeyIfPlaintext(object):
1232+
def test_encrypts_plaintext_key(self):
1233+
encrypted_bytes, passphrase = _mtls_helper._encrypt_key_if_plaintext(
1234+
pytest.private_key_bytes, b"my_passphrase"
1235+
)
1236+
assert passphrase == b"my_passphrase"
1237+
assert encrypted_bytes != pytest.private_key_bytes
1238+
assert b"ENCRYPTED PRIVATE KEY" in encrypted_bytes
1239+
1240+
decrypted = serialization.load_pem_private_key(
1241+
encrypted_bytes, password=b"my_passphrase"
1242+
)
1243+
assert decrypted
1244+
1245+
@mock.patch("secrets.token_hex", return_value="0123456789abcdef0123456789abcdef")
1246+
def test_default_passphrase_generation(self, mock_secrets):
1247+
encrypted_bytes, passphrase = _mtls_helper._encrypt_key_if_plaintext(
1248+
pytest.private_key_bytes, None
1249+
)
1250+
assert passphrase == b"0123456789abcdef0123456789abcdef"
1251+
assert b"ENCRYPTED PRIVATE KEY" in encrypted_bytes
1252+
1253+
1254+
class TestSecureWipeAndRemove(object):
1255+
@mock.patch.object(os.path, "exists", return_value=True)
1256+
@mock.patch.object(os.path, "getsize", return_value=10)
1257+
@mock.patch("builtins.open", autospec=True)
1258+
@mock.patch.object(os, "fsync")
1259+
@mock.patch.object(os, "remove")
1260+
def test_success(
1261+
self, mock_remove, mock_fsync, mock_open, mock_getsize, mock_exists
1262+
):
1263+
mock_fh = mock.MagicMock()
1264+
mock_fh.fileno.return_value = 1
1265+
mock_open.return_value.__enter__.return_value = mock_fh
1266+
1267+
_mtls_helper._secure_wipe_and_remove("/path/to/secret")
1268+
1269+
mock_open.assert_called_once_with("/path/to/secret", "r+b")
1270+
mock_fh.write.assert_called_once_with(b"\0" * 10)
1271+
mock_fsync.assert_called_once()
1272+
mock_remove.assert_called_once_with("/path/to/secret")

0 commit comments

Comments
 (0)