1+ #
2+ # Licensed to the Apache Software Foundation (ASF) under one or more
3+ # contributor license agreements. See the NOTICE file distributed with
4+ # this work for additional information regarding copyright ownership.
5+ # The ASF licenses this file to You under the Apache License, Version 2.0
6+ # (the "License"); you may not use this file except in compliance with
7+ # the License. You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing, software
12+ # distributed under the License is distributed on an "AS IS" BASIS,
13+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ # See the License for the specific language governing permissions and
15+ # limitations under the License.
16+ #
17+ import logging
18+ from collections .abc import Callable
19+ from enum import Enum
20+ from typing import Any
21+ from typing import Optional
22+
23+ from google .cloud .sql .connector import Connector
24+
25+ import apache_beam as beam
26+ from apache_beam .transforms .enrichment import EnrichmentSourceHandler
27+ from apache_beam .transforms .enrichment_handlers .utils import ExceptionLevel
28+
29+ __all__ = [
30+ 'CloudSQLEnrichmentHandler' ,
31+ ]
32+
33+ RowKeyFn = Callable [[beam .Row ], str ]
34+
35+ _LOGGER = logging .getLogger (__name__ )
36+
37+ class DatabaseTypeAdapter (Enum ):
38+ POSTGRESQL = "pg8000"
39+ MYSQL = "pymysql"
40+ SQLSERVER = "pytds"
41+
42+ def __str__ (self ):
43+ return self .value
44+
45+ class CloudSQLEnrichmentHandler (EnrichmentSourceHandler [beam .Row , beam .Row ]):
46+ """A handler for :class:`apache_beam.transforms.enrichment.Enrichment`
47+ transform to interact with Google Cloud SQL databases.
48+
49+ Args:
50+ project_id (str): GCP project-id of the Cloud SQL instance.
51+ region_id (str): GCP region-id of the Cloud SQL instance.
52+ instance_id (str): GCP instance-id of the Cloud SQL instance.
53+ database_type_adapter (DatabaseTypeAdapter): The type of database adapter to use.
54+ Supported adapters are: POSTGRESQL (pg8000), MYSQL (pymysql), and SQLSERVER (pytds).
55+ database_name (str): The name of the database to connect to.
56+ database_user (str): The username for connecting to the database.
57+ database_password (str): The password for connecting to the database.
58+ table_id (str): The name of the table to query.
59+ row_key (str): Field name from the input `beam.Row` object to use as
60+ identifier for database querying.
61+ row_key_fn: A lambda function that returns a string key from the
62+ input row. Used to build/extract the identifier for the database query.
63+ exception_level: A `enum.Enum` value from
64+ ``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel``
65+ to set the level when no matching record is found from the database query.
66+ Defaults to ``ExceptionLevel.WARN``.
67+ """
68+ def __init__ (
69+ self ,
70+ region_id : str ,
71+ project_id : str ,
72+ instance_id : str ,
73+ database_type_adapter : DatabaseTypeAdapter ,
74+ database_name : str ,
75+ database_user : str ,
76+ database_password : str ,
77+ table_id : str ,
78+ row_key : str = "" ,
79+ * ,
80+ row_key_fn : Optional [RowKeyFn ] = None ,
81+ exception_level : ExceptionLevel = ExceptionLevel .WARN ,
82+ ):
83+ self ._project_id = project_id
84+ self ._region_id = region_id
85+ self ._instance_id = instance_id
86+ self ._database_type_adapter = database_type_adapter
87+ self ._database_name = database_name
88+ self ._database_user = database_user
89+ self ._database_password = database_password
90+ self ._table_id = table_id
91+ self ._row_key = row_key
92+ self ._row_key_fn = row_key_fn
93+ self ._exception_level = exception_level
94+ if ((not self ._row_key_fn and not self ._row_key ) or
95+ bool (self ._row_key_fn and self ._row_key )):
96+ raise ValueError (
97+ "Please specify exactly one of `row_key` or a lambda "
98+ "function with `row_key_fn` to extract the row key "
99+ "from the input row." )
100+
101+ def __enter__ (self ):
102+ """Connect to the the Cloud SQL instance."""
103+ connector = Connector ()
104+ self .client = connector .connect (
105+ f"{ self ._project_id } :{ self ._region_id } :{ self ._instance_id } " ,
106+ driver = self ._database_type_adapter .value ,
107+ user = self ._database_user ,
108+ password = self ._database_password ,
109+ db = self ._database_name
110+ )
111+ self .cursor = self .client .cursor ()
112+
113+ def __call__ (self , request : beam .Row , * args , ** kwargs ):
114+ """
115+ Executes a query to the Cloud SQL instance and returns
116+ a `Tuple` of request and response.
117+
118+ Args:
119+ request: the input `beam.Row` to enrich.
120+ """
121+ response_dict : dict [str , Any ] = {}
122+ row_key_str : str = ""
123+
124+ try :
125+ if self ._row_key_fn :
126+ row_key = self ._row_key_fn (request )
127+ else :
128+ request_dict = request ._asdict ()
129+ row_key_str = str (request_dict [self ._row_key ])
130+ row_key = row_key_str
131+
132+ query = f"SELECT * FROM { self ._table_id } WHERE { self ._row_key } = %s"
133+ self .cursor .execute (query , (row_key ,))
134+ result = self .cursor .fetchone ()
135+
136+ if result :
137+ columns = [col [0 ] for col in self .cursor .description ]
138+ for i , value in enumerate (result ):
139+ response_dict [columns [i ]] = value
140+ elif self ._exception_level == ExceptionLevel .WARN :
141+ _LOGGER .warning (
142+ 'No matching record found for row_key: %s in table: %s' ,
143+ row_key_str , self ._table_id )
144+ elif self ._exception_level == ExceptionLevel .RAISE :
145+ raise ValueError (
146+ 'No matching record found for row_key: %s in table: %s' %
147+ (row_key_str , self ._table_id ))
148+ except KeyError :
149+ raise KeyError ('row_key %s not found in input PCollection.' % row_key_str )
150+ except Exception as e :
151+ raise e
152+
153+ return request , beam .Row (** response_dict )
154+
155+ def __exit__ (self , exc_type , exc_val , exc_tb ):
156+ """Clean the instantiated the Cloud SQL client."""
157+ self .cursor .close ()
158+ self .client .close ()
159+ self .cursor , self .client = None , None
160+
161+ def get_cache_key (self , request : beam .Row ) -> str :
162+ """Returns a string formatted with row key since it is unique to
163+ a request made to the Cloud SQL instance."""
164+ if self ._row_key_fn :
165+ return f"row_key: { str (self ._row_key_fn (request ))} "
166+ return f"{ self ._row_key } : { request ._asdict ()[self ._row_key ]} "
0 commit comments