|
34 | 34 |
|
35 | 35 | class TestSessionsMtls: |
36 | 36 | @pytest.mark.asyncio |
37 | | - @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) |
38 | | - @mock.patch("os.path.exists") |
39 | | - @mock.patch( |
40 | | - "builtins.open", |
41 | | - new_callable=mock.mock_open, |
42 | | - read_data=json.dumps(VALID_WORKLOAD_CONFIG), |
43 | | - ) |
44 | | - @mock.patch("google.auth.aio.transport.mtls.get_client_cert_and_key") |
45 | | - @mock.patch("ssl.create_default_context") |
46 | | - async def test_configure_mtls_channel( |
47 | | - self, mock_ssl, mock_helper, mock_file, mock_exists |
48 | | - ): |
| 37 | + async def test_configure_mtls_channel(self): |
49 | 38 | """ |
50 | 39 | Tests that the mTLS channel configures correctly when a |
51 | 40 | valid workload config is mocked. |
52 | 41 | """ |
53 | | - mock_exists.return_value = True |
54 | | - mock_helper.return_value = (True, b"fake_cert_data", b"fake_key_data") |
55 | | - |
56 | | - mock_context = mock.Mock(spec=ssl.SSLContext) |
57 | | - mock_ssl.return_value = mock_context |
| 42 | + with mock.patch.dict( |
| 43 | + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"} |
| 44 | + ), mock.patch("os.path.exists") as mock_exists, mock.patch( |
| 45 | + "builtins.open", mock.mock_open(read_data=json.dumps(VALID_WORKLOAD_CONFIG)) |
| 46 | + ), mock.patch( |
| 47 | + "google.auth.aio.transport.mtls.get_client_cert_and_key" |
| 48 | + ) as mock_helper, mock.patch( |
| 49 | + "ssl.create_default_context" |
| 50 | + ) as mock_ssl: |
| 51 | + mock_exists.return_value = True |
| 52 | + mock_helper.return_value = (True, b"fake_cert_data", b"fake_key_data") |
| 53 | + |
| 54 | + mock_context = mock.Mock(spec=ssl.SSLContext) |
| 55 | + mock_ssl.return_value = mock_context |
| 56 | + |
| 57 | + # Use AsyncMock for credentials to avoid "coroutine never awaited" warnings |
| 58 | + mock_creds = mock.AsyncMock(spec=credentials.Credentials) |
| 59 | + session = sessions.AsyncAuthorizedSession(mock_creds) |
58 | 60 |
|
59 | | - mock_creds = mock.Mock(spec=credentials.Credentials) |
60 | | - session = sessions.AsyncAuthorizedSession(mock_creds) |
61 | | - await session.configure_mtls_channel() |
| 61 | + await session.configure_mtls_channel() |
62 | 62 |
|
63 | | - assert session._is_mtls is True |
64 | | - assert mock_context.load_cert_chain.called |
| 63 | + assert session._is_mtls is True |
| 64 | + assert mock_context.load_cert_chain.called |
65 | 65 |
|
66 | 66 | @pytest.mark.asyncio |
67 | | - @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) |
68 | | - @mock.patch("os.path.exists") |
69 | | - async def test_configure_mtls_channel_disabled(self, mock_exists): |
| 67 | + async def test_configure_mtls_channel_disabled(self): |
70 | 68 | """ |
71 | 69 | Tests behavior when the config file does not exist. |
72 | 70 | """ |
73 | | - mock_exists.return_value = False |
74 | | - mock_creds = mock.Mock(spec=credentials.Credentials) |
| 71 | + with mock.patch.dict( |
| 72 | + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"} |
| 73 | + ), mock.patch("os.path.exists") as mock_exists: |
| 74 | + mock_exists.return_value = False |
| 75 | + mock_creds = mock.AsyncMock(spec=credentials.Credentials) |
| 76 | + session = sessions.AsyncAuthorizedSession(mock_creds) |
75 | 77 |
|
76 | | - session = sessions.AsyncAuthorizedSession(mock_creds) |
77 | | - await session.configure_mtls_channel() |
| 78 | + await session.configure_mtls_channel() |
78 | 79 |
|
79 | | - # If the file doesn't exist, it shouldn't error; it just won't use mTLS |
80 | | - assert session._is_mtls is False |
| 80 | + # If the file doesn't exist, it shouldn't error; it just won't use mTLS |
| 81 | + assert session._is_mtls is False |
81 | 82 |
|
82 | 83 | @pytest.mark.asyncio |
83 | | - @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) |
84 | | - @mock.patch("os.path.exists") |
85 | | - @mock.patch( |
86 | | - "builtins.open", new_callable=mock.mock_open, read_data='{"invalid": "format"}' |
87 | | - ) |
88 | | - async def test_configure_mtls_channel_invalid_format(self, mock_file, mock_exists): |
| 84 | + async def test_configure_mtls_channel_invalid_format(self): |
89 | 85 | """ |
90 | 86 | Verifies that the MutualTLSChannelError is raised for bad formats. |
91 | 87 | """ |
92 | | - mock_exists.return_value = True |
93 | | - mock_creds = mock.Mock(spec=credentials.Credentials) |
| 88 | + with mock.patch.dict( |
| 89 | + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"} |
| 90 | + ), mock.patch("os.path.exists") as mock_exists, mock.patch( |
| 91 | + "builtins.open", mock.mock_open(read_data='{"invalid": "format"}') |
| 92 | + ): |
| 93 | + mock_exists.return_value = True |
| 94 | + mock_creds = mock.AsyncMock(spec=credentials.Credentials) |
| 95 | + session = sessions.AsyncAuthorizedSession(mock_creds) |
94 | 96 |
|
95 | | - session = sessions.AsyncAuthorizedSession(mock_creds) |
96 | | - with pytest.raises(exceptions.MutualTLSChannelError): |
97 | | - await session.configure_mtls_channel() |
| 97 | + with pytest.raises(exceptions.MutualTLSChannelError): |
| 98 | + await session.configure_mtls_channel() |
98 | 99 |
|
99 | 100 | @pytest.mark.asyncio |
100 | | - @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) |
101 | | - @mock.patch( |
102 | | - "google.auth.transport.mtls.has_default_client_cert_source", |
103 | | - return_value=True, |
104 | | - ) |
105 | | - async def test_configure_mtls_channel_mock_callback(self, mock_has_cert): |
| 101 | + async def test_configure_mtls_channel_mock_callback(self): |
106 | 102 | """ |
107 | 103 | Tests mTLS configuration using bytes-returning callback. |
108 | 104 | """ |
109 | 105 |
|
110 | 106 | def mock_callback(): |
111 | 107 | return (b"fake_cert_bytes", b"fake_key_bytes") |
112 | 108 |
|
113 | | - mock_creds = mock.Mock(spec=credentials.Credentials) |
114 | | - |
115 | | - with mock.patch("ssl.SSLContext.load_cert_chain"): |
| 109 | + with mock.patch.dict( |
| 110 | + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"} |
| 111 | + ), mock.patch( |
| 112 | + "google.auth.transport.mtls.has_default_client_cert_source", |
| 113 | + return_value=True, |
| 114 | + ), mock.patch( |
| 115 | + "ssl.SSLContext.load_cert_chain" |
| 116 | + ): |
| 117 | + mock_creds = mock.AsyncMock(spec=credentials.Credentials) |
116 | 118 | session = sessions.AsyncAuthorizedSession(mock_creds) |
| 119 | + |
117 | 120 | await session.configure_mtls_channel(client_cert_callback=mock_callback) |
118 | 121 |
|
119 | 122 | assert session._is_mtls is True |
0 commit comments