Skip to content

Commit ce50045

Browse files
authored
Python driver: Add skip_load parameter to skip LOAD 'age' statement (#2366)
* Add skip_load parameter to connect and setUpAge functions to control plugin loading. * Add tests and README docs for skip_load parameter - Add 4 unit tests for setUpAge() skip_load behavior: skip_load=True skips LOAD, skip_load=False executes LOAD, load_from_plugins integration, and search_path always set. - Document skip_load in README under new "Managed PostgreSQL Usage" section for Azure/AWS RDS/etc. environments. - Fix syntax in existing load_from_plugins code example. Made-with: Cursor * Address review feedback: ValueError for contradictory flags, e2e test - Raise ValueError when skip_load=True and load_from_plugins=True are both set (contradictory combination) - Add end-to-end test verifying skip_load is forwarded through the full age.connect() → Age.connect() → setUpAge() call chain - Replace fragile string assertions with assert_called_with/assert_any_call - README: mention configure_connection() as the pool-based alternative for managed PostgreSQL environments Made-with: Cursor * Fix README: use setUpAge(skip_load=True) for pool example configure_connection() is not part of this PR; use the available setUpAge() API with skip_load=True for the connection pool example. Made-with: Cursor * fix(python-driver): use quoted $user in search_path and run TestSetUpAge in CI PostgreSQL treats single-quoted '$user' as a literal schema name; use "$user" so the session user schema is included. Include TestSetUpAge in the test_age_py __main__ suite so skip_load tests run in GitHub Actions. Made-with: Cursor
1 parent f1a9b1d commit ce50045

4 files changed

Lines changed: 109 additions & 12 deletions

File tree

drivers/python/README.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,25 @@ SET search_path = ag_catalog, "$user", public;
8989
* Make sure to give your non-superuser db account proper permissions to the graph schemas and corresponding objects
9090
* Make sure to initiate the Apache Age python driver with the ```load_from_plugins``` parameter. This parameter tries to
9191
load the Apache Age extension from the PostgreSQL plugins directory located at ```$libdir/plugins/age```. Example:
92-
```python.
92+
```python
9393
ag = age.connect(host='localhost', port=5432, user='dbuser', password='strong_password',
94-
dbname=postgres, load_from_plugins=True, graph='graph_name)
94+
dbname='postgres', load_from_plugins=True, graph='graph_name')
95+
```
96+
97+
### Managed PostgreSQL Usage (Azure, AWS RDS, etc.)
98+
* On managed PostgreSQL services where the AGE extension is loaded server-side via ```shared_preload_libraries```,
99+
the ```LOAD 'age'``` command may fail because the binary is not at the expected file path. Use the ```skip_load```
100+
parameter to skip the ```LOAD``` statement while still performing all other setup:
101+
```python
102+
ag = age.connect(host='myserver.postgres.database.azure.com', port=5432,
103+
user='dbuser', password='strong_password',
104+
dbname='postgres', skip_load=True, graph='graph_name')
105+
```
106+
* **Connection pools:** If you manage connections externally (e.g. via ```psycopg_pool.ConnectionPool```),
107+
you can call ```setUpAge()``` with ```skip_load=True``` on each pooled connection:
108+
```python
109+
from age.age import setUpAge
110+
setUpAge(conn, 'graph_name', skip_load=True)
95111
```
96112

97113
### License

drivers/python/age/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ def version():
2525

2626

2727
def connect(dsn=None, graph=None, connection_factory=None, cursor_factory=ClientCursor, load_from_plugins=False,
28-
**kwargs):
28+
skip_load=False, **kwargs):
2929

3030
dsn = conninfo.make_conninfo('' if dsn is None else dsn, **kwargs)
3131

3232
ag = Age()
3333
ag.connect(dsn=dsn, graph=graph, connection_factory=connection_factory, cursor_factory=cursor_factory,
34-
load_from_plugins=load_from_plugins, **kwargs)
34+
load_from_plugins=load_from_plugins, skip_load=skip_load, **kwargs)
3535
return ag
3636

3737
# Dummy ResultHandler

drivers/python/age/age.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,22 @@ def load(self, data: bytes | bytearray | memoryview) -> Any | None:
137137
return parseAgeValue(data_bytes.decode('utf-8'))
138138

139139

140-
def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=False):
140+
def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=False, skip_load:bool=False):
141+
if skip_load and load_from_plugins:
142+
raise ValueError(
143+
"skip_load=True and load_from_plugins=True are contradictory. "
144+
"Set skip_load=False to load the extension from the plugins path, "
145+
"or remove load_from_plugins to skip loading entirely."
146+
)
147+
141148
with conn.cursor() as cursor:
142-
if load_from_plugins:
143-
cursor.execute("LOAD '$libdir/plugins/age';")
144-
else:
145-
cursor.execute("LOAD 'age';")
149+
if not skip_load:
150+
if load_from_plugins:
151+
cursor.execute("LOAD '$libdir/plugins/age';")
152+
else:
153+
cursor.execute("LOAD 'age';")
146154

147-
cursor.execute("SET search_path = ag_catalog, '$user', public;")
155+
cursor.execute('SET search_path = ag_catalog, "$user", public;')
148156

149157
ag_info = TypeInfo.fetch(conn, 'agtype')
150158

@@ -333,9 +341,9 @@ def __init__(self):
333341

334342
# Connect to PostgreSQL Server and establish session and type extension environment.
335343
def connect(self, graph:str=None, dsn:str=None, connection_factory=None, cursor_factory=ClientCursor,
336-
load_from_plugins:bool=False, **kwargs):
344+
load_from_plugins:bool=False, skip_load:bool=False, **kwargs):
337345
conn = psycopg.connect(dsn, cursor_factory=cursor_factory, **kwargs)
338-
setUpAge(conn, graph, load_from_plugins)
346+
setUpAge(conn, graph, load_from_plugins, skip_load=skip_load)
339347
self.connection = conn
340348
self.graphName = graph
341349
return self

drivers/python/test_age_py.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from age.models import Vertex
1818
import unittest
19+
import unittest.mock
1920
import decimal
2021
import age
2122
import argparse
@@ -28,6 +29,76 @@
2829
TEST_GRAPH_NAME = "test_graph"
2930

3031

32+
class TestSetUpAge(unittest.TestCase):
33+
"""Unit tests for setUpAge() skip_load parameter — no DB required."""
34+
35+
def _make_mock_conn(self):
36+
mock_conn = unittest.mock.MagicMock()
37+
mock_cursor = unittest.mock.MagicMock()
38+
mock_conn.cursor.return_value.__enter__ = unittest.mock.Mock(return_value=mock_cursor)
39+
mock_conn.cursor.return_value.__exit__ = unittest.mock.Mock(return_value=False)
40+
mock_conn.adapters = unittest.mock.MagicMock()
41+
mock_type_info = unittest.mock.MagicMock()
42+
mock_type_info.oid = 1
43+
mock_type_info.array_oid = 2
44+
return mock_conn, mock_cursor, mock_type_info
45+
46+
def test_skip_load_true_does_not_execute_load(self):
47+
"""When skip_load=True, LOAD 'age' must not be executed."""
48+
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
49+
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
50+
unittest.mock.patch("age.age.checkGraphCreated"):
51+
age.age.setUpAge(mock_conn, "test_graph", skip_load=True)
52+
mock_cursor.execute.assert_called_once_with(
53+
'SET search_path = ag_catalog, "$user", public;'
54+
)
55+
56+
def test_skip_load_false_executes_load(self):
57+
"""When skip_load=False (default), LOAD 'age' must be executed."""
58+
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
59+
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
60+
unittest.mock.patch("age.age.checkGraphCreated"):
61+
age.age.setUpAge(mock_conn, "test_graph", skip_load=False)
62+
mock_cursor.execute.assert_any_call("LOAD 'age';")
63+
64+
def test_skip_load_with_load_from_plugins(self):
65+
"""When skip_load=False and load_from_plugins=True, LOAD from plugins path."""
66+
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
67+
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
68+
unittest.mock.patch("age.age.checkGraphCreated"):
69+
age.age.setUpAge(mock_conn, "test_graph", load_from_plugins=True, skip_load=False)
70+
mock_cursor.execute.assert_any_call("LOAD '$libdir/plugins/age';")
71+
72+
def test_skip_load_true_still_sets_search_path(self):
73+
"""When skip_load=True, search_path must still be set."""
74+
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
75+
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
76+
unittest.mock.patch("age.age.checkGraphCreated"):
77+
age.age.setUpAge(mock_conn, "test_graph", skip_load=True)
78+
mock_cursor.execute.assert_any_call(
79+
'SET search_path = ag_catalog, "$user", public;'
80+
)
81+
82+
def test_contradictory_skip_load_and_load_from_plugins_raises(self):
83+
"""skip_load=True + load_from_plugins=True must raise ValueError."""
84+
mock_conn, _, _ = self._make_mock_conn()
85+
with self.assertRaises(ValueError):
86+
age.age.setUpAge(mock_conn, "test_graph", load_from_plugins=True, skip_load=True)
87+
88+
def test_connect_forwards_skip_load_to_setup(self):
89+
"""age.connect(skip_load=True) must forward skip_load through the full call chain."""
90+
with unittest.mock.patch("age.age.psycopg.connect") as mock_psycopg, \
91+
unittest.mock.patch("age.age.setUpAge") as mock_setup:
92+
mock_psycopg.return_value = unittest.mock.MagicMock()
93+
age.connect(dsn="host=localhost", graph="test_graph", skip_load=True)
94+
mock_setup.assert_called_once()
95+
_, kwargs = mock_setup.call_args
96+
self.assertTrue(
97+
kwargs.get("skip_load", False),
98+
"skip_load must be forwarded from age.connect() to setUpAge()"
99+
)
100+
101+
31102
class TestAgeBasic(unittest.TestCase):
32103
ag = None
33104
args: argparse.Namespace = argparse.Namespace(
@@ -485,6 +556,8 @@ def testSerialization(self):
485556

486557
args = parser.parse_args()
487558
suite = unittest.TestSuite()
559+
loader = unittest.TestLoader()
560+
suite.addTests(loader.loadTestsFromTestCase(TestSetUpAge))
488561
suite.addTest(TestAgeBasic("testExec"))
489562
suite.addTest(TestAgeBasic("testQuery"))
490563
suite.addTest(TestAgeBasic("testChangeData"))

0 commit comments

Comments
 (0)