Skip to content

Commit 2f7a11d

Browse files
sdks/python: itest CloudSQLEnrichmentHandler
1 parent 95ec739 commit 2f7a11d

1 file changed

Lines changed: 396 additions & 0 deletions

File tree

Lines changed: 396 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,396 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
import logging
18+
import unittest
19+
from unittest.mock import MagicMock
20+
import pytest
21+
from pg8000.exceptions import DatabaseError as PgDatabaseError
22+
import apache_beam as beam
23+
from apache_beam.coders import coders
24+
from apache_beam.testing.test_pipeline import TestPipeline
25+
from apache_beam.testing.util import BeamAssertException
26+
from apache_beam.transforms.enrichment import Enrichment
27+
from apache_beam.transforms.enrichment_handlers.cloudsql import (
28+
CloudSQLEnrichmentHandler,
29+
DatabaseTypeAdapter,
30+
ExceptionLevel,
31+
)
32+
from testcontainers.redis import RedisContainer
33+
from google.cloud.sql.connector import Connector
34+
import os
35+
36+
_LOGGER = logging.getLogger(__name__)
37+
38+
39+
def _row_key_fn(request: beam.Row, key_id="product_id") -> tuple[str]:
40+
key_value = str(getattr(request, key_id))
41+
return (key_id, key_value)
42+
43+
44+
class ValidateResponse(beam.DoFn):
45+
"""ValidateResponse validates if a PCollection of `beam.Row`
46+
has the required fields."""
47+
def __init__(
48+
self,
49+
n_fields: int,
50+
fields: list[str],
51+
enriched_fields: dict[str, list[str]],
52+
):
53+
self.n_fields = n_fields
54+
self._fields = fields
55+
self._enriched_fields = enriched_fields
56+
57+
def process(self, element: beam.Row, *args, **kwargs):
58+
element_dict = element.as_dict()
59+
if len(element_dict.keys()) != self.n_fields:
60+
raise BeamAssertException(
61+
"Expected %d fields in enriched PCollection:" % self.n_fields)
62+
63+
for field in self._fields:
64+
if field not in element_dict or element_dict[field] is None:
65+
raise BeamAssertException(f"Expected a not None field: {field}")
66+
67+
for key in self._enriched_fields:
68+
if key not in element_dict:
69+
raise BeamAssertException(
70+
f"Response from Cloud SQL should contain {key} column.")
71+
72+
73+
def create_rows(cursor):
74+
"""Insert test rows into the Cloud SQL database table."""
75+
cursor.execute(
76+
"""
77+
CREATE TABLE IF NOT EXISTS products (
78+
product_id SERIAL PRIMARY KEY,
79+
product_name VARCHAR(255),
80+
product_stock INT
81+
)
82+
""")
83+
cursor.execute(
84+
"""
85+
INSERT INTO products (product_name, product_stock)
86+
VALUES
87+
('pixel 5', 2),
88+
('pixel 6', 4),
89+
('pixel 7', 20),
90+
('pixel 8', 10),
91+
('iphone 11', 3),
92+
('iphone 12', 7),
93+
('iphone 13', 8),
94+
('iphone 14', 3)
95+
ON CONFLICT DO NOTHING
96+
""")
97+
98+
99+
@pytest.mark.uses_testcontainer
100+
class TestCloudSQLEnrichment(unittest.TestCase):
101+
@classmethod
102+
def setUpClass(cls):
103+
cls.project_id = "apache-beam-testing"
104+
cls.region_id = "us-central1"
105+
cls.instance_id = "beam-test"
106+
cls.database_id = "postgres"
107+
cls.database_user = os.getenv("BEAM_TEST_CLOUDSQL_PG_USER")
108+
cls.database_password = os.getenv("BEAM_TEST_CLOUDSQL_PG_PASSWORD")
109+
cls.table_id = "products"
110+
cls.row_key = "product_id"
111+
cls.database_type_adapter = DatabaseTypeAdapter.POSTGRESQL
112+
cls.req = [
113+
beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1),
114+
beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3),
115+
beam.Row(sale_id=5, customer_id=5, product_id=3, quantity=2),
116+
beam.Row(sale_id=7, customer_id=7, product_id=4, quantity=1),
117+
]
118+
cls.connector = Connector()
119+
cls.client = cls.connector.connect(
120+
f"{cls.project_id}:{cls.region_id}:{cls.instance_id}",
121+
driver=cls.database_type_adapter.value,
122+
db=cls.database_id,
123+
user=cls.database_user,
124+
password=cls.database_password,
125+
)
126+
cls.cursor = cls.client.cursor()
127+
create_rows(cls.cursor)
128+
cls.cache_client_retries = 3
129+
130+
def _start_cache_container(self):
131+
for i in range(self.cache_client_retries):
132+
try:
133+
self.container = RedisContainer(image="redis:7.2.4")
134+
self.container.start()
135+
self.host = self.container.get_container_host_ip()
136+
self.port = self.container.get_exposed_port(6379)
137+
self.cache_client = self.container.get_client()
138+
break
139+
except Exception as e:
140+
if i == self.cache_client_retries - 1:
141+
_LOGGER.error(
142+
f"Unable to start redis container for RRIO tests after {self.cache_client_retries} retries."
143+
)
144+
raise e
145+
146+
@classmethod
147+
def tearDownClass(cls):
148+
cls.cursor.close()
149+
cls.client.close()
150+
cls.connector.close()
151+
cls.cursor, cls.client, cls.connector = None, None, None
152+
153+
def test_enrichment_with_cloudsql(self):
154+
expected_fields = [
155+
"sale_id",
156+
"customer_id",
157+
"product_id",
158+
"quantity",
159+
"product_name",
160+
"product_stock",
161+
]
162+
expected_enriched_fields = ["product_id", "product_name", "product_stock"]
163+
cloudsql = CloudSQLEnrichmentHandler(
164+
region_id=self.region_id,
165+
project_id=self.project_id,
166+
instance_id=self.instance_id,
167+
database_type_adapter=self.database_type_adapter,
168+
database_id=self.database_id,
169+
database_user=self.database_user,
170+
database_password=self.database_password,
171+
table_id=self.table_id,
172+
row_key=self.row_key,
173+
)
174+
with TestPipeline(is_integration_test=True) as test_pipeline:
175+
_ = (
176+
test_pipeline
177+
| "Create" >> beam.Create(self.req)
178+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql)
179+
| "Validate Response" >> beam.ParDo(
180+
ValidateResponse(
181+
len(expected_fields),
182+
expected_fields,
183+
expected_enriched_fields,
184+
)))
185+
186+
def test_enrichment_with_cloudsql_no_enrichment(self):
187+
expected_fields = ["sale_id", "customer_id", "product_id", "quantity"]
188+
expected_enriched_fields = {}
189+
cloudsql = CloudSQLEnrichmentHandler(
190+
region_id=self.region_id,
191+
project_id=self.project_id,
192+
instance_id=self.instance_id,
193+
database_type_adapter=self.database_type_adapter,
194+
database_id=self.database_id,
195+
database_user=self.database_user,
196+
database_password=self.database_password,
197+
table_id=self.table_id,
198+
row_key=self.row_key,
199+
)
200+
req = [beam.Row(sale_id=1, customer_id=1, product_id=99, quantity=1)]
201+
with TestPipeline(is_integration_test=True) as test_pipeline:
202+
_ = (
203+
test_pipeline
204+
| "Create" >> beam.Create(req)
205+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql)
206+
| "Validate Response" >> beam.ParDo(
207+
ValidateResponse(
208+
len(expected_fields),
209+
expected_fields,
210+
expected_enriched_fields,
211+
)))
212+
213+
def test_enrichment_with_cloudsql_raises_key_error(self):
214+
cloudsql = CloudSQLEnrichmentHandler(
215+
region_id=self.region_id,
216+
project_id=self.project_id,
217+
instance_id=self.instance_id,
218+
database_type_adapter=self.database_type_adapter,
219+
database_id=self.database_id,
220+
database_user=self.database_user,
221+
database_password=self.database_password,
222+
table_id=self.table_id,
223+
row_key="car_name",
224+
)
225+
with self.assertRaises(KeyError):
226+
test_pipeline = TestPipeline()
227+
_ = (
228+
test_pipeline
229+
| "Create" >> beam.Create(self.req)
230+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql))
231+
res = test_pipeline.run()
232+
res.wait_until_finish()
233+
234+
def test_enrichment_with_cloudsql_raises_not_found(self):
235+
"""Raises a database error when the GCP Cloud SQL table doesn't exist."""
236+
table_id = "invalid_table"
237+
cloudsql = CloudSQLEnrichmentHandler(
238+
region_id=self.region_id,
239+
project_id=self.project_id,
240+
instance_id=self.instance_id,
241+
database_type_adapter=self.database_type_adapter,
242+
database_id=self.database_id,
243+
database_user=self.database_user,
244+
database_password=self.database_password,
245+
table_id=table_id,
246+
row_key=self.row_key,
247+
)
248+
with self.assertRaises(RuntimeError) as ctx:
249+
test_pipeline = beam.Pipeline()
250+
_ = (
251+
test_pipeline
252+
| "Create" >> beam.Create(self.req)
253+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql))
254+
res = test_pipeline.run()
255+
res.wait_until_finish()
256+
self.assertIn(f'relation "{table_id}" does not exist', str(ctx.exception))
257+
258+
def test_enrichment_with_cloudsql_exception_level(self):
259+
"""raises a `ValueError` exception when the GCP Cloud SQL query returns
260+
an empty row."""
261+
cloudsql = CloudSQLEnrichmentHandler(
262+
region_id=self.region_id,
263+
project_id=self.project_id,
264+
instance_id=self.instance_id,
265+
database_type_adapter=self.database_type_adapter,
266+
database_id=self.database_id,
267+
database_user=self.database_user,
268+
database_password=self.database_password,
269+
table_id=self.table_id,
270+
row_key=self.row_key,
271+
exception_level=ExceptionLevel.RAISE,
272+
)
273+
req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)]
274+
with self.assertRaises(ValueError):
275+
test_pipeline = beam.Pipeline()
276+
_ = (
277+
test_pipeline
278+
| "Create" >> beam.Create(req)
279+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql))
280+
res = test_pipeline.run()
281+
res.wait_until_finish()
282+
283+
def test_cloudsql_enrichment_with_lambda(self):
284+
expected_fields = [
285+
"sale_id",
286+
"customer_id",
287+
"product_id",
288+
"quantity",
289+
"product_name",
290+
"product_stock",
291+
]
292+
expected_enriched_fields = ["product_id", "product_name", "product_stock"]
293+
cloudsql = CloudSQLEnrichmentHandler(
294+
region_id=self.region_id,
295+
project_id=self.project_id,
296+
instance_id=self.instance_id,
297+
database_type_adapter=self.database_type_adapter,
298+
database_id=self.database_id,
299+
database_user=self.database_user,
300+
database_password=self.database_password,
301+
table_id=self.table_id,
302+
row_key_fn=_row_key_fn,
303+
)
304+
with TestPipeline(is_integration_test=True) as test_pipeline:
305+
_ = (
306+
test_pipeline
307+
| "Create" >> beam.Create(self.req)
308+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql)
309+
| "Validate Response" >> beam.ParDo(
310+
ValidateResponse(
311+
len(expected_fields),
312+
expected_fields,
313+
expected_enriched_fields)))
314+
315+
@pytest.fixture
316+
def cache_container(self):
317+
# Setup phase: start the container.
318+
self._start_cache_container()
319+
320+
# Hand control to the test.
321+
yield
322+
323+
# Cleanup phase: stop the container. It runs after the test completion
324+
# even if it failed.
325+
self.container.stop()
326+
self.container = None
327+
328+
@pytest.mark.usefixtures("cache_container")
329+
def test_cloudsql_enrichment_with_redis(self):
330+
expected_fields = [
331+
"sale_id",
332+
"customer_id",
333+
"product_id",
334+
"quantity",
335+
"product_name",
336+
"product_stock",
337+
]
338+
expected_enriched_fields = ["product_id", "product_name", "product_stock"]
339+
cloudsql = CloudSQLEnrichmentHandler(
340+
region_id=self.region_id,
341+
project_id=self.project_id,
342+
instance_id=self.instance_id,
343+
database_type_adapter=self.database_type_adapter,
344+
database_id=self.database_id,
345+
database_user=self.database_user,
346+
database_password=self.database_password,
347+
table_id=self.table_id,
348+
row_key_fn=_row_key_fn,
349+
)
350+
with TestPipeline(is_integration_test=True) as test_pipeline:
351+
_ = (
352+
test_pipeline
353+
| "Create1" >> beam.Create(self.req)
354+
| "Enrich W/ CloudSQL1" >> Enrichment(cloudsql).with_redis_cache(
355+
self.host, self.port, 300)
356+
| "Validate Response" >> beam.ParDo(
357+
ValidateResponse(
358+
len(expected_fields),
359+
expected_fields,
360+
expected_enriched_fields,
361+
)))
362+
363+
# Manually check cache entry to verify entries were correctly stored.
364+
c = coders.StrUtf8Coder()
365+
for req in self.req:
366+
key = cloudsql.get_cache_key(req)
367+
response = self.cache_client.get(c.encode(key))
368+
if not response:
369+
raise ValueError("No cache entry found for %s" % key)
370+
371+
# Mock the CloudSQL handler to avoid actual database calls.
372+
# This simulates a cache hit scenario by returning predefined data.
373+
actual = CloudSQLEnrichmentHandler.__call__
374+
CloudSQLEnrichmentHandler.__call__ = MagicMock(
375+
return_value=(
376+
beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1),
377+
beam.Row(),
378+
))
379+
380+
# Run a second pipeline to verify cache is being used.
381+
with TestPipeline(is_integration_test=True) as test_pipeline:
382+
_ = (
383+
test_pipeline
384+
| "Create2" >> beam.Create(self.req)
385+
| "Enrich W/ CloudSQL2" >> Enrichment(cloudsql).with_redis_cache(
386+
self.host, self.port)
387+
| "Validate Response" >> beam.ParDo(
388+
ValidateResponse(
389+
len(expected_fields),
390+
expected_fields,
391+
expected_enriched_fields)))
392+
CloudSQLEnrichmentHandler.__call__ = actual
393+
394+
395+
if __name__ == "__main__":
396+
unittest.main()

0 commit comments

Comments
 (0)