Skip to content

Commit 85b5198

Browse files
sdks/python: enrich data with CloudSQL
1 parent 3595a33 commit 85b5198

1 file changed

Lines changed: 169 additions & 0 deletions

File tree

  • sdks/python/apache_beam/transforms/enrichment_handlers
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
import logging
18+
from collections.abc import Callable
19+
from enum import Enum
20+
from typing import Any
21+
from typing import Optional
22+
23+
from google.cloud.sql.connector import Connector
24+
25+
import apache_beam as beam
26+
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
27+
from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel
28+
29+
__all__ = [
30+
'CloudSQLEnrichmentHandler',
31+
]
32+
33+
# RowKeyFn takes beam.Row and returns tuple of (key_id, key_value).
34+
RowKeyFn = Callable[[beam.Row], tuple[str]]
35+
36+
_LOGGER = logging.getLogger(__name__)
37+
38+
39+
class DatabaseTypeAdapter(Enum):
40+
POSTGRESQL = "pg8000"
41+
MYSQL = "pymysql"
42+
SQLSERVER = "pytds"
43+
44+
45+
class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]):
46+
"""A handler for :class:`apache_beam.transforms.enrichment.Enrichment`
47+
transform to interact with Google Cloud SQL databases.
48+
49+
Args:
50+
project_id (str): GCP project-id of the Cloud SQL instance.
51+
region_id (str): GCP region-id of the Cloud SQL instance.
52+
instance_id (str): GCP instance-id of the Cloud SQL instance.
53+
database_type_adapter (DatabaseTypeAdapter): The type of database adapter to use.
54+
Supported adapters are: POSTGRESQL (pg8000), MYSQL (pymysql), and SQLSERVER (pytds).
55+
database_id (str): The id of the database to connect to.
56+
database_user (str): The username for connecting to the database.
57+
database_password (str): The password for connecting to the database.
58+
table_id (str): The name of the table to query.
59+
row_key (str): Field name from the input `beam.Row` object to use as
60+
identifier for database querying.
61+
row_key_fn: A lambda function that returns a string key from the
62+
input row. Used to build/extract the identifier for the database query.
63+
exception_level: A `enum.Enum` value from
64+
``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel``
65+
to set the level when no matching record is found from the database query.
66+
Defaults to ``ExceptionLevel.WARN``.
67+
"""
68+
def __init__(
69+
self,
70+
region_id: str,
71+
project_id: str,
72+
instance_id: str,
73+
database_type_adapter: DatabaseTypeAdapter,
74+
database_id: str,
75+
database_user: str,
76+
database_password: str,
77+
table_id: str,
78+
row_key: str = "",
79+
*,
80+
row_key_fn: Optional[RowKeyFn] = None,
81+
exception_level: ExceptionLevel = ExceptionLevel.WARN,
82+
):
83+
self._project_id = project_id
84+
self._region_id = region_id
85+
self._instance_id = instance_id
86+
self._database_type_adapter = database_type_adapter
87+
self._database_id = database_id
88+
self._database_user = database_user
89+
self._database_password = database_password
90+
self._table_id = table_id
91+
self._row_key = row_key
92+
self._row_key_fn = row_key_fn
93+
self._exception_level = exception_level
94+
if ((not self._row_key_fn and not self._row_key) or
95+
bool(self._row_key_fn and self._row_key)):
96+
raise ValueError(
97+
"Please specify exactly one of `row_key` or a lambda "
98+
"function with `row_key_fn` to extract the row key "
99+
"from the input row.")
100+
101+
def __enter__(self):
102+
"""Connect to the the Cloud SQL instance."""
103+
self.connector = Connector()
104+
self.client = self.connector.connect(
105+
f"{self._project_id}:{self._region_id}:{self._instance_id}",
106+
driver=self._database_type_adapter.value,
107+
db=self._database_id,
108+
user=self._database_user,
109+
password=self._database_password,
110+
)
111+
self.cursor = self.client.cursor()
112+
113+
def __call__(self, request: beam.Row, *args, **kwargs):
114+
"""
115+
Executes a query to the Cloud SQL instance and returns
116+
a `Tuple` of request and response.
117+
118+
Args:
119+
request: the input `beam.Row` to enrich.
120+
"""
121+
response_dict: dict[str, Any] = {}
122+
row_key_str: str = ""
123+
124+
try:
125+
if self._row_key_fn:
126+
self._row_key, row_key = self._row_key_fn(request)
127+
else:
128+
request_dict = request._asdict()
129+
row_key_str = str(request_dict[self._row_key])
130+
row_key = row_key_str
131+
132+
query = f"SELECT * FROM {self._table_id} WHERE {self._row_key} = %s"
133+
self.cursor.execute(query, (row_key, ))
134+
result = self.cursor.fetchone()
135+
136+
if result:
137+
columns = [col[0] for col in self.cursor.description]
138+
for i, value in enumerate(result):
139+
response_dict[columns[i]] = value
140+
elif self._exception_level == ExceptionLevel.WARN:
141+
_LOGGER.warning(
142+
'No matching record found for row_key: %s in table: %s',
143+
row_key_str,
144+
self._table_id)
145+
elif self._exception_level == ExceptionLevel.RAISE:
146+
raise ValueError(
147+
'No matching record found for row_key: %s in table: %s' %
148+
(row_key_str, self._table_id))
149+
except KeyError:
150+
raise KeyError('row_key %s not found in input PCollection.' % row_key_str)
151+
except Exception as e:
152+
raise e
153+
154+
return request, beam.Row(**response_dict)
155+
156+
def __exit__(self, exc_type, exc_val, exc_tb):
157+
"""Clean the instantiated Cloud SQL client."""
158+
self.cursor.close()
159+
self.client.close()
160+
self.connector.close()
161+
self.cursor, self.client, self.connector = None, None, None
162+
163+
def get_cache_key(self, request: beam.Row) -> str:
164+
"""Returns a string formatted with row key since it is unique to
165+
a request made to the Cloud SQL instance."""
166+
if self._row_key_fn:
167+
id, value = self._row_key_fn(request)
168+
return f"{id}: {value}"
169+
return f"{self._row_key}: {request._asdict()[self._row_key]}"

0 commit comments

Comments
 (0)