diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index 538325c4..e27c8a02 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -36,8 +36,8 @@ steps: timeout: "7200s" substitutions: _INSTANCE_ID: test-instance - _GOOGLE_DATABASE: test-google-db - _PG_DATABASE: test-pg-db + _GOOGLE_DATABASE: test-gsql-db + _PG_DATABASE: test-pgsql-db _VERSION: "3.9" options: diff --git a/tests/integration/test_spanner_loader.py b/tests/integration/test_spanner_loader.py index fd8f0284..99bb1b0e 100644 --- a/tests/integration/test_spanner_loader.py +++ b/tests/integration/test_spanner_loader.py @@ -35,9 +35,41 @@ def client() -> Client: return Client(project=project_id) +@pytest.fixture(scope="class") +def cleanupGSQL(client): + yield + + print("\nPerforming GSQL cleanup after each test...") + + database = client.instance(instance_id).database(google_database) + operation = database.update_ddl( + [ + f"DROP TABLE IF EXISTS {table_name}", + ] + ) + operation.result(OPERATION_TIMEOUT_SECONDS) + + # Code to perform teardown after each test goes here + print("\nGSQL Cleanup complete.") + + +@pytest.fixture(scope="class") +def cleanupPGSQL(client): + yield + + print("\nPerforming PGSQL cleanup after each test...") + + database = client.instance(instance_id).database(pg_database) + operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"]) + operation.result(OPERATION_TIMEOUT_SECONDS) + + # Code to perform teardown after each test goes here + print("\n PGSQL Cleanup complete.") + + class TestSpannerDocumentLoaderGoogleSQL: @pytest.fixture(autouse=True, scope="class") - def setup_database(self, client): + def setup_database(self, client, cleanupGSQL): database = client.instance(instance_id).database(google_database) operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"]) operation.result(OPERATION_TIMEOUT_SECONDS) @@ -455,7 +487,7 @@ def test_loader_custom_json_metadata(self, client): class TestSpannerDocumentLoaderPostgreSQL: @pytest.fixture(autouse=True, scope="class") - def setup_database(self, client): + def setup_database(self, client, cleanupPGSQL): database = client.instance(instance_id).database(pg_database) operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"]) operation.result(OPERATION_TIMEOUT_SECONDS) @@ -872,7 +904,7 @@ def test_loader_custom_json_metadata(self, client): class TestSpannerDocumentSaver: @pytest.fixture(name="google_client") - def setup_google_client(self, client) -> Client: + def setup_google_client(self, client, cleanupGSQL) -> Client: database = client.instance(instance_id).database(google_database) operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"]) print("table dropped") @@ -880,7 +912,7 @@ def setup_google_client(self, client) -> Client: yield client @pytest.fixture(name="pg_client") - def setup_pg_client(self, client) -> Client: + def setup_pg_client(self, client, cleanupPGSQL) -> Client: database = client.instance(instance_id).database(pg_database) operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"]) operation.result(OPERATION_TIMEOUT_SECONDS) diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index 8cc25a0b..41bd4901 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -20,7 +20,7 @@ import pytest from google.cloud.spanner import Client # type: ignore -from langchain_community.document_loaders import HNLoader +from langchain_community.document_loaders import RecursiveUrlLoader from langchain_community.embeddings import FakeEmbeddings from langchain_google_spanner.vector_store import ( # type: ignore @@ -245,11 +245,13 @@ def setup_database(self, client): id_column="row_id", metadata_columns=[ TableColumn(name="metadata", type="JSON", is_null=True), - TableColumn(name="title", type="STRING(MAX)", is_null=False), + TableColumn(name="title", type="STRING(MAX)"), ], ) - loader = HNLoader("https://news.ycombinator.com/item?id=34817881") + loader = RecursiveUrlLoader( + "https://news.ycombinator.com/item?id=1", max_depth=1 + ) embeddings = FakeEmbeddings(size=3) @@ -327,7 +329,7 @@ def test_spanner_vector_delete_data(self, setup_database): docs = loader.load() - deleted = db.delete(documents=[docs[0], docs[1]]) + deleted = db.delete(documents=docs) assert deleted @@ -459,7 +461,9 @@ def setup_database(self, client): ], ) - loader = HNLoader("https://news.ycombinator.com/item?id=34817881") + loader = RecursiveUrlLoader( + "https://news.ycombinator.com/item?id=1", max_depth=1 + ) embeddings = FakeEmbeddings(size=title_vector_size) def cleanup_db(): @@ -552,7 +556,7 @@ def test_delete(self, setup_database): ) docs = loader.load() - deleted = db.delete(documents=[docs[0], docs[1]]) + deleted = db.delete(documents=docs) assert deleted @@ -677,8 +681,9 @@ def setup_database(self, client): ], ) - loader = HNLoader("https://news.ycombinator.com/item?id=34817881") - + loader = RecursiveUrlLoader( + "https://news.ycombinator.com/item?id=1", max_depth=1 + ) embeddings = FakeEmbeddings(size=3) yield loader, embeddings @@ -755,7 +760,7 @@ def test_spanner_vector_delete_data(self, setup_database): docs = loader.load() - deleted = db.delete(documents=[docs[0], docs[1]]) + deleted = db.delete(documents=docs) assert deleted