Skip to content

Commit 2e8e6c2

Browse files
authored
Merge branch 'main' into issue_183
2 parents 5bcc0e9 + 51dba31 commit 2e8e6c2

7 files changed

Lines changed: 82 additions & 27 deletions

File tree

.github/workflows/docs.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ jobs:
1010
runs-on: ubuntu-latest
1111
steps:
1212
- name: Checkout
13-
uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4
13+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
1414
- name: Setup Python
15-
uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5
15+
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
1616
with:
1717
python-version: "3.10"
1818
- name: Install nox
@@ -26,9 +26,9 @@ jobs:
2626
runs-on: ubuntu-latest
2727
steps:
2828
- name: Checkout
29-
uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4
29+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
3030
- name: Setup Python
31-
uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5
31+
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
3232
with:
3333
python-version: "3.10"
3434
- name: Install nox

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3
3535

3636
- name: Setup Python
37-
uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0
37+
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
3838
with:
3939
python-version: "3.11"
4040

integration.cloudbuild.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ steps:
3636
timeout: "7200s"
3737
substitutions:
3838
_INSTANCE_ID: test-instance
39-
_GOOGLE_DATABASE: test-google-db
40-
_PG_DATABASE: test-pg-db
39+
_GOOGLE_DATABASE: test-gsql-db
40+
_PG_DATABASE: test-pgsql-db
4141
_VERSION: "3.9"
4242

4343
options:

src/langchain_google_spanner/graph_qa.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ class VerifyGqlOutput(BaseModel):
7171
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
7272

7373

