11import asyncio
22import functools
3+ import logging
34import sys
45import typing
56from types import TracebackType
1617else : # pragma: no cover
1718 from aiocontextvars import ContextVar
1819
20+ try : # pragma: no cover
21+ import click
22+
23+ # Extra log info for optional coloured terminal outputs.
24+ LOG_EXTRA = {
25+ "color_message" : "Query: " + click .style ("%s" , bold = True ) + " Args: %s"
26+ }
27+ CONNECT_EXTRA = {
28+ "color_message" : "Connected to database " + click .style ("%s" , bold = True )
29+ }
30+ DISCONNECT_EXTRA = {
31+ "color_message" : "Disconnected from database " + click .style ("%s" , bold = True )
32+ }
33+ except ImportError : # pragma: no cover
34+ LOG_EXTRA = {}
35+ CONNECT_EXTRA = {}
36+ DISCONNECT_EXTRA = {}
37+
38+
39+ logger = logging .getLogger ("databases" )
40+
1941
2042class Database :
2143 SUPPORTED_BACKENDS = {
2244 "postgresql" : "databases.backends.postgres:PostgresBackend" ,
2345 "postgresql+aiopg" : "databases.backends.aiopg:AiopgBackend" ,
46+ "postgres" : "databases.backends.postgres:PostgresBackend" ,
2447 "mysql" : "databases.backends.mysql:MySQLBackend" ,
2548 "sqlite" : "databases.backends.sqlite:SQLiteBackend" ,
2649 }
@@ -51,23 +74,27 @@ def __init__(
5174 self ._global_connection = None # type: typing.Optional[Connection]
5275 self ._global_transaction = None # type: typing.Optional[Transaction]
5376
54- if self ._force_rollback :
55- self ._global_connection = Connection (self ._backend )
56- self ._global_transaction = self ._global_connection .transaction (
57- force_rollback = True
58- )
59-
6077 async def connect (self ) -> None :
6178 """
6279 Establish the connection pool.
6380 """
6481 assert not self .is_connected , "Already connected."
6582
6683 await self ._backend .connect ()
84+ logger .info (
85+ "Connected to database %s" , self .url .obscure_password , extra = CONNECT_EXTRA
86+ )
6787 self .is_connected = True
6888
6989 if self ._force_rollback :
70- assert self ._global_transaction is not None
90+ assert self ._global_connection is None
91+ assert self ._global_transaction is None
92+
93+ self ._global_connection = Connection (self ._backend )
94+ self ._global_transaction = self ._global_connection .transaction (
95+ force_rollback = True
96+ )
97+
7198 await self ._global_transaction .__aenter__ ()
7299
73100 async def disconnect (self ) -> None :
@@ -77,10 +104,20 @@ async def disconnect(self) -> None:
77104 assert self .is_connected , "Already disconnected."
78105
79106 if self ._force_rollback :
107+ assert self ._global_connection is not None
80108 assert self ._global_transaction is not None
109+
81110 await self ._global_transaction .__aexit__ ()
82111
112+ self ._global_transaction = None
113+ self ._global_connection = None
114+
83115 await self ._backend .disconnect ()
116+ logger .info (
117+ "Disconnected from database %s" ,
118+ self .url .obscure_password ,
119+ extra = DISCONNECT_EXTRA ,
120+ )
84121 self .is_connected = False
85122
86123 async def __aenter__ (self ) -> "Database" :
@@ -367,7 +404,10 @@ def netloc(self) -> typing.Optional[str]:
367404
368405 @property
369406 def database (self ) -> str :
370- return self .components .path .lstrip ("/" )
407+ path = self .components .path
408+ if path .startswith ("/" ):
409+ path = path [1 :]
410+ return path
371411
372412 @property
373413 def options (self ) -> dict :
@@ -414,14 +454,17 @@ def replace(self, **kwargs: typing.Any) -> "DatabaseURL":
414454 components = self .components ._replace (** kwargs )
415455 return self .__class__ (components .geturl ())
416456
457+ @property
458+ def obscure_password (self ) -> str :
459+ if self .password :
460+ return self .replace (password = "********" )._url
461+ return self ._url
462+
417463 def __str__ (self ) -> str :
418464 return self ._url
419465
420466 def __repr__ (self ) -> str :
421- url = str (self )
422- if self .password :
423- url = str (self .replace (password = "********" ))
424- return f"{ self .__class__ .__name__ } ({ repr (url )} )"
467+ return f"{ self .__class__ .__name__ } ({ repr (self .obscure_password )} )"
425468
426469 def __eq__ (self , other : typing .Any ) -> bool :
427470 return str (self ) == str (other )
0 commit comments