Skip to content

Commit 01ee941

Browse files
authored
feat(python-driver): add public API for connection pooling and model dict conversion (#2374)
* feat(python-driver): add public API for connection pooling and model dict conversion Add two enhancements to the Python driver's public API: 1. Add configure_connection() function that registers AGE agtype adapters on an existing psycopg connection without creating a new one. This enables use with external connection pools (e.g. psycopg_pool) and managed PostgreSQL services where LOAD 'age' may be restricted. Also explicitly export AgeLoader and ClientCursor as public symbols in age/__init__.py. (#2369) 2. Add to_dict() methods to Vertex, Edge, and Path model classes for conversion to plain Python dicts. This enables direct JSON serialization with json.dumps() without requiring custom conversion logic. (#2371) - Vertex.to_dict() returns {id, label, properties} - Edge.to_dict() returns {id, label, start_id, end_id, properties} - Path.to_dict() returns a list of to_dict() results Closes #2369 Closes #2371 * Fix configure_connection: correct parameter semantics, add load_from_plugins - Replace confusing `skip_load` (double-negative) with `load` (positive boolean, default False). The default now correctly matches the intent: no LOAD by default for connection pool / managed PostgreSQL use cases. - Add `load_from_plugins` parameter for parity with setUpAge(). - Fix docstring to accurately describe parameter behavior. - Add 6 unit tests for configure_connection covering: default no-load, explicit load, load_from_plugins, search_path always set, adapter registration, and graph_name check delegation. Made-with: Cursor * Address review feedback for configure_connection and to_dict - Move TypeInfo.fetch() inside cursor block so search_path change is visible regardless of transaction isolation mode - Raise ValueError when load_from_plugins=True but load=False - Add type annotations to configure_connection signature - Document shallow-copy semantics in Vertex/Edge to_dict() - Path.to_dict() uses str() fallback for non-AGObj entities to guarantee JSON-serializable output - Add test for AgeNotSet when TypeInfo.fetch returns None - Add test for load_from_plugins=True without load=True - Replace fragile string assertions with assert_called_with/assert_any_call Made-with: Cursor * Fix Path.to_dict() to preserve JSON-native types, add tests to suite - Path.to_dict(): leave dict/list/str/int/float/bool/None unchanged instead of converting to str(); handle entities=None safely - Add TestModelToDict, TestPublicImports, TestConfigureConnection to the __main__ suite so they run via direct script execution Made-with: Cursor * fix(python-driver): Python 3.9-safe hints and correct $user in search_path - Use Optional[str] for configure_connection graph_name (PEP 604 unions are invalid on Python 3.9). - Import Any/Optional from typing for annotations. - Quote $user in SET search_path; align unit test expectations. Made-with: Cursor * test(python-driver): add configure_connection + to_dict integration test Existing tests for the new public API are unit-only: - TestConfigureConnection mocks the psycopg connection, so it never proves that AgeLoader actually registers against real agtype OIDs. - TestModelToDict hand-constructs Vertex/Edge/Path via kwargs, so it never serialises objects produced by the ANTLR parser. Add a single TestAgeBasic.testConfigureConnection that: - opens a raw psycopg connection (bypassing age.connect()), - calls configure_connection(..., load=True) on it, - runs a Cypher CREATE/RETURN through the configured connection, - asserts the returned values are real Vertex/Edge instances and that their to_dict() output is JSON-serialisable with the expected label/start_id/end_id/properties shape, - repeats the round-trip for a Path returned by MATCH. This is the smallest test that proves the configure_connection + to_dict pipeline works end-to-end against a live AGE database. Made-with: Cursor
1 parent a1b749a commit 01ee941

4 files changed

Lines changed: 341 additions & 1 deletion

File tree

drivers/python/age/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import psycopg.conninfo as conninfo
1717
from . import age
1818
from .age import *
19+
from .age import AgeLoader, ClientCursor, configure_connection
1920
from .models import *
2021
from .builder import ResultHandler, DummyResultHandler, parseAgeValue, newResultHandler
2122
from . import VERSION

drivers/python/age/age.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# under the License.
1515

1616
import re
17+
from typing import Any, Optional
18+
1719
import psycopg
1820
from psycopg.types import TypeInfo
1921
from psycopg import sql
@@ -170,6 +172,71 @@ def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=Fals
170172
if graphName != None:
171173
checkGraphCreated(conn, graphName)
172174

175+
176+
def configure_connection(
177+
conn: psycopg.connection,
178+
graph_name: Optional[str] = None,
179+
load: bool = False,
180+
load_from_plugins: bool = False,
181+
) -> None:
182+
"""Register AGE agtype adapters on an existing connection.
183+
184+
This enables use of AGE with externally-managed connections, such as
185+
those from psycopg_pool.ConnectionPool. By default the function does
186+
**not** execute ``LOAD 'age'``, making it safe for managed PostgreSQL
187+
services (Azure, AWS RDS) where the extension is pre-loaded via
188+
``shared_preload_libraries``.
189+
190+
Performs:
191+
- ``SET search_path`` to include ``ag_catalog``
192+
- Fetches agtype OIDs and registers ``AgeLoader``
193+
- Optionally loads the AGE extension (``load=True``)
194+
- Optionally checks/creates the graph
195+
196+
Args:
197+
conn: An existing psycopg connection.
198+
graph_name: Optional graph name to check/create.
199+
load: If True, execute ``LOAD 'age'`` (or the plugins path).
200+
Default False — suitable for environments where AGE is
201+
already loaded.
202+
load_from_plugins: If True (and ``load=True``), use
203+
``LOAD '$libdir/plugins/age'`` instead of ``LOAD 'age'``.
204+
205+
Raises:
206+
ValueError: If ``load_from_plugins=True`` but ``load=False``.
207+
AgeNotSet: If the agtype type is not found in the database.
208+
"""
209+
if load_from_plugins and not load:
210+
raise ValueError(
211+
"load_from_plugins=True requires load=True. "
212+
"Set load=True to enable extension loading."
213+
)
214+
215+
with conn.cursor() as cursor:
216+
if load:
217+
if load_from_plugins:
218+
cursor.execute("LOAD '$libdir/plugins/age';")
219+
else:
220+
cursor.execute("LOAD 'age';")
221+
222+
cursor.execute('SET search_path = ag_catalog, "$user", public;')
223+
224+
ag_info = TypeInfo.fetch(conn, 'agtype')
225+
226+
if not ag_info:
227+
raise AgeNotSet(
228+
"AGE agtype type not found. Ensure the AGE extension is "
229+
"installed and loaded in the current database. "
230+
"Run CREATE EXTENSION age; first."
231+
)
232+
233+
conn.adapters.register_loader(ag_info.oid, AgeLoader)
234+
conn.adapters.register_loader(ag_info.array_oid, AgeLoader)
235+
236+
if graph_name is not None:
237+
checkGraphCreated(conn, graph_name)
238+
239+
173240
# Create the graph, if it does not exist
174241
def checkGraphCreated(conn:psycopg.connection, graphName:str):
175242
validate_graph_name(graphName)

drivers/python/age/models.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,20 @@ def toJson(self) -> str:
118118

119119
return buf.getvalue()
120120

121+
def to_dict(self) -> list:
122+
# AGObj elements are recursively converted; JSON-native types
123+
# (dict, list, str, int, float, bool, None) pass through unchanged.
124+
# Non-serializable objects fall back to str() as a safety net.
125+
result = []
126+
for e in (self.entities or []):
127+
if isinstance(e, AGObj):
128+
result.append(e.to_dict())
129+
elif isinstance(e, (dict, list, str, int, float, bool, type(None))):
130+
result.append(e)
131+
else:
132+
result.append(str(e))
133+
return result
134+
121135

122136

123137

@@ -146,6 +160,18 @@ def __str__(self) -> str:
146160
def __repr__(self) -> str:
147161
return self.toString()
148162

163+
def to_dict(self) -> dict:
164+
"""Return a plain dict suitable for JSON serialization.
165+
166+
Properties are shallow-copied; nested mutable values will share
167+
references with the original Vertex.
168+
"""
169+
return {
170+
"id": self.id,
171+
"label": self.label,
172+
"properties": dict(self.properties) if self.properties else {},
173+
}
174+
149175
def toString(self) -> str:
150176
return nodeToString(self)
151177

@@ -186,6 +212,20 @@ def __str__(self) -> str:
186212
def __repr__(self) -> str:
187213
return self.toString()
188214

215+
def to_dict(self) -> dict:
216+
"""Return a plain dict suitable for JSON serialization.
217+
218+
Properties are shallow-copied; nested mutable values will share
219+
references with the original Edge.
220+
"""
221+
return {
222+
"id": self.id,
223+
"label": self.label,
224+
"start_id": self.start_id,
225+
"end_id": self.end_id,
226+
"properties": dict(self.properties) if self.properties else {},
227+
}
228+
189229
def extraStrFormat(node, buf):
190230
if node.start_id != None:
191231
buf.write(", start_id:")

0 commit comments

Comments
 (0)