Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 8e2cf04

Browse files
author
Ilya Gurov
authored
refactor(db_api): cleanup the code (#636)
1 parent d760c2c commit 8e2cf04

6 files changed

Lines changed: 127 additions & 174 deletions

File tree

google/cloud/spanner_dbapi/_helpers.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,16 @@
1919

2020

2121
SQL_LIST_TABLES = """
22-
SELECT
23-
t.table_name
24-
FROM
25-
information_schema.tables AS t
26-
WHERE
27-
t.table_catalog = '' and t.table_schema = ''
28-
"""
29-
30-
SQL_GET_TABLE_COLUMN_SCHEMA = """SELECT
31-
COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE
32-
FROM
33-
INFORMATION_SCHEMA.COLUMNS
34-
WHERE
35-
TABLE_SCHEMA = ''
36-
AND
37-
TABLE_NAME = @table_name
38-
"""
22+
SELECT table_name
23+
FROM information_schema.tables
24+
WHERE table_catalog = '' AND table_schema = ''
25+
"""
26+
27+
SQL_GET_TABLE_COLUMN_SCHEMA = """
28+
SELECT COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE
29+
FROM INFORMATION_SCHEMA.COLUMNS
30+
WHERE TABLE_SCHEMA = '' AND TABLE_NAME = @table_name
31+
"""
3932

4033
# This table maps spanner_types to Spanner's data type sizes as per
4134
# https://cloud.google.com/spanner/docs/data-types#allowable-types
@@ -64,10 +57,9 @@ def _execute_insert_heterogenous(transaction, sql_params_list):
6457

6558
def _execute_insert_homogenous(transaction, parts):
6659
# Perform an insert in one shot.
67-
table = parts.get("table")
68-
columns = parts.get("columns")
69-
values = parts.get("values")
70-
return transaction.insert(table, columns, values)
60+
return transaction.insert(
61+
parts.get("table"), parts.get("columns"), parts.get("values")
62+
)
7163

7264

7365
def handle_insert(connection, sql, params):

google/cloud/spanner_dbapi/connection.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,23 @@
4040
MAX_INTERNAL_RETRIES = 50
4141

4242

43+
def check_not_closed(function):
44+
"""`Connection` class methods decorator.
45+
46+
Raise an exception if the connection is closed.
47+
48+
:raises: :class:`InterfaceError` if the connection is closed.
49+
"""
50+
51+
def wrapper(connection, *args, **kwargs):
52+
if connection.is_closed:
53+
raise InterfaceError("Connection is already closed")
54+
55+
return function(connection, *args, **kwargs)
56+
57+
return wrapper
58+
59+
4360
class Connection:
4461
"""Representation of a DB-API connection to a Cloud Spanner database.
4562
@@ -328,15 +345,6 @@ def snapshot_checkout(self):
328345

329346
return self._snapshot
330347

331-
def _raise_if_closed(self):
332-
"""Helper to check the connection state before running a query.
333-
Raises an exception if this connection is closed.
334-
335-
:raises: :class:`InterfaceError`: if this connection is closed.
336-
"""
337-
if self.is_closed:
338-
raise InterfaceError("connection is already closed")
339-
340348
def close(self):
341349
"""Closes this connection.
342350
@@ -391,15 +399,13 @@ def rollback(self):
391399
self._release_session()
392400
self._statements = []
393401

402+
@check_not_closed
394403
def cursor(self):
395-
"""Factory to create a DB-API Cursor."""
396-
self._raise_if_closed()
397-
404+
"""Factory to create a DB API Cursor."""
398405
return Cursor(self)
399406

407+
@check_not_closed
400408
def run_prior_DDL_statements(self):
401-
self._raise_if_closed()
402-
403409
if self._ddl_statements:
404410
ddl_statements = self._ddl_statements
405411
self._ddl_statements = []
@@ -454,6 +460,7 @@ def run_statement(self, statement, retried=False):
454460
ResultsChecksum() if retried else statement.checksum,
455461
)
456462

463+
@check_not_closed
457464
def validate(self):
458465
"""
459466
Execute a minimal request to check if the connection
@@ -468,8 +475,6 @@ def validate(self):
468475
:raises: :class:`google.cloud.exceptions.NotFound`: if the linked instance
469476
or database doesn't exist.
470477
"""
471-
self._raise_if_closed()
472-
473478
with self.database.snapshot() as snapshot:
474479
result = list(snapshot.execute_sql("SELECT 1"))
475480
if result != [[1]]:

0 commit comments

Comments
 (0)