Skip to content

Commit 360e0f7

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add database_role property to SpannerToolSettings and use it in execute_sql to support fine grained access controls
PiperOrigin-RevId: 884166439
1 parent 6c34694 commit 360e0f7

File tree

4 files changed

+84
-5
lines changed

4 files changed

+84
-5
lines changed

src/google/adk/tools/spanner/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,5 +263,8 @@ class SpannerToolSettings(BaseModel):
263263
query_result_mode: QueryResultMode = QueryResultMode.DEFAULT
264264
"""Mode for Spanner execute sql query result."""
265265

266+
database_role: Optional[str] = None
267+
"""Optional. The database role to use for the Spanner session."""
268+
266269
vector_store_settings: Optional[SpannerVectorStoreSettings] = None
267270
"""Settings for Spanner vector store and vector similarity search."""

src/google/adk/tools/spanner/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def execute_sql(
8282
project=project_id, credentials=credentials
8383
)
8484
instance = spanner_client.instance(instance_id)
85-
database = instance.database(database_id)
85+
database = instance.database(
86+
database_id, database_role=settings.database_role
87+
)
8688

8789
if database.database_dialect == DatabaseDialect.POSTGRESQL:
8890
return {
@@ -244,7 +246,10 @@ def __init__(
244246
self._vector_store_settings.instance_id
245247
)
246248
)
247-
self._database = instance.database(self._vector_store_settings.database_id)
249+
self._database = instance.database(
250+
self._vector_store_settings.database_id,
251+
database_role=self._settings.database_role,
252+
)
248253
if not self._database.exists():
249254
raise ValueError(
250255
"Database id {} doesn't exist.".format(

tests/unittests/tools/spanner/test_spanner_tool_settings.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ def test_spanner_vector_store_settings_invalid_vector_length():
8383

8484

8585
@pytest.mark.parametrize(
86-
"settings_args, expected_rows, expected_mode",
86+
"settings_args, expected_rows, expected_mode, expected_role",
8787
[
88-
({}, 50, QueryResultMode.DEFAULT),
88+
({}, 50, QueryResultMode.DEFAULT, None),
8989
(
9090
{
9191
"capabilities": [Capabilities.DATA_READ],
@@ -94,12 +94,22 @@ def test_spanner_vector_store_settings_invalid_vector_length():
9494
},
9595
100,
9696
QueryResultMode.DICT_LIST,
97+
None,
98+
),
99+
(
100+
{"database_role": "test-role"},
101+
50,
102+
QueryResultMode.DEFAULT,
103+
"test-role",
97104
),
98105
],
99106
)
100-
def test_spanner_tool_settings(settings_args, expected_rows, expected_mode):
107+
def test_spanner_tool_settings(
108+
settings_args, expected_rows, expected_mode, expected_role
109+
):
101110
"""Test SpannerToolSettings with different values."""
102111
settings = SpannerToolSettings(**settings_args)
103112
assert settings.capabilities == [Capabilities.DATA_READ]
104113
assert settings.max_executed_query_result_rows == expected_rows
105114
assert settings.query_result_mode == expected_mode
115+
assert settings.database_role == expected_role

tests/unittests/tools/spanner/test_utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,64 @@ def test_create_vector_search_index_fails(
411411
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
412412
with pytest.raises(RuntimeError, match="DDL failed"):
413413
vector_store.create_vector_search_index()
414+
415+
416+
@mock.patch.object(spanner_utils.client, "get_spanner_client", autospec=True)
417+
def test_execute_sql_with_database_role(mock_get_spanner_client):
418+
"""Test that execute_sql passes database_role to instance.database."""
419+
mock_spanner_client = mock.MagicMock()
420+
mock_instance = mock.MagicMock()
421+
mock_database = mock.MagicMock()
422+
mock_snapshot = mock.MagicMock()
423+
424+
mock_snapshot.execute_sql.return_value = iter([["row1"]])
425+
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
426+
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
427+
mock_instance.database.return_value = mock_database
428+
mock_spanner_client.instance.return_value = mock_instance
429+
mock_get_spanner_client.return_value = mock_spanner_client
430+
431+
database_role = "test-role"
432+
settings = SpannerToolSettings(database_role=database_role)
433+
434+
spanner_utils.execute_sql(
435+
project_id="test-project",
436+
instance_id="test-instance",
437+
database_id="test-database",
438+
query="SELECT 1",
439+
credentials=mock.Mock(),
440+
settings=settings,
441+
tool_context=mock.Mock(),
442+
)
443+
444+
mock_instance.database.assert_called_once_with(
445+
"test-database", database_role=database_role
446+
)
447+
448+
449+
@mock.patch.object(spanner_utils.client, "get_spanner_client", autospec=True)
450+
def test_spanner_vector_store_with_database_role(
451+
mock_get_spanner_client, vector_store_settings
452+
):
453+
"""Test that SpannerVectorStore passes database_role to instance.database."""
454+
mock_spanner_client = mock.MagicMock()
455+
mock_instance = mock.MagicMock()
456+
mock_database = mock.MagicMock()
457+
458+
mock_instance.database.return_value = mock_database
459+
mock_instance.exists.return_value = True
460+
mock_database.exists.return_value = True
461+
mock_spanner_client.instance.return_value = mock_instance
462+
mock_get_spanner_client.return_value = mock_spanner_client
463+
mock_spanner_client._client_info = mock.Mock(user_agent="test-agent")
464+
465+
database_role = "test-role"
466+
settings = SpannerToolSettings(
467+
database_role=database_role, vector_store_settings=vector_store_settings
468+
)
469+
470+
spanner_utils.SpannerVectorStore(settings)
471+
472+
mock_instance.database.assert_called_once_with(
473+
"test-database", database_role=database_role
474+
)

0 commit comments

Comments
 (0)