Skip to content

Commit 5398ec0

Browse files
fix(hive): support metastore partition key extraction (#27029)
1 parent 5dbb3ea commit 5398ec0

2 files changed

Lines changed: 84 additions & 22 deletions

File tree

ingestion/src/metadata/ingestion/source/database/hive/metadata.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -185,21 +185,48 @@ def get_table_partition_details(
185185
partition_keys: List[str] = []
186186
in_partition_section = False
187187
try:
188-
with self.engine.connect() as conn:
189-
rows = conn.execute(
190-
text(f"DESCRIBE FORMATTED `{schema_name}`.`{table_name}`")
188+
drivername = getattr(getattr(self.engine, "url", None), "drivername", "")
189+
if drivername in {"hive+mysql", "hive+postgres"}:
190+
query = (
191+
"""
192+
SELECT pk.PKEY_NAME
193+
FROM PARTITION_KEYS pk
194+
JOIN TBLS tbl ON pk.TBL_ID = tbl.TBL_ID
195+
JOIN DBS db ON tbl.DB_ID = db.DB_ID
196+
WHERE db.NAME = :schema AND tbl.TBL_NAME = :table_name
197+
ORDER BY pk.INTEGER_IDX
198+
"""
199+
if drivername == "hive+mysql"
200+
else """
201+
SELECT pk."PKEY_NAME"
202+
FROM "PARTITION_KEYS" pk
203+
JOIN "TBLS" tbl ON pk."TBL_ID" = tbl."TBL_ID"
204+
JOIN "DBS" db ON tbl."DB_ID" = db."DB_ID"
205+
WHERE db."NAME" = :schema AND tbl."TBL_NAME" = :table_name
206+
ORDER BY pk."INTEGER_IDX"
207+
"""
191208
)
192-
for row in rows:
193-
col_name = row[0].strip() if row[0] else ""
194-
if col_name == "# Partition Information":
195-
in_partition_section = True
196-
continue
197-
if in_partition_section:
198-
if not col_name or col_name.startswith("# Detailed"):
199-
break
200-
if col_name.startswith("#"):
209+
rows = self.connection.execute(
210+
text(query),
211+
{"table_name": table_name, "schema": schema_name},
212+
).fetchall()
213+
partition_keys = [row[0] for row in rows if row and row[0]]
214+
else:
215+
with self.engine.connect() as conn:
216+
rows = conn.execute(
217+
text(f"DESCRIBE FORMATTED `{schema_name}`.`{table_name}`")
218+
)
219+
for row in rows:
220+
col_name = row[0].strip() if row[0] else ""
221+
if col_name == "# Partition Information":
222+
in_partition_section = True
201223
continue
202-
partition_keys.append(col_name)
224+
if in_partition_section:
225+
if not col_name or col_name.startswith("# Detailed"):
226+
break
227+
if col_name.startswith("#"):
228+
continue
229+
partition_keys.append(col_name)
203230
except Exception as exc:
204231
logger.debug(traceback.format_exc())
205232
logger.warning(

ingestion/tests/unit/topology/database/test_hive.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,15 @@ def __init__(
360360
self.thread_id = self.hive.context.get_current_thread_id()
361361
self.hive._inspector_map[self.thread_id] = types.SimpleNamespace()
362362

363-
self.hive._inspector_map[
364-
self.thread_id
365-
].get_pk_constraint = lambda table_name, schema_name: []
366-
self.hive._inspector_map[
367-
self.thread_id
368-
].get_unique_constraints = lambda table_name, schema_name: []
369-
self.hive._inspector_map[
370-
self.thread_id
371-
].get_foreign_keys = lambda table_name, schema_name: []
363+
self.hive._inspector_map[self.thread_id].get_pk_constraint = (
364+
lambda table_name, schema_name: []
365+
)
366+
self.hive._inspector_map[self.thread_id].get_unique_constraints = (
367+
lambda table_name, schema_name: []
368+
)
369+
self.hive._inspector_map[self.thread_id].get_foreign_keys = (
370+
lambda table_name, schema_name: []
371+
)
372372

373373
def test_yield_database(self):
374374
assert EXPECTED_DATABASE == [
@@ -518,6 +518,41 @@ def test_get_columns_deduplicates_partition_column_with_sentinel(self):
518518
)
519519
self.assertEqual(col_names, ["id", "name", "dt"])
520520

521+
def test_get_table_partition_details_from_metastore_mysql(self):
522+
self.hive.engine = types.SimpleNamespace(
523+
url=types.SimpleNamespace(drivername="hive+mysql")
524+
)
525+
mock_connection = Mock()
526+
mock_connection.execute.return_value.fetchall.return_value = [
527+
("dt",),
528+
("region",),
529+
]
530+
self.hive._connection_map[self.thread_id] = mock_connection
531+
532+
is_partitioned, partition = self.hive.get_table_partition_details(
533+
table_name="sample_table", schema_name="sample_schema", inspector=Mock()
534+
)
535+
536+
self.assertTrue(is_partitioned)
537+
self.assertIsNotNone(partition)
538+
self.assertEqual([c.columnName for c in partition.columns], ["dt", "region"])
539+
540+
def test_get_table_partition_details_from_metastore_postgres(self):
541+
self.hive.engine = types.SimpleNamespace(
542+
url=types.SimpleNamespace(drivername="hive+postgres")
543+
)
544+
mock_connection = Mock()
545+
mock_connection.execute.return_value.fetchall.return_value = [("dt",)]
546+
self.hive._connection_map[self.thread_id] = mock_connection
547+
548+
is_partitioned, partition = self.hive.get_table_partition_details(
549+
table_name="sample_table", schema_name="sample_schema", inspector=Mock()
550+
)
551+
552+
self.assertTrue(is_partitioned)
553+
self.assertIsNotNone(partition)
554+
self.assertEqual([c.columnName for c in partition.columns], ["dt"])
555+
521556
def test_ssl_connection_configuration(self):
522557
"""
523558
Test SSL configuration in Hive connection

0 commit comments

Comments
 (0)