2323from typing import List
2424from typing import Optional
2525from typing import Tuple
26+ from typing import Union
2627from urllib .parse import urlparse
2728
2829from requests import PreparedRequest
3738from trino .constants import HEADER_ORIGINAL_USER
3839from trino .constants import HEADER_USER
3940from trino .constants import MAX_NT_PASSWORD_SIZE
41+ from trino .oauth2 import OAuth2Client
42+ from trino .oauth2 .models import AuthorizationCodeConfig
43+ from trino .oauth2 .models import ClientCredentialsConfig
44+ from trino .oauth2 .models import DeviceCodeConfig
45+ from trino .oauth2 .models import ManualUrlsConfig
46+ from trino .oauth2 .models import OidcConfig
4047
4148logger = trino .logging .get_logger (__name__ )
4249
@@ -50,6 +57,23 @@ def get_exceptions(self) -> Tuple[Any, ...]:
5057 return tuple ()
5158
5259
60+ class OAuth2TokenAuthentication (Authentication ):
61+ """Shared base for OAuth2 strategies that authenticate with a bearer token."""
62+
63+ def __init__ (self ) -> None :
64+ self ._oauth2 : Optional [OAuth2Client ] = None
65+
66+ @property
67+ def oauth2 (self ) -> OAuth2Client :
68+ if self ._oauth2 is None :
69+ raise RuntimeError ("OAuth2 client not initialized" )
70+ return self ._oauth2
71+
72+ def set_http_session (self , http_session : Session ) -> Session :
73+ http_session .auth = _BearerAuth (self .oauth2 .token ())
74+ return http_session
75+
76+
5377class KerberosAuthentication (Authentication ):
5478 MUTUAL_REQUIRED = 1
5579 MUTUAL_OPTIONAL = 2
@@ -276,6 +300,141 @@ def __eq__(self, other: object) -> bool:
276300 return self .token == other .token
277301
278302
303+ class ClientCredentials (OAuth2TokenAuthentication ):
304+ def __init__ (self ,
305+ client_id : str ,
306+ client_secret : str ,
307+ url_config : Union [OidcConfig , ManualUrlsConfig ],
308+ scope : Optional [str ] = None ,
309+ audience : Optional [str ] = None ):
310+ super ().__init__ ()
311+ self .client_id = client_id
312+ self .client_secret = client_secret
313+ self .url_config = url_config
314+ self .scope = scope
315+ self .audience = audience
316+
317+ config_args = {
318+ "client_id" : self .client_id ,
319+ "client_secret" : self .client_secret ,
320+ "url_config" : self .url_config ,
321+ }
322+ if self .scope is not None :
323+ config_args ["scope" ] = self .scope
324+ if self .audience is not None :
325+ config_args ["audience" ] = self .audience
326+
327+ self ._oauth2 = OAuth2Client (
328+ config = ClientCredentialsConfig (** config_args )
329+ )
330+
331+ def get_exceptions (self ) -> Tuple [Any , ...]:
332+ return ()
333+
334+ def __eq__ (self , other : object ) -> bool :
335+ if not isinstance (other , ClientCredentials ):
336+ return False
337+ return (
338+ self .client_id == other .client_id
339+ and self .client_secret == other .client_secret
340+ and self .url_config == other .url_config
341+ )
342+
343+
344+ class DeviceCode (OAuth2TokenAuthentication ):
345+ def __init__ (self ,
346+ client_id : str ,
347+ url_config : Union [OidcConfig , ManualUrlsConfig ],
348+ client_secret : Optional [str ] = None ,
349+ scope : Optional [str ] = None ,
350+ audience : Optional [str ] = None ,
351+ automation_callback : Optional [Callable [[str ], None ]] = None ):
352+
353+ super ().__init__ ()
354+ self .client_id = client_id
355+ self .client_secret = client_secret
356+ self .url_config = url_config
357+ self .scope = scope
358+ self .audience = audience
359+ self .automation_callback = automation_callback
360+
361+ config_args = {
362+ "client_id" : self .client_id ,
363+ "url_config" : self .url_config ,
364+ }
365+ if self .client_secret is not None :
366+ config_args ["client_secret" ] = self .client_secret
367+ if self .scope is not None :
368+ config_args ["scope" ] = self .scope
369+ if self .audience is not None :
370+ config_args ["audience" ] = self .audience
371+ if self .automation_callback is not None :
372+ config_args ["automation_callback" ] = self .automation_callback
373+
374+ self ._oauth2 = OAuth2Client (
375+ config = DeviceCodeConfig (** config_args )
376+ )
377+
378+ def get_exceptions (self ) -> Tuple [Any , ...]:
379+ return ()
380+
381+ def __eq__ (self , other : object ) -> bool :
382+ if not isinstance (other , DeviceCode ):
383+ return False
384+ return (
385+ self .client_id == other .client_id
386+ and self .client_secret == other .client_secret
387+ and self .url_config == other .url_config
388+ )
389+
390+
391+ class AuthorizationCode (OAuth2TokenAuthentication ):
392+ def __init__ (self ,
393+ client_id : str ,
394+ url_config : Union [OidcConfig , ManualUrlsConfig ],
395+ client_secret : Optional [str ] = None ,
396+ scope : Optional [str ] = None ,
397+ audience : Optional [str ] = None ,
398+ automation_callback : Optional [Callable [[str ], None ]] = None ):
399+
400+ super ().__init__ ()
401+ self .client_id = client_id
402+ self .client_secret = client_secret
403+ self .url_config = url_config
404+ self .scope = scope
405+ self .audience = audience
406+ self .automation_callback = automation_callback
407+
408+ config_args = {
409+ "client_id" : self .client_id ,
410+ "url_config" : self .url_config ,
411+ }
412+ if self .client_secret is not None :
413+ config_args ["client_secret" ] = self .client_secret
414+ if self .scope is not None :
415+ config_args ["scope" ] = self .scope
416+ if self .audience is not None :
417+ config_args ["audience" ] = self .audience
418+ if self .automation_callback is not None :
419+ config_args ["automation_callback" ] = self .automation_callback
420+
421+ self ._oauth2 = OAuth2Client (
422+ config = AuthorizationCodeConfig (** config_args )
423+ )
424+
425+ def get_exceptions (self ) -> Tuple [Any , ...]:
426+ return ()
427+
428+ def __eq__ (self , other : object ) -> bool :
429+ if not isinstance (other , DeviceCode ):
430+ return False
431+ return (
432+ self .client_id == other .client_id
433+ and self .client_secret == other .client_secret
434+ and self .url_config == other .url_config
435+ )
436+
437+
279438class RedirectHandler (metaclass = abc .ABCMeta ):
280439 """
281440 Abstract class for OAuth redirect handlers, inherit from this class to implement your own redirect handler.
@@ -292,7 +451,10 @@ class ConsoleRedirectHandler(RedirectHandler):
292451 """
293452
294453 def __call__ (self , url : str ) -> None :
295- print (f"Open the following URL in browser for the external authentication:\n { url } " , flush = True )
454+ print (
455+ f"Open the following URL in browser for the external authentication:\n { url } " ,
456+ flush = True ,
457+ )
296458
297459
298460class WebBrowserRedirectHandler (RedirectHandler ):
0 commit comments