diff --git a/ingestion/src/metadata/ingestion/source/database/starrocks/metadata.py b/ingestion/src/metadata/ingestion/source/database/starrocks/metadata.py index 0548b6851f7c..59cd63a06e86 100644 --- a/ingestion/src/metadata/ingestion/source/database/starrocks/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/starrocks/metadata.py @@ -43,6 +43,7 @@ STARROCKS_PARTITION_DETAILS, STARROCKS_SHOW_FULL_COLUMNS, ) +from metadata.ingestion.source.database.starrocks.utils import get_table_comment from metadata.utils.logger import ingestion_logger from metadata.utils.ssl_manager import SSLManager, check_ssl_and_init @@ -258,21 +259,16 @@ def query_view_names_and_types(self, schema_name: str) -> Iterable[TableNameAndT return tables - @staticmethod - def get_table_description(schema_name: str, table_name: str, inspector: Inspector) -> Optional[str]: + # pylint: disable=arguments-differ,unused-argument + def get_table_description(self, schema_name: str, table_name: str, inspector: Inspector) -> Optional[str]: description = None try: - table_info: dict = inspector.get_table_comment(table_name, schema_name) + table_info = get_table_comment(None, self.connection, table_name, schema=schema_name) except Exception as exc: # pylint: disable=broad-except logger.debug(traceback.format_exc()) logger.warning(f"Table description error for table [{schema_name}.{table_name}]: {exc}") else: description = table_info.get("text") - - if description is None: - return None - if isinstance(description, (list, tuple)) and len(description) > 0: - return description[0] return description def _get_columns(self, table_name, schema=None): diff --git a/ingestion/src/metadata/ingestion/source/database/starrocks/utils.py b/ingestion/src/metadata/ingestion/source/database/starrocks/utils.py index 0025c307392b..a136b00eb64d 100644 --- a/ingestion/src/metadata/ingestion/source/database/starrocks/utils.py +++ b/ingestion/src/metadata/ingestion/source/database/starrocks/utils.py @@ -42,7 +42,7 @@ def get_table_comment(_, connection, table_name, schema=None, **kw): {"table_name": table_name, "schema": schema}, **kw, ) - for table_comment in rows: - comment = table_comment + for table_comment in rows.mappings(): + comment = table_comment["TABLE_COMMENT"] break return {"text": comment} diff --git a/ingestion/tests/unit/topology/database/test_starrocks.py b/ingestion/tests/unit/topology/database/test_starrocks.py index 60e9030bb93e..71f3ef831860 100644 --- a/ingestion/tests/unit/topology/database/test_starrocks.py +++ b/ingestion/tests/unit/topology/database/test_starrocks.py @@ -14,7 +14,7 @@ """ from unittest import TestCase -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from sqlalchemy import types as sqltypes @@ -26,6 +26,7 @@ StarRocksSource, _get_sqlalchemy_type, ) +from metadata.ingestion.source.database.starrocks.utils import get_table_comment mock_starrocks_config = { "source": { @@ -171,3 +172,79 @@ def test_iceberg_relkind_mapping(self): from metadata.ingestion.source.database.starrocks.metadata import RELKIND_MAP assert RELKIND_MAP["ICEBERG"] == TableType.Iceberg + + +class TestStarRocksGetTableDescription(TestCase): + """Tests for get_table_description delegating to utils.get_table_comment""" + + @patch("metadata.ingestion.source.database.common_db_source.CommonDbSourceService.test_connection") + def setUp(self, test_connection): + test_connection.return_value = False + self.config = OpenMetadataWorkflowConfig.model_validate(mock_starrocks_config) + self.source = StarRocksSource.create( + mock_starrocks_config["source"], + self.config.workflowConfig.openMetadataServerConfig, + ) + thread_id = self.source.context.get_current_thread_id() + self.source._connection_map[thread_id] = MagicMock() + + @patch("metadata.ingestion.source.database.starrocks.metadata.get_table_comment") + def test_returns_table_comment(self, mock_get_table_comment): + mock_get_table_comment.return_value = {"text": "审计日志表"} + + description = self.source.get_table_description( + schema_name="test_db", + table_name="audit_tbl", + inspector=MagicMock(), + ) + assert description == "审计日志表" + mock_get_table_comment.assert_called_once_with(None, self.source.connection, "audit_tbl", schema="test_db") + + @patch("metadata.ingestion.source.database.starrocks.metadata.get_table_comment") + def test_returns_none_for_empty_comment(self, mock_get_table_comment): + mock_get_table_comment.return_value = {"text": None} + + description = self.source.get_table_description( + schema_name="test_db", + table_name="no_comment_tbl", + inspector=MagicMock(), + ) + assert description is None + + @patch("metadata.ingestion.source.database.starrocks.metadata.get_table_comment") + def test_returns_none_on_exception(self, mock_get_table_comment): + mock_get_table_comment.side_effect = Exception("connection error") + + description = self.source.get_table_description( + schema_name="test_db", + table_name="error_tbl", + inspector=MagicMock(), + ) + assert description is None + + +class TestGetTableComment(TestCase): + """Tests for utils.get_table_comment row parsing""" + + def _make_connection(self, rows): + connection = MagicMock() + connection.info = {} + result = MagicMock() + result.mappings.return_value = iter(rows) + connection.execute.return_value = result + return connection + + def test_returns_comment_from_row(self): + row = {"TABLE_COMMENT": "审计日志表"} + connection = self._make_connection([row]) + + result = get_table_comment(None, connection, "audit_tbl", schema="test_db") + + assert result == {"text": "审计日志表"} + + def test_returns_none_when_no_rows(self): + connection = self._make_connection([]) + + result = get_table_comment(None, connection, "missing_tbl", schema="test_db") + + assert result == {"text": None}