74+
class InvalidGQLGenerationError(ValueError):
75+
def __init__(self, message, intermediate_steps=None):
76+
self.intermediate_steps = intermediate_steps
77+
super().__init__(message)
78+
79+
7480
class SpannerGraphQAChain(Chain):
7581
"""Chain for question-answering against a Spanner Graph database by
7682
generating GQL statements from natural language questions.
@@ -268,7 +274,9 @@ def execute_with_retry(
268274
finally:
269275
retries += 1
270276

271-
raise ValueError("The generated gql query is invalid")
277+
raise InvalidGQLGenerationError(
278+
"The generated gql query is invalid", intermediate_steps
279+
)
272280

273281
def log_invalid_query(
274282
self,
@@ -309,6 +317,7 @@ def _call(
309317
}
310318
)
311319
if "verified_gql" in verify_response:
320+
intermediate_steps.append({"raw_generated_gql": generated_gql})
312321
verified_gql = fix_gql_syntax(verify_response["verified_gql"])
313322
intermediate_steps.append({"verified_gql": verified_gql})
314323
else:
@@ -322,7 +331,9 @@ def _call(
322331
_run_manager, intermediate_steps, question, verified_gql
323332
)
324333
if not final_gql:
325-
raise ValueError("No GQL was generated.")
334+
raise InvalidGQLGenerationError(
335+
"No GQL was generated.", intermediate_steps
336+
)
326337
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
327338
_run_manager.on_text(
328339
str(context), color="green", end="\n", verbose=self.verbose

src/langchain_google_spanner/graph_store.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -958,8 +958,8 @@ def __repr__(self) -> str:
958958
"Node properties per node label": {
959959
label: [
960960
{
961-
"property name": name,
962-
"property type": properties[name],
961+
"name": name,
962+
"type": properties[name],
963963
}
964964
for name in sorted(self.labels[label].prop_names)
965965
]
@@ -968,8 +968,8 @@ def __repr__(self) -> str:
968968
"Edge properties per edge label": {
969969
label: [
970970
{
971-
"property name": name,
972-
"property type": properties[name],
971+
"name": name,
972+
"type": properties[name],
973973
}
974974
for name in sorted(self.labels[label].prop_names)
975975
]
@@ -1230,6 +1230,7 @@ def __init__(
12301230
instance_id: str,
12311231
database_id: str,
12321232
client: Optional[spanner.Client] = None,
1233+
timeout: Optional[float] = None,
12331234
):
12341235
"""Initializes the Spanner implementation.
12351236
@@ -1241,11 +1242,14 @@ def __init__(
12411242
self.client = client or spanner.Client()
12421243
self.instance = self.client.instance(instance_id)
12431244
self.database = self.instance.database(database_id)
1245+
self.timeout = timeout
12441246

12451247
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
12461248
param_types = {k: TypeUtility.value_to_param_type(v) for k, v in params.items()}
12471249
with self.database.snapshot() as snapshot:
1248-
rows = snapshot.execute_sql(query, params=params, param_types=param_types)
1250+
rows = snapshot.execute_sql(
1251+
query, params=params, param_types=param_types, timeout=self.timeout
1252+
)
12491253
return [
12501254
{
12511255
column: value
@@ -1286,6 +1290,7 @@ def __init__(
12861290
static_node_properties: List[str] = [],
12871291
static_edge_properties: List[str] = [],
12881292
impl: Optional[SpannerInterface] = None,
1293+
timeout: Optional[float] = None,
12891294
):
12901295
"""Initializes SpannerGraphStore.
12911296
@@ -1300,11 +1305,13 @@ def __init__(
13001305
properties as static;
13011306
static_edge_properties: in flexible schema, treat these edge
13021307
properties as static.
1308+
timeout (Optional[float]): The timeout for queries in seconds.
13031309
"""
13041310
self.impl = impl or SpannerImpl(
13051311
instance_id,
13061312
database_id,
13071313
client_with_user_agent(client, USER_AGENT_GRAPH_STORE),
1314+
timeout=timeout,
13081315
)
13091316
self.schema = SpannerGraphSchema(
13101317
graph_name,

tests/integration/test_spanner_loader.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,41 @@ def client() -> Client:
3535
return Client(project=project_id)
3636

3737

38+
@pytest.fixture(scope="class")
39+
def cleanupGSQL(client):
40+
yield
41+
42+
print("\nPerforming GSQL cleanup after each test...")
43+
44+
database = client.instance(instance_id).database(google_database)
45+
operation = database.update_ddl(
46+
[
47+
f"DROP TABLE IF EXISTS {table_name}",
48+
]
49+
)
50+
operation.result(OPERATION_TIMEOUT_SECONDS)
51+
52+
# Code to perform teardown after each test goes here
53+
print("\nGSQL Cleanup complete.")
54+
55+
56+
@pytest.fixture(scope="class")
57+
def cleanupPGSQL(client):
58+
yield
59+
60+
print("\nPerforming PGSQL cleanup after each test...")
61+
62+
database = client.instance(instance_id).database(pg_database)
63+
operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"])
64+
operation.result(OPERATION_TIMEOUT_SECONDS)
65+
66+
# Code to perform teardown after each test goes here
67+
print("\n PGSQL Cleanup complete.")
68+
69+
3870
class TestSpannerDocumentLoaderGoogleSQL:
3971
@pytest.fixture(autouse=True, scope="class")
40-
def setup_database(self, client):
72+
def setup_database(self, client, cleanupGSQL):
4173
database = client.instance(instance_id).database(google_database)
4274
operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"])
4375
operation.result(OPERATION_TIMEOUT_SECONDS)
@@ -455,7 +487,7 @@ def test_loader_custom_json_metadata(self, client):
455487

456488
class TestSpannerDocumentLoaderPostgreSQL:
457489
@pytest.fixture(autouse=True, scope="class")
458-
def setup_database(self, client):
490+
def setup_database(self, client, cleanupPGSQL):
459491
database = client.instance(instance_id).database(pg_database)
460492
operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"])
461493
operation.result(OPERATION_TIMEOUT_SECONDS)
@@ -872,15 +904,15 @@ def test_loader_custom_json_metadata(self, client):
872904

873905
class TestSpannerDocumentSaver:
874906
@pytest.fixture(name="google_client")
875-
def setup_google_client(self, client) -> Client:
907+
def setup_google_client(self, client, cleanupGSQL) -> Client:
876908
database = client.instance(instance_id).database(google_database)
877909
operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"])
878910
print("table dropped")
879911
operation.result(OPERATION_TIMEOUT_SECONDS)
880912
yield client
881913

882914
@pytest.fixture(name="pg_client")
883-
def setup_pg_client(self, client) -> Client:
915+
def setup_pg_client(self, client, cleanupPGSQL) -> Client:
884916
database = client.instance(instance_id).database(pg_database)
885917
operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"])
886918
operation.result(OPERATION_TIMEOUT_SECONDS)

tests/integration/test_spanner_vector_store.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import pytest
2222
from google.cloud.spanner import Client # type: ignore
23-
from langchain_community.document_loaders import HNLoader
23+
from langchain_community.document_loaders import RecursiveUrlLoader
2424
from langchain_community.embeddings import FakeEmbeddings
2525

2626
from langchain_google_spanner.vector_store import ( # type: ignore
@@ -245,11 +245,13 @@ def setup_database(self, client):
245245
id_column="row_id",
246246
metadata_columns=[
247247
TableColumn(name="metadata", type="JSON", is_null=True),
248-
TableColumn(name="title", type="STRING(MAX)", is_null=False),
248+
TableColumn(name="title", type="STRING(MAX)"),
249249
],
250250
)
251251

252-
loader = HNLoader("https://news.ycombinator.com/item?id=34817881")
252+
loader = RecursiveUrlLoader(
253+
"https://news.ycombinator.com/item?id=1", max_depth=1
254+
)
253255

254256
embeddings = FakeEmbeddings(size=3)
255257

@@ -327,7 +329,7 @@ def test_spanner_vector_delete_data(self, setup_database):
327329

328330
docs = loader.load()
329331

330-
deleted = db.delete(documents=[docs[0], docs[1]])
332+
deleted = db.delete(documents=docs)
331333

332334
assert deleted
333335

@@ -459,7 +461,9 @@ def setup_database(self, client):
459461
],
460462
)
461463

462-
loader = HNLoader("https://news.ycombinator.com/item?id=34817881")
464+
loader = RecursiveUrlLoader(
465+
"https://news.ycombinator.com/item?id=1", max_depth=1
466+
)
463467
embeddings = FakeEmbeddings(size=title_vector_size)
464468

465469
def cleanup_db():
@@ -552,7 +556,7 @@ def test_delete(self, setup_database):
552556
)
553557

554558
docs = loader.load()
555-
deleted = db.delete(documents=[docs[0], docs[1]])
559+
deleted = db.delete(documents=docs)
556560

557561
assert deleted
558562

@@ -677,8 +681,9 @@ def setup_database(self, client):
677681
],
678682
)
679683

680-
loader = HNLoader("https://news.ycombinator.com/item?id=34817881")
681-
684+
loader = RecursiveUrlLoader(
685+
"https://news.ycombinator.com/item?id=1", max_depth=1
686+
)
682687
embeddings = FakeEmbeddings(size=3)
683688

684689
yield loader, embeddings
@@ -755,7 +760,7 @@ def test_spanner_vector_delete_data(self, setup_database):
755760

756761
docs = loader.load()
757762

758-
deleted = db.delete(documents=[docs[0], docs[1]])
763+
deleted = db.delete(documents=docs)
759764

760765
assert deleted
761766

0 commit comments

Comments
 (0)