Skip to content

Commit a6c034f

Browse files
committed
Adding oauth2 and 2.1 support to trino-python-client
Added documentaton for scopes and audiences. moved set_session to a base class.
1 parent 2108c38 commit a6c034f

4 files changed

Lines changed: 318 additions & 2 deletions

File tree

README.md

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ the [`JWT` authentication type](https://trino.io/docs/current/security/jwt.html)
217217

218218
### OAuth2 authentication
219219

220+
Make sure that the OAuth2 support is installed using `pip install trino[oauth]`.
221+
222+
#### Interactive Browser authentication
223+
220224
The `OAuth2Authentication` class can be used to connect to a Trino cluster configured with
221225
the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.html).
222226

@@ -256,6 +260,152 @@ The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` ins
256260
)
257261
```
258262

263+
#### Client Credentials authentication
264+
265+
The `ClientCredentials` class enables service-to-service authentication using standard OAuth2 Client Credentials flow.
266+
267+
- DBAPI
268+
269+
```python
270+
from trino.dbapi import connect
271+
from trino.auth import ClientCredentials, OidcConfig
272+
273+
conn = connect(
274+
user="<username>",
275+
auth=ClientCredentials(
276+
client_id="<client_id>",
277+
client_secret="<client_secret>",
278+
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
279+
),
280+
http_scheme="https",
281+
...
282+
)
283+
```
284+
285+
Or using manual URLs:
286+
287+
```python
288+
from trino.dbapi import connect
289+
from trino.auth import ClientCredentials, ManualUrlsConfig
290+
291+
conn = connect(
292+
user="<username>",
293+
auth=ClientCredentials(
294+
client_id="<client_id>",
295+
client_secret="<client_secret>",
296+
url_config=ManualUrlsConfig(token_endpoint="<token_endpoint>")
297+
),
298+
http_scheme="https",
299+
...
300+
)
301+
```
302+
303+
With optional `scope` and `audience`:
304+
305+
```python
306+
from trino.dbapi import connect
307+
from trino.auth import ClientCredentials, OidcConfig
308+
309+
conn = connect(
310+
user="<username>",
311+
auth=ClientCredentials(
312+
client_id="<client_id>",
313+
client_secret="<client_secret>",
314+
scope="<scope>",
315+
audience="<audience>",
316+
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
317+
),
318+
http_scheme="https",
319+
...
320+
)
321+
```
322+
323+
#### Device Code authentication
324+
325+
The `DeviceCode` class enables authentication on devices with limited input capabilities using OAuth2 Device Code flow.
326+
This flow prints tries to open an browser to the URL but it will also fall back and print the verification code and
327+
a URL to visit on another device to authenticate.
328+
329+
- DBAPI
330+
331+
```python
332+
from trino.dbapi import connect
333+
from trino.auth import DeviceCode, OidcConfig
334+
335+
conn = connect(
336+
user="<username>",
337+
auth=DeviceCode(
338+
client_id="<client_id>",
339+
client_secret="<client_secret>",
340+
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
341+
),
342+
http_scheme="https",
343+
...
344+
)
345+
```
346+
347+
With optional `scope` and `audience`:
348+
349+
```python
350+
from trino.dbapi import connect
351+
from trino.auth import DeviceCode, OidcConfig
352+
353+
conn = connect(
354+
user="<username>",
355+
auth=DeviceCode(
356+
client_id="<client_id>",
357+
client_secret="<client_secret>",
358+
scope="<scope>",
359+
audience="<audience>",
360+
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
361+
),
362+
http_scheme="https",
363+
...
364+
)
365+
```
366+
367+
#### Authorization Code authentication
368+
369+
The `AuthorizationCode` class enables the standard OAuth2 Authorization Code flow.
370+
371+
- DBAPI
372+
373+
```python
374+
from trino.dbapi import connect
375+
from trino.auth import AuthorizationCode, OidcConfig
376+
377+
conn = connect(
378+
user="<username>",
379+
auth=AuthorizationCode(
380+
client_id="<client_id>",
381+
client_secret="<client_secret>",
382+
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
383+
),
384+
http_scheme="https",
385+
...
386+
)
387+
```
388+
389+
With optional `scope` and `audience`:
390+
391+
```python
392+
from trino.dbapi import connect
393+
from trino.auth import AuthorizationCode, OidcConfig
394+
395+
conn = connect(
396+
user="<username>",
397+
auth=AuthorizationCode(
398+
client_id="<client_id>",
399+
client_secret="<client_secret>",
400+
scope="<scope>",
401+
audience="<audience>",
402+
url_config=OidcConfig(oidc_discovery_url="<oidc_discovery_url>")
403+
),
404+
http_scheme="https",
405+
...
406+
)
407+
```
408+
259409
### Certificate authentication
260410

261411
`CertificateAuthentication` class can be used to connect to Trino cluster configured with [certificate based authentication](https://trino.io/docs/current/security/certificate.html). `CertificateAuthentication` requires paths to a valid client certificate and private key.

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
"krb5 == 0.5.1"]
3434
sqlalchemy_require = ["sqlalchemy >= 1.3"]
3535
external_authentication_token_cache_require = ["keyring"]
36+
oauth_require = ["trino.oauth2 @ git+https://github.com/dprophet/trino-python-oauth2"]
3637

3738
# We don't add localstorage_require to all_require as users must explicitly opt in to use keyring.
38-
all_require = kerberos_require + sqlalchemy_require
39+
all_require = kerberos_require + sqlalchemy_require + oauth_require
3940

4041
tests_require = all_require + gssapi_require + [
4142
# httpretty >= 1.1 duplicates requests in `httpretty.latest_requests`
@@ -96,6 +97,7 @@
9697
"all": all_require,
9798
"kerberos": kerberos_require,
9899
"gssapi": gssapi_require,
100+
"oauth": oauth_require,
99101
"sqlalchemy": sqlalchemy_require,
100102
"tests": tests_require,
101103
"external-authentication-token-cache": external_authentication_token_cache_require,

trino/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12+
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
13+
1214
from . import auth
1315
from . import client
1416
from . import constants

trino/auth.py

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import List
2424
from typing import Optional
2525
from typing import Tuple
26+
from typing import Union
2627
from urllib.parse import urlparse
2728

2829
from requests import PreparedRequest
@@ -37,6 +38,12 @@
3738
from trino.constants import HEADER_ORIGINAL_USER
3839
from trino.constants import HEADER_USER
3940
from trino.constants import MAX_NT_PASSWORD_SIZE
41+
from trino.oauth2 import OAuth2Client
42+
from trino.oauth2.models import AuthorizationCodeConfig
43+
from trino.oauth2.models import ClientCredentialsConfig
44+
from trino.oauth2.models import DeviceCodeConfig
45+
from trino.oauth2.models import ManualUrlsConfig
46+
from trino.oauth2.models import OidcConfig
4047

4148
logger = trino.logging.get_logger(__name__)
4249

@@ -50,6 +57,23 @@ def get_exceptions(self) -> Tuple[Any, ...]:
5057
return tuple()
5158

5259

60+
class OAuth2TokenAuthentication(Authentication):
61+
"""Shared base for OAuth2 strategies that authenticate with a bearer token."""
62+
63+
def __init__(self) -> None:
64+
self._oauth2: Optional[OAuth2Client] = None
65+
66+
@property
67+
def oauth2(self) -> OAuth2Client:
68+
if self._oauth2 is None:
69+
raise RuntimeError("OAuth2 client not initialized")
70+
return self._oauth2
71+
72+
def set_http_session(self, http_session: Session) -> Session:
73+
http_session.auth = _BearerAuth(self.oauth2.token())
74+
return http_session
75+
76+
5377
class KerberosAuthentication(Authentication):
5478
MUTUAL_REQUIRED = 1
5579
MUTUAL_OPTIONAL = 2
@@ -276,6 +300,141 @@ def __eq__(self, other: object) -> bool:
276300
return self.token == other.token
277301

278302

303+
class ClientCredentials(OAuth2TokenAuthentication):
304+
def __init__(self,
305+
client_id: str,
306+
client_secret: str,
307+
url_config: Union[OidcConfig, ManualUrlsConfig],
308+
scope: Optional[str] = None,
309+
audience: Optional[str] = None):
310+
super().__init__()
311+
self.client_id = client_id
312+
self.client_secret = client_secret
313+
self.url_config = url_config
314+
self.scope = scope
315+
self.audience = audience
316+
317+
config_args = {
318+
"client_id": self.client_id,
319+
"client_secret": self.client_secret,
320+
"url_config": self.url_config,
321+
}
322+
if self.scope is not None:
323+
config_args["scope"] = self.scope
324+
if self.audience is not None:
325+
config_args["audience"] = self.audience
326+
327+
self._oauth2 = OAuth2Client(
328+
config=ClientCredentialsConfig(**config_args)
329+
)
330+
331+
def get_exceptions(self) -> Tuple[Any, ...]:
332+
return ()
333+
334+
def __eq__(self, other: object) -> bool:
335+
if not isinstance(other, ClientCredentials):
336+
return False
337+
return (
338+
self.client_id == other.client_id
339+
and self.client_secret == other.client_secret
340+
and self.url_config == other.url_config
341+
)
342+
343+
344+
class DeviceCode(OAuth2TokenAuthentication):
345+
def __init__(self,
346+
client_id: str,
347+
url_config: Union[OidcConfig, ManualUrlsConfig],
348+
client_secret: Optional[str] = None,
349+
scope: Optional[str] = None,
350+
audience: Optional[str] = None,
351+
automation_callback: Optional[Callable[[str], None]] = None):
352+
353+
super().__init__()
354+
self.client_id = client_id
355+
self.client_secret = client_secret
356+
self.url_config = url_config
357+
self.scope = scope
358+
self.audience = audience
359+
self.automation_callback = automation_callback
360+
361+
config_args = {
362+
"client_id": self.client_id,
363+
"url_config": self.url_config,
364+
}
365+
if self.client_secret is not None:
366+
config_args["client_secret"] = self.client_secret
367+
if self.scope is not None:
368+
config_args["scope"] = self.scope
369+
if self.audience is not None:
370+
config_args["audience"] = self.audience
371+
if self.automation_callback is not None:
372+
config_args["automation_callback"] = self.automation_callback
373+
374+
self._oauth2 = OAuth2Client(
375+
config=DeviceCodeConfig(**config_args)
376+
)
377+
378+
def get_exceptions(self) -> Tuple[Any, ...]:
379+
return ()
380+
381+
def __eq__(self, other: object) -> bool:
382+
if not isinstance(other, DeviceCode):
383+
return False
384+
return (
385+
self.client_id == other.client_id
386+
and self.client_secret == other.client_secret
387+
and self.url_config == other.url_config
388+
)
389+
390+
391+
class AuthorizationCode(OAuth2TokenAuthentication):
392+
def __init__(self,
393+
client_id: str,
394+
url_config: Union[OidcConfig, ManualUrlsConfig],
395+
client_secret: Optional[str] = None,
396+
scope: Optional[str] = None,
397+
audience: Optional[str] = None,
398+
automation_callback: Optional[Callable[[str], None]] = None):
399+
400+
super().__init__()
401+
self.client_id = client_id
402+
self.client_secret = client_secret
403+
self.url_config = url_config
404+
self.scope = scope
405+
self.audience = audience
406+
self.automation_callback = automation_callback
407+
408+
config_args = {
409+
"client_id": self.client_id,
410+
"url_config": self.url_config,
411+
}
412+
if self.client_secret is not None:
413+
config_args["client_secret"] = self.client_secret
414+
if self.scope is not None:
415+
config_args["scope"] = self.scope
416+
if self.audience is not None:
417+
config_args["audience"] = self.audience
418+
if self.automation_callback is not None:
419+
config_args["automation_callback"] = self.automation_callback
420+
421+
self._oauth2 = OAuth2Client(
422+
config=AuthorizationCodeConfig(**config_args)
423+
)
424+
425+
def get_exceptions(self) -> Tuple[Any, ...]:
426+
return ()
427+
428+
def __eq__(self, other: object) -> bool:
429+
if not isinstance(other, DeviceCode):
430+
return False
431+
return (
432+
self.client_id == other.client_id
433+
and self.client_secret == other.client_secret
434+
and self.url_config == other.url_config
435+
)
436+
437+
279438
class RedirectHandler(metaclass=abc.ABCMeta):
280439
"""
281440
Abstract class for OAuth redirect handlers, inherit from this class to implement your own redirect handler.
@@ -292,7 +451,10 @@ class ConsoleRedirectHandler(RedirectHandler):
292451
"""
293452

294453
def __call__(self, url: str) -> None:
295-
print(f"Open the following URL in browser for the external authentication:\n{url}", flush=True)
454+
print(
455+
f"Open the following URL in browser for the external authentication:\n{url}",
456+
flush=True,
457+
)
296458

297459

298460
class WebBrowserRedirectHandler(RedirectHandler):

0 commit comments

Comments
 (0)