Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""StarRocks source module"""

import re
import traceback
from typing import Dict, Iterable, List, Optional, Tuple, cast
Expand Down Expand Up @@ -41,6 +42,7 @@
STARROCKS_GET_TABLE_NAMES,
STARROCKS_PARTITION_DETAILS,
STARROCKS_SHOW_FULL_COLUMNS,
STARROCKS_TABLE_COMMENTS,
)
from metadata.utils.logger import ingestion_logger
from metadata.utils.ssl_manager import SSLManager, check_ssl_and_init
Expand Down Expand Up @@ -279,26 +281,21 @@ def query_view_names_and_types(

return tables

@staticmethod
def get_table_description(
schema_name: str, table_name: str, inspector: Inspector
self, schema_name: str, table_name: str, _inspector: Inspector
) -> Optional[str]:
description = None
try:
table_info: dict = inspector.get_table_comment(table_name, schema_name)
row = self.connection.execute(
sql.text(STARROCKS_TABLE_COMMENTS),
{"table_name": table_name, "schema": schema_name},
).fetchone()
return row[0] if row else None
except Exception as exc: # pylint: disable=broad-except
logger.debug(traceback.format_exc())
logger.warning(
f"Table description error for table [{schema_name}.{table_name}]: {exc}"
)
else:
description = table_info.get("text")

if description is None:
return None
if isinstance(description, (list, tuple)) and len(description) > 0:
return description[0]
return description
return None

def _get_columns(self, table_name, schema=None):
"""Get column information and primary key columns of the specified table"""
Expand Down
28 changes: 27 additions & 1 deletion ingestion/tests/unit/topology/database/test_starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

from unittest import TestCase
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
from sqlalchemy import types as sqltypes
Expand Down Expand Up @@ -114,6 +114,32 @@ def test_close_connection(self, engine, connection):
connection.return_value = True
self.starrocks_source.close()

def test_get_table_description_returns_comment(self):
mock_result = Mock()
mock_result.fetchone.return_value = ("table comment",)
mock_connection = Mock()
mock_connection.execute.return_value = mock_result
self.starrocks_source.connection = mock_connection

result = self.starrocks_source.get_table_description(
schema_name="public", table_name="my_table", _inspector=Mock()
)

self.assertEqual(result, "table comment")

def test_get_table_description_returns_none_when_missing(self):
mock_result = Mock()
mock_result.fetchone.return_value = None
mock_connection = Mock()
mock_connection.execute.return_value = mock_result
self.starrocks_source.connection = mock_connection

result = self.starrocks_source.get_table_description(
schema_name="public", table_name="my_table", _inspector=Mock()
)

self.assertIsNone(result)


class StarRocksSSLUnitTest(TestCase):
@patch(
Expand Down
Loading