11import asyncio
22import enum
33import functools
4+ import ssl
45import os
56from typing import Optional , Union
67
78from .api import Api
9+ from .const import Transport
810from .exceptions import TarantoolDatabaseError , \
9- ErrorCode , TarantoolError
11+ ErrorCode , TarantoolError , SSLError
1012from .iproto import protocol
1113from .log import logger
1214from .stream import Stream
@@ -27,8 +29,10 @@ class ConnectionState(enum.IntEnum):
2729
2830class Connection (Api ):
2931 __slots__ = (
30- '_host' , '_port' , '_username' , '_password' ,
31- '_fetch_schema' , '_auto_refetch_schema' , '_initial_read_buffer_size' ,
32+ '_host' , '_port' , '_transport' , '_ssl_key_file' ,
33+ '_ssl_cert_file' , '_ssl_ca_file' , '_ssl_ciphers' ,
34+ '_username' , '_password' , '_fetch_schema' ,
35+ '_auto_refetch_schema' , '_initial_read_buffer_size' ,
3236 '_encoding' , '_connect_timeout' , '_reconnect_timeout' ,
3337 '_request_timeout' , '_ping_timeout' , '_loop' , '_state' , '_state_prev' ,
3438 '_connection_transport' , '_protocol' ,
@@ -40,6 +44,11 @@ class Connection(Api):
4044 def __init__ (self , * ,
4145 host : str = '127.0.0.1' ,
4246 port : Union [int , str ] = 3301 ,
47+ transport : Optional [Transport ] = Transport .DEFAULT ,
48+ ssl_key_file : Optional [str ] = None ,
49+ ssl_cert_file : Optional [str ] = None ,
50+ ssl_ca_file : Optional [str ] = None ,
51+ ssl_ciphers : Optional [str ] = None ,
4352 username : Optional [str ] = None ,
4453 password : Optional [str ] = None ,
4554 fetch_schema : bool = True ,
@@ -78,6 +87,22 @@ def __init__(self, *,
7887 :param port:
7988 Tarantool port
8089 (pass ``/path/to/sockfile`` to connect ot unix socket)
90+ :param transport:
91+ This parameter can be used to configure traffic encryption.
92+ Pass ``asynctnt.Transport.SSL`` value to enable SSL
93+ encryption (by default there is no encryption)
94+ :param ssl_key_file:
95+ A path to a private SSL key file.
96+ Optional, mandatory if server uses CA file
97+ :param ssl_cert_file:
98+ A path to an SSL certificate file.
99+ Optional, mandatory if server uses CA file
100+ :param ssl_ca_file:
101+ A path to a trusted certificate authorities (CA) file.
102+ Optional
103+ :param ssl_ciphers:
104+ A colon-separated (:) list of SSL cipher suites
105+ the connection can use. Optional
81106 :param username:
82107 Username to use for auth
83108 (if ``None`` you are connected as a guest)
@@ -116,6 +141,13 @@ def __init__(self, *,
116141 super ().__init__ ()
117142 self ._host = host
118143 self ._port = port
144+
145+ self ._transport = transport
146+ self ._ssl_key_file = ssl_key_file
147+ self ._ssl_cert_file = ssl_cert_file
148+ self ._ssl_ca_file = ssl_ca_file
149+ self ._ssl_ciphers = ssl_ciphers
150+
119151 self ._username = username
120152 self ._password = password
121153 self ._fetch_schema = False if fetch_schema is None else fetch_schema
@@ -220,6 +252,54 @@ def protocol_factory(self,
220252 on_connection_lost = self .connection_lost ,
221253 loop = self ._loop )
222254
255+ def _create_ssl_context (self ):
256+ try :
257+ if hasattr (ssl , 'TLSVersion' ):
258+ # Since python 3.7
259+ context = ssl .SSLContext (ssl .PROTOCOL_TLS_CLIENT )
260+ # Reset to default OpenSSL values.
261+ context .check_hostname = False
262+ context .verify_mode = ssl .CERT_NONE
263+ # Require TLSv1.2, because other protocol versions don't seem
264+ # to support the GOST cipher.
265+ context .minimum_version = ssl .TLSVersion .TLSv1_2
266+ context .maximum_version = ssl .TLSVersion .TLSv1_2
267+ else :
268+ # Deprecated, but it works for python < 3.7
269+ context = ssl .SSLContext (ssl .PROTOCOL_TLSv1_2 )
270+
271+ if self ._ssl_cert_file :
272+ # If the password argument is not specified and a password is
273+ # required, OpenSSL’s built-in password prompting mechanism
274+ # will be used to interactively prompt the user for a password.
275+ #
276+ # We should disable this behaviour, because a python
277+ # application that uses the connector unlikely assumes
278+ # interaction with a human + a Tarantool implementation does
279+ # not support this at least for now.
280+ def password_raise_error ():
281+ raise SSLError ("a password for decrypting the private " +
282+ "key is unsupported" )
283+ context .load_cert_chain (certfile = self ._ssl_cert_file ,
284+ keyfile = self ._ssl_key_file ,
285+ password = password_raise_error )
286+
287+ if self ._ssl_ca_file :
288+ context .load_verify_locations (cafile = self ._ssl_ca_file )
289+ context .verify_mode = ssl .CERT_REQUIRED
290+ # A Tarantool implementation does not check hostname. We don't
291+ # do that too. As a result we don't set here:
292+ # context.check_hostname = True
293+
294+ if self ._ssl_ciphers :
295+ context .set_ciphers (self ._ssl_ciphers )
296+
297+ return context
298+ except SSLError as e :
299+ raise
300+ except Exception as e :
301+ raise SSLError (e )
302+
223303 async def _connect (self , return_exceptions : bool = True ):
224304 if self ._loop is None :
225305 self ._loop = get_running_loop ()
@@ -246,6 +326,10 @@ async def full_connect():
246326 while True :
247327 connected_fut = _create_future (self ._loop )
248328
329+ ssl_context = None
330+ if self ._transport == Transport .SSL :
331+ ssl_context = self ._create_ssl_context ()
332+
249333 if self ._host .startswith ('unix/' ):
250334 unix_path = self ._port
251335 assert isinstance (unix_path , str ), \
@@ -260,13 +344,14 @@ async def full_connect():
260344 conn = self ._loop .create_unix_connection (
261345 functools .partial (self .protocol_factory ,
262346 connected_fut ),
263- unix_path
264- )
347+ unix_path ,
348+ ssl = ssl_context )
265349 else :
266350 conn = self ._loop .create_connection (
267351 functools .partial (self .protocol_factory ,
268352 connected_fut ),
269- self ._host , self ._port )
353+ self ._host , self ._port ,
354+ ssl = ssl_context )
270355
271356 tr , pr = await conn
272357
@@ -330,6 +415,8 @@ async def full_connect():
330415 logger .debug ("connect is cancelled" )
331416 self ._reconnect_task = None
332417 raise
418+ except ssl .SSLError as e :
419+ raise SSLError (e )
333420 except Exception as e :
334421 if self ._reconnect_timeout > 0 :
335422 await self ._wait_reconnect (e )
0 commit comments