Skip to content
This repository was archived by the owner on Mar 13, 2026. It is now read-only.

Commit 5f851b8

Browse files
committed
Support Named Schemas
1 parent e17c5ef commit 5f851b8

File tree

2 files changed

+93
-1
lines changed

2 files changed

+93
-1
lines changed

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ class SpannerSQLCompiler(SQLCompiler):
233233

234234
compound_keywords = _compound_keywords
235235

236+
def __init__(self, *args, **kwargs):
237+
self.tablealiases = {}
238+
super().__init__(*args, **kwargs)
239+
236240
def get_from_hint_text(self, _, text):
237241
"""Return a hint text.
238242
@@ -378,8 +382,10 @@ def limit_clause(self, select, **kw):
378382
return text
379383

380384
def returning_clause(self, stmt, returning_cols, **kw):
385+
# Set include_table=False because although table names are allowed in
386+
# RETURNING clauses, schema names are not.
381387
columns = [
382-
self._label_select_column(None, c, True, False, {})
388+
self._label_select_column(None, c, True, False, {}, include_table=False)
383389
for c in expression._select_iterables(returning_cols)
384390
]
385391

@@ -391,6 +397,66 @@ def visit_sequence(self, seq, **kw):
391397
seq
392398
)
393399

400+
def visit_table(self, table, spanner_aliased=False, iscrud=False, **kwargs):
401+
"""Build the table name.
402+
403+
Schema names are not allowed in Spanner SELECT statements. When selecting
404+
from a schema-qualified table, alias the table to produce SQL like:
405+
406+
SELECT tbl_1.id, tbl_1.col
407+
FROM schema.tbl AS tbl_1
408+
"""
409+
# This closely code mirrors the mssql dialect which also
410+
# avoids schema-qualified columns in SELECTs, although the
411+
# behaviour is currently behind a deprecated
412+
# 'legacy_schema_aliasing' flag.
413+
if spanner_aliased is table or iscrud:
414+
return super().visit_table(table, **kwargs)
415+
416+
# alias schema-qualified tables
417+
alias = self._schema_aliased_table(table)
418+
if alias is not None:
419+
return self.process(alias, spanner_aliased=table, **kwargs)
420+
else:
421+
return super().visit_table(table, **kwargs)
422+
423+
def visit_alias(self, alias, **kw):
424+
# translate for schema-qualified table aliases
425+
kw["spanner_aliased"] = alias.element
426+
return super().visit_alias(alias, **kw)
427+
428+
def visit_column(self, column, add_to_result_map=None, **kw):
429+
if (
430+
column.table is not None
431+
and (not self.isupdate and not self.isdelete and not self.isinsert)
432+
or self.is_subquery()
433+
):
434+
# translate for schema-qualified table aliases
435+
t = self._schema_aliased_table(column.table)
436+
if t is not None:
437+
converted = elements._corresponding_column_or_error(t, column)
438+
if add_to_result_map is not None:
439+
add_to_result_map(
440+
column.name,
441+
column.name,
442+
(column, column.name, column.key),
443+
column.type,
444+
)
445+
446+
return super().visit_column(converted, **kw)
447+
448+
return super().visit_column(
449+
column, add_to_result_map=add_to_result_map, **kw
450+
)
451+
452+
def _schema_aliased_table(self, table):
453+
if getattr(table, "schema", None) is not None:
454+
if table not in self.tablealiases:
455+
self.tablealiases[table] = table.alias()
456+
return self.tablealiases[table]
457+
else:
458+
return None
459+
394460

395461
class SpannerDDLCompiler(DDLCompiler):
396462
"""Spanner DDL statements compiler."""

test/system/test_basics.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ def define_tables(cls, metadata):
5858
Column("name", String(20)),
5959
)
6060

61+
with cls.bind.begin() as conn:
62+
conn.execute(text('CREATE SCHEMA IF NOT EXISTS schema'))
63+
Table(
64+
"users",
65+
metadata,
66+
Column("ID", Integer, primary_key=True),
67+
Column("name", String(20)),
68+
schema="schema"
69+
)
70+
6171
def test_hello_world(self, connection):
6272
greeting = connection.execute(text("select 'Hello World'"))
6373
eq_("Hello World", greeting.fetchone()[0])
@@ -139,6 +149,12 @@ class User(Base):
139149
ID: Mapped[int] = mapped_column(primary_key=True)
140150
name: Mapped[str] = mapped_column(String(20))
141151

152+
class SchemaUser(Base):
153+
__tablename__ = "users"
154+
__table_args__ = {'schema': 'schema'}
155+
ID: Mapped[int] = mapped_column(primary_key=True)
156+
name: Mapped[str] = mapped_column(String(20))
157+
142158
engine = connection.engine
143159
with Session(engine) as session:
144160
number = Number(
@@ -156,3 +172,13 @@ class User(Base):
156172
users = session.scalars(statement).all()
157173
eq_(1, len(users))
158174
is_true(users[0].ID > 0)
175+
176+
with Session(engine) as session:
177+
user = SchemaUser(name="SchemaTest")
178+
session.add(user)
179+
session.commit()
180+
181+
statement = select(SchemaUser).filter_by(name="SchemaTest")
182+
users = session.scalars(statement).all()
183+
eq_(1, len(users))
184+
is_true(users[0].ID > 0)

0 commit comments

Comments
 (0)