@@ -360,15 +360,15 @@ def __init__(
360360 self .thread_id = self .hive .context .get_current_thread_id ()
361361 self .hive ._inspector_map [self .thread_id ] = types .SimpleNamespace ()
362362
363- self .hive ._inspector_map [
364- self . thread_id
365- ]. get_pk_constraint = lambda table_name , schema_name : []
366- self .hive ._inspector_map [
367- self . thread_id
368- ]. get_unique_constraints = lambda table_name , schema_name : []
369- self .hive ._inspector_map [
370- self . thread_id
371- ]. get_foreign_keys = lambda table_name , schema_name : []
363+ self .hive ._inspector_map [self . thread_id ]. get_pk_constraint = (
364+ lambda table_name , schema_name : []
365+ )
366+ self .hive ._inspector_map [self . thread_id ]. get_unique_constraints = (
367+ lambda table_name , schema_name : []
368+ )
369+ self .hive ._inspector_map [self . thread_id ]. get_foreign_keys = (
370+ lambda table_name , schema_name : []
371+ )
372372
373373 def test_yield_database (self ):
374374 assert EXPECTED_DATABASE == [
@@ -518,6 +518,41 @@ def test_get_columns_deduplicates_partition_column_with_sentinel(self):
518518 )
519519 self .assertEqual (col_names , ["id" , "name" , "dt" ])
520520
521+ def test_get_table_partition_details_from_metastore_mysql (self ):
522+ self .hive .engine = types .SimpleNamespace (
523+ url = types .SimpleNamespace (drivername = "hive+mysql" )
524+ )
525+ mock_connection = Mock ()
526+ mock_connection .execute .return_value .fetchall .return_value = [
527+ ("dt" ,),
528+ ("region" ,),
529+ ]
530+ self .hive ._connection_map [self .thread_id ] = mock_connection
531+
532+ is_partitioned , partition = self .hive .get_table_partition_details (
533+ table_name = "sample_table" , schema_name = "sample_schema" , inspector = Mock ()
534+ )
535+
536+ self .assertTrue (is_partitioned )
537+ self .assertIsNotNone (partition )
538+ self .assertEqual ([c .columnName for c in partition .columns ], ["dt" , "region" ])
539+
540+ def test_get_table_partition_details_from_metastore_postgres (self ):
541+ self .hive .engine = types .SimpleNamespace (
542+ url = types .SimpleNamespace (drivername = "hive+postgres" )
543+ )
544+ mock_connection = Mock ()
545+ mock_connection .execute .return_value .fetchall .return_value = [("dt" ,)]
546+ self .hive ._connection_map [self .thread_id ] = mock_connection
547+
548+ is_partitioned , partition = self .hive .get_table_partition_details (
549+ table_name = "sample_table" , schema_name = "sample_schema" , inspector = Mock ()
550+ )
551+
552+ self .assertTrue (is_partitioned )
553+ self .assertIsNotNone (partition )
554+ self .assertEqual ([c .columnName for c in partition .columns ], ["dt" ])
555+
521556 def test_ssl_connection_configuration (self ):
522557 """
523558 Test SSL configuration in Hive connection
0 commit comments