1515# specific language governing permissions and limitations
1616# under the License.
1717import getpass
18+ import importlib
1819import logging
1920import socket
2021import time
22+ from collections import namedtuple
2123from types import TracebackType
2224from typing import (
2325 TYPE_CHECKING ,
3234)
3335from urllib .parse import urlparse
3436
35- from hive_metastore .ThriftHiveMetastore import Client
36- from hive_metastore .ttypes import (
37+ from hive_metastore .v3 . ThriftHiveMetastore import Client
38+ from hive_metastore .v3 . ttypes import (
3739 AlreadyExistsException ,
3840 CheckLockRequest ,
3941 EnvironmentContext ,
4042 FieldSchema ,
43+ GetTableRequest ,
44+ GetTablesRequest ,
4145 InvalidOperationException ,
4246 LockComponent ,
4347 LockLevel ,
5155 StorageDescriptor ,
5256 UnlockRequest ,
5357)
54- from hive_metastore .ttypes import Database as HiveDatabase
55- from hive_metastore .ttypes import Table as HiveTable
58+ from hive_metastore .v3 . ttypes import Database as HiveDatabase
59+ from hive_metastore .v3 . ttypes import Table as HiveTable
5660from tenacity import retry , retry_if_exception_type , stop_after_attempt , wait_exponential
5761from thrift .protocol import TBinaryProtocol
5862from thrift .transport import TSocket , TTransport
141145DEFAULT_LOCK_CHECK_RETRIES = 4
142146DO_NOT_UPDATE_STATS = "DO_NOT_UPDATE_STATS"
143147DO_NOT_UPDATE_STATS_DEFAULT = "true"
148+ HiveVersion = namedtuple ("HiveVersion" , "major minor patch" )
144149
145150logger = logging .getLogger (__name__ )
146151
@@ -150,6 +155,9 @@ class _HiveClient:
150155
151156 _transport : TTransport
152157 _ugi : Optional [List [str ]]
158+ _hive_version : HiveVersion = HiveVersion (4 , 0 , 0 )
159+ _hms_v3 : object
160+ _hms_v4 : object
153161
154162 def __init__ (
155163 self ,
@@ -163,9 +171,19 @@ def __init__(
163171 self ._kerberos_service_name = kerberos_service_name
164172 self ._ugi = ugi .split (":" ) if ugi else None
165173 self ._transport = self ._init_thrift_transport ()
174+ self .hms_v3 = importlib .import_module ("hive_metastore.v3.ThriftHiveMetastore" )
175+ self .hms_v4 = importlib .import_module ("hive_metastore.v4.ThriftHiveMetastore" )
176+ self ._hive_version = self ._get_hive_version ()
177+
178+ def _get_hive_version (self ) -> HiveVersion :
179+ with self as open_client :
180+ version = map (int , open_client .getVersion ().split ("." ))
181+ return HiveVersion (* version )
166182
167183 def _init_thrift_transport (self ) -> TTransport :
168184 url_parts = urlparse (self ._uri )
185+ if not url_parts .hostname or not url_parts .port :
186+ raise ValueError ("hive hostname and port must be set" )
169187 socket = TSocket .TSocket (url_parts .hostname , url_parts .port )
170188 if not self ._kerberos_auth :
171189 return TTransport .TBufferedTransport (socket )
@@ -174,7 +192,8 @@ def _init_thrift_transport(self) -> TTransport:
174192
175193 def _client (self ) -> Client :
176194 protocol = TBinaryProtocol .TBinaryProtocol (self ._transport )
177- client = Client (protocol )
195+ hms = self .hms_v4 if all ((self ._hive_version .major >= 4 , self ._hive_version .patch > 0 )) else self .hms_v3
196+ client : Client = hms .Client (protocol )
178197 if self ._ugi :
179198 client .set_ugi (* self ._ugi )
180199 return client
@@ -387,11 +406,18 @@ def _create_hive_table(self, open_client: Client, hive_table: HiveTable) -> None
387406 except AlreadyExistsException as e :
388407 raise TableAlreadyExistsError (f"Table { hive_table .dbName } .{ hive_table .tableName } already exists" ) from e
389408
390- def _get_hive_table (self , open_client : Client , database_name : str , table_name : str ) -> HiveTable :
409+ def _get_hive_table (self , open_client : Client , * , dbname : str , tbl_name : str ) -> HiveTable :
391410 try :
392- return open_client .get_table (dbname = database_name , tbl_name = table_name )
411+ if all ((self ._client ._hive_version .major >= 4 , self ._client ._hive_version .patch > 0 )):
412+ return open_client .get_table_req (GetTableRequest (dbName = dbname , tblName = tbl_name )).table
413+ return open_client .get_table (dbname = dbname , tbl_name = tbl_name )
393414 except NoSuchObjectException as e :
394- raise NoSuchTableError (f"Table does not exists: { table_name } " ) from e
415+ raise NoSuchTableError (f"Table does not exists: { tbl_name } " ) from e
416+
417+ def _get_table_objects_by_name (self , open_client : Client , * , dbname : str , tbl_names : list [str ]) -> list [HiveTable ]:
418+ if all ((self ._client ._hive_version .major >= 4 , self ._client ._hive_version .patch > 0 )):
419+ return open_client .get_table_objects_by_name_req (GetTablesRequest (dbName = dbname , tblNames = tbl_names )).tables
420+ return open_client .get_table_objects_by_name (dbname = dbname , tbl_names = tbl_names )
395421
396422 def create_table (
397423 self ,
@@ -435,7 +461,7 @@ def create_table(
435461
436462 with self ._client as open_client :
437463 self ._create_hive_table (open_client , tbl )
438- hive_table = open_client . get_table ( dbname = database_name , tbl_name = table_name )
464+ hive_table = self . _get_hive_table ( open_client , dbname = database_name , tbl_name = table_name )
439465
440466 return self ._convert_hive_into_iceberg (hive_table )
441467
@@ -465,7 +491,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
465491 tbl = self ._convert_iceberg_into_hive (staged_table )
466492 with self ._client as open_client :
467493 self ._create_hive_table (open_client , tbl )
468- hive_table = open_client . get_table ( dbname = database_name , tbl_name = table_name )
494+ hive_table = self . _get_hive_table ( open_client , dbname = database_name , tbl_name = table_name )
469495
470496 return self ._convert_hive_into_iceberg (hive_table )
471497
@@ -538,7 +564,7 @@ def commit_table(
538564 hive_table : Optional [HiveTable ]
539565 current_table : Optional [Table ]
540566 try :
541- hive_table = self ._get_hive_table (open_client , database_name , table_name )
567+ hive_table = self ._get_hive_table (open_client , dbname = database_name , tbl_name = table_name )
542568 current_table = self ._convert_hive_into_iceberg (hive_table )
543569 except NoSuchTableError :
544570 hive_table = None
@@ -612,7 +638,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
612638 database_name , table_name = self .identifier_to_database_and_table (identifier , NoSuchTableError )
613639
614640 with self ._client as open_client :
615- hive_table = self ._get_hive_table (open_client , database_name , table_name )
641+ hive_table = self ._get_hive_table (open_client , dbname = database_name , tbl_name = table_name )
616642
617643 return self ._convert_hive_into_iceberg (hive_table )
618644
@@ -661,7 +687,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
661687
662688 try :
663689 with self ._client as open_client :
664- tbl = open_client . get_table ( dbname = from_database_name , tbl_name = from_table_name )
690+ tbl = self . _get_hive_table ( open_client , dbname = from_database_name , tbl_name = from_table_name )
665691 tbl .dbName = to_database_name
666692 tbl .tableName = to_table_name
667693 open_client .alter_table_with_environment_context (
@@ -733,8 +759,8 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
733759 with self ._client as open_client :
734760 return [
735761 (database_name , table .tableName )
736- for table in open_client . get_table_objects_by_name (
737- dbname = database_name , tbl_names = open_client .get_all_tables (db_name = database_name )
762+ for table in self . _get_table_objects_by_name (
763+ open_client , dbname = database_name , tbl_names = open_client .get_all_tables (db_name = database_name )
738764 )
739765 if table .parameters .get (TABLE_TYPE , "" ).lower () == ICEBERG
740766 ]
0 commit comments