1919import logging
2020import socket
2121import time
22+ from collections import namedtuple
2223from types import TracebackType
2324from typing import (
2425 TYPE_CHECKING ,
148149DEFAULT_LOCK_CHECK_RETRIES = 4
149150DO_NOT_UPDATE_STATS = "DO_NOT_UPDATE_STATS"
150151DO_NOT_UPDATE_STATS_DEFAULT = "true"
152+ HiveVersion = namedtuple ("HiveVersion" , "major minor patch" )
151153
152154logger = logging .getLogger (__name__ )
153155
@@ -157,7 +159,7 @@ class _HiveClient:
157159
158160 _transport : TTransport
159161 _ugi : Optional [List [str ]]
160- _hive_version : int = 4
162+ _hive_version : HiveVersion = HiveVersion ( 4 , 0 , 0 )
161163 _hms_v3 : object
162164 _hms_v4 : object
163165
@@ -177,10 +179,10 @@ def __init__(
177179 self .hms_v4 = importlib .import_module ("hive_metastore.v4.ThriftHiveMetastore" )
178180 self ._hive_version = self ._get_hive_version ()
179181
180- def _get_hive_version (self ) -> int :
182+ def _get_hive_version (self ) -> HiveVersion :
181183 with self as open_client :
182- major , * _ = open_client .getVersion ().split ("." )
183- return int ( major )
184+ version = map ( int , open_client .getVersion ().split ("." ) )
185+ return HiveVersion ( * version )
184186
185187 def _init_thrift_transport (self ) -> TTransport :
186188 url_parts = urlparse (self ._uri )
@@ -192,7 +194,7 @@ def _init_thrift_transport(self) -> TTransport:
192194
193195 def _client (self ) -> Client :
194196 protocol = TBinaryProtocol .TBinaryProtocol (self ._transport )
195- hms = self .hms_v3 if self ._hive_version < 4 else self .hms_v4
197+ hms = self .hms_v4 if all (( self ._hive_version . major >= 4 , self . _hive_version . minor > 0 )) else self .hms_v3
196198 client : Client = hms .Client (protocol )
197199 if self ._ugi :
198200 client .set_ugi (* self ._ugi )
@@ -407,14 +409,14 @@ def _create_hive_table(self, open_client: Client, hive_table: HiveTable) -> None
407409 raise TableAlreadyExistsError (f"Table { hive_table .dbName } .{ hive_table .tableName } already exists" ) from e
408410
409411 def _get_hive_table (self , open_client : Client , * , dbname : str , tbl_name : str ) -> HiveTable :
410- if self ._client ._hive_version < 4 :
411- return open_client .get_table ( dbname = dbname , tbl_name = tbl_name )
412- return open_client .get_table_req ( GetTableRequest ( dbName = dbname , tblName = tbl_name )). table
412+ if all (( self ._client ._hive_version . major >= 4 , self . _client . _hive_version . minor > 0 )) :
413+ return open_client .get_table_req ( GetTableRequest ( dbName = dbname , tblName = tbl_name )). table
414+ return open_client .get_table ( dbname = dbname , tbl_name = tbl_name )
413415
414416 def _get_table_objects_by_name (self , open_client : Client , * , dbname : str , tbl_names : list [str ]) -> list [HiveTable ]:
415- if self ._client ._hive_version < 4 :
416- return open_client .get_table_objects_by_name ( dbname = dbname , tbl_names = tbl_names )
417- return open_client .get_table_objects_by_name_req ( GetTablesRequest ( dbName = dbname , tblNames = tbl_names )). tables
417+ if all (( self ._client ._hive_version . major >= 4 , self . _client . _hive_version . minor > 0 )) :
418+ return open_client .get_table_objects_by_name_req ( GetTablesRequest ( dbName = dbname , tblNames = tbl_names )). tables
419+ return open_client .get_table_objects_by_name ( dbname = dbname , tbl_names = tbl_names )
418420
419421 def create_table (
420422 self ,
0 commit comments