Skip to content

Commit ea8fc46

Browse files
committed
Fix tests
1 parent 0425097 commit ea8fc46

4 files changed

Lines changed: 132 additions & 187 deletions

File tree

README.md

Lines changed: 68 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` ins
252252
from trino.auth import OAuth2Authentication
253253

254254
engine = create_engine(
255-
"trino://<username>@<host>:<port>/<catalog>",
255+
"trino://<username>@<host>:<port>/<catalog>",
256256
connect_args={
257257
"auth": OAuth2Authentication(),
258258
"http_scheme": "https",
@@ -262,156 +262,84 @@ The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` ins
262262

263263
#### Client Credentials authentication
264264

265-
The `ClientCredentials` class enables service-to-service authentication using standard OAuth2 Client Credentials flow.
266-
267-
- DBAPI
268-
269-
export KEYRING_CRYPTFILE_PASSWORD=password123
270-
271-
```python
272-
from trino.dbapi import connect
273-
from trino.auth import ClientCredentials, OidcConfig
274-
275-
conn = connect(
276-
user="<username>",
277-
auth=ClientCredentials(
278-
client_id="<client_id>",
279-
client_secret="<client_secret>",
280-
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
281-
),
282-
http_scheme="https",
283-
...
284-
)
285-
```
286-
287-
Or using manual URLs:
288-
289-
```python
290-
from trino.dbapi import connect
291-
from trino.auth import ClientCredentials, ManualUrlsConfig
292-
293-
conn = connect(
294-
user="<username>",
295-
auth=ClientCredentials(
296-
client_id="<client_id>",
297-
client_secret="<client_secret>",
298-
url_config=ManualUrlsConfig(token_endpoint="<token_endpoint>")
299-
),
300-
http_scheme="https",
301-
...
302-
)
303-
```
304-
305-
With optional `scope` and `audience`:
306-
307-
```python
308-
from trino.dbapi import connect
309-
from trino.auth import ClientCredentials, OidcConfig
265+
```python
266+
from trino.dbapi import connect
267+
from trino.auth import ClientCredentials
268+
from trino.oauth2.models import OidcConfig
269+
270+
auth = ClientCredentials(
271+
client_id="<client_id>",
272+
client_secret="<client_secret>",
273+
url_config=OidcConfig(
274+
token_endpoint="<token_endpoint>",
275+
# other endpoints if needed
276+
),
277+
scope="<number of scopes>", # optional
278+
audience="<audience>", # optional
279+
)
310280

311-
conn = connect(
312-
user="<username>",
313-
auth=ClientCredentials(
314-
client_id="<client_id>",
315-
client_secret="<client_secret>",
316-
scope="<scope>",
317-
audience="<audience>",
318-
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
319-
),
320-
http_scheme="https",
321-
...
322-
)
323-
```
281+
conn = connect(
282+
user="<username>",
283+
auth=auth,
284+
http_scheme="https",
285+
...
286+
)
287+
```
324288

325289
#### Device Code authentication
326290

327-
The `DeviceCode` class enables authentication on devices with limited input capabilities using OAuth2 Device Code flow.
328-
This flow prints tries to open an browser to the URL but it will also fall back and print the verification code and
329-
a URL to visit on another device to authenticate.
330-
331-
- DBAPI
332-
333-
export KEYRING_CRYPTFILE_PASSWORD=password123
334-
335-
```python
336-
from trino.dbapi import connect
337-
from trino.auth import DeviceCode, OidcConfig
338-
339-
conn = connect(
340-
user="<username>",
341-
auth=DeviceCode(
342-
client_id="<client_id>",
343-
client_secret="<client_secret>",
344-
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
345-
),
346-
http_scheme="https",
347-
...
348-
)
349-
```
350-
351-
With optional `scope` and `audience`:
352-
353-
```python
354-
from trino.dbapi import connect
355-
from trino.auth import DeviceCode, OidcConfig
291+
```python
292+
from trino.dbapi import connect
293+
from trino.auth import DeviceCode
294+
from trino.oauth2.models import OidcConfig
295+
296+
auth = DeviceCode(
297+
client_id="<client_id>",
298+
url_config=OidcConfig(
299+
token_endpoint="<token_endpoint>",
300+
device_authorization_endpoint="<device_authorization_endpoint>",
301+
),
302+
scope="<scope>", # optional
303+
audience="<audience>", # optional
304+
)
356305

357-
conn = connect(
358-
user="<username>",
359-
auth=DeviceCode(
360-
client_id="<client_id>",
361-
client_secret="<client_secret>",
362-
scope="<scope>",
363-
audience="<audience>",
364-
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
365-
),
366-
http_scheme="https",
367-
...
368-
)
369-
```
306+
conn = connect(
307+
user="<username>",
308+
auth=auth,
309+
http_scheme="https",
310+
...
311+
)
312+
```
370313

371314
#### Authorization Code authentication
372315

373-
The `AuthorizationCode` class enables the standard OAuth2 Authorization Code flow.
374-
375-
- DBAPI
376-
377-
export KEYRING_CRYPTFILE_PASSWORD=password123
378-
379-
```python
380-
from trino.dbapi import connect
381-
from trino.auth import AuthorizationCode, OidcConfig
382-
383-
conn = connect(
384-
user="<username>",
385-
auth=AuthorizationCode(
386-
client_id="<client_id>",
387-
client_secret="<client_secret>",
388-
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
389-
),
390-
http_scheme="https",
391-
...
392-
)
393-
```
394-
395-
With optional `scope` and `audience`:
316+
```python
317+
from trino.dbapi import connect
318+
from trino.auth import AuthorizationCode
319+
from trino.oauth2.models import OidcConfig
320+
321+
auth = AuthorizationCode(
322+
client_id="<client_id>",
323+
client_secret="<client_secret>", # optional
324+
url_config=OidcConfig(
325+
token_endpoint="<token_endpoint>",
326+
authorization_endpoint="<authorization_endpoint>",
327+
),
328+
scope="<scope>", # optional
329+
audience="<audience>", # optional
330+
)
396331

397-
```python
398-
from trino.dbapi import connect
399-
from trino.auth import AuthorizationCode, OidcConfig
332+
conn = connect(
333+
user="<username>",
334+
auth=auth,
335+
http_scheme="https",
336+
...
337+
)
338+
```
400339

401-
conn = connect(
402-
user="<username>",
403-
auth=AuthorizationCode(
404-
client_id="<client_id>",
405-
client_secret="<client_secret>",
406-
scope="<scope>",
407-
audience="<audience>",
408-
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
409-
),
410-
http_scheme="https",
411-
...
412-
)
413-
```
340+
### Reference
414341

342+
For further details, please consult [Trino documentation](https://trino.io/docs/current).
415343

416344
### Secure Token Storage
417345

tests/unit/oauth_test_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,54 @@ def get_token_callback(self, request, uri, response_headers):
150150
if challenge.attempts == 0:
151151
return [200, response_headers, f'{{"token": "{challenge.token}"}}']
152152
return [200, response_headers, f'{{"nextUri": "{uri}"}}']
153+
154+
155+
import keyring.backend
156+
157+
158+
class MockKeyring(keyring.backend.KeyringBackend):
159+
priority = 1
160+
161+
def __init__(self):
162+
self.file_location = self._generate_test_root_dir()
163+
164+
@staticmethod
165+
def _generate_test_root_dir():
166+
import tempfile
167+
168+
return tempfile.mkdtemp(prefix="trino-python-client-unit-test-")
169+
170+
def _get_file_path(self, servicename, username):
171+
from os.path import join
172+
173+
file_location = self.file_location
174+
file_name = f"{servicename}_{username}.txt"
175+
return join(file_location, file_name)
176+
177+
def set_password(self, servicename, username, password):
178+
file_path = self._get_file_path(servicename, username)
179+
180+
with open(file_path, "w") as file:
181+
file.write(password)
182+
183+
def get_password(self, servicename, username):
184+
import os
185+
186+
file_path = self._get_file_path(servicename, username)
187+
if not os.path.exists(file_path):
188+
return None
189+
190+
with open(file_path, "r") as file:
191+
password = file.read()
192+
193+
return password
194+
195+
def delete_password(self, servicename, username):
196+
import os
197+
198+
file_path = self._get_file_path(servicename, username)
199+
if not os.path.exists(file_path):
200+
return None
201+
202+
os.remove(file_path)
203+

tests/unit/test_client.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from tests.unit.oauth_test_utils import RedirectHandlerWithException
4848
from tests.unit.oauth_test_utils import SERVER_ADDRESS
4949
from tests.unit.oauth_test_utils import TOKEN_RESOURCE
50+
from tests.unit.oauth_test_utils import MockKeyring
5051
from trino import __version__
5152
from trino import constants
5253
from trino.auth import _OAuth2KeyRingTokenCache
@@ -1406,47 +1407,3 @@ def test_store_long_password(self):
14061407
retrieved_password = cache.get_token_from_cache(host)
14071408
self.assertEqual(long_password, retrieved_password)
14081409

1409-
1410-
class MockKeyring(keyring.backend.KeyringBackend):
1411-
def __init__(self):
1412-
self.file_location = self._generate_test_root_dir()
1413-
1414-
@staticmethod
1415-
def _generate_test_root_dir():
1416-
import tempfile
1417-
1418-
return tempfile.mkdtemp(prefix="trino-python-client-unit-test-")
1419-
1420-
def file_path(self, servicename, username):
1421-
from os.path import join
1422-
1423-
file_location = self.file_location
1424-
file_name = f"{servicename}_{username}.txt"
1425-
return join(file_location, file_name)
1426-
1427-
def set_password(self, servicename, username, password):
1428-
file_path = self.file_path(servicename, username)
1429-
1430-
with open(file_path, "w") as file:
1431-
file.write(password)
1432-
1433-
def get_password(self, servicename, username):
1434-
import os
1435-
1436-
file_path = self.file_path(servicename, username)
1437-
if not os.path.exists(file_path):
1438-
return None
1439-
1440-
with open(file_path, "r") as file:
1441-
password = file.read()
1442-
1443-
return password
1444-
1445-
def delete_password(self, servicename, username):
1446-
import os
1447-
1448-
file_path = self.file_path(servicename, username)
1449-
if not os.path.exists(file_path):
1450-
return None
1451-
1452-
os.remove(file_path)

tests/unit/test_dbapi.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import httpretty
1717
import pytest
18+
import keyring
1819
from httpretty import httprettified
1920
from requests import Session
2021

@@ -26,6 +27,7 @@
2627
from tests.unit.oauth_test_utils import RedirectHandler
2728
from tests.unit.oauth_test_utils import SERVER_ADDRESS
2829
from tests.unit.oauth_test_utils import TOKEN_RESOURCE
30+
from tests.unit.oauth_test_utils import MockKeyring
2931
from trino import constants
3032
from trino.auth import OAuth2Authentication
3133
from trino.dbapi import connect
@@ -58,8 +60,15 @@ def test_http_session_is_defaulted_when_not_specified(mock_client):
5860
assert mock_client.TrinoRequest.http.Session.return_value in request_args
5961

6062

63+
@pytest.fixture
64+
def mock_keyring():
65+
mk = MockKeyring()
66+
keyring.set_keyring(mk)
67+
return mk
68+
69+
6170
@httprettified
62-
def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sample_get_response_data):
71+
def test_token_retrieved_once_per_auth_instance(mock_keyring, sample_post_response_data, sample_get_response_data):
6372
token = str(uuid.uuid4())
6473
challenge_id = str(uuid.uuid4())
6574

@@ -123,7 +132,7 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl
123132

124133

125134
@httprettified
126-
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data,
135+
def test_token_retrieved_once_when_authentication_instance_is_shared(mock_keyring, sample_post_response_data,
127136
sample_get_response_data):
128137
token = str(uuid.uuid4())
129138
challenge_id = str(uuid.uuid4())
@@ -189,7 +198,7 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post
189198

190199

191200
@httprettified
192-
def test_token_retrieved_once_when_multithreaded(sample_post_response_data, sample_get_response_data):
201+
def test_token_retrieved_once_when_multithreaded(mock_keyring, sample_post_response_data, sample_get_response_data):
193202
token = str(uuid.uuid4())
194203
challenge_id = str(uuid.uuid4())
195204

0 commit comments

Comments
 (0)