diff --git a/ingestion/src/metadata/domain/tags/canonicalizer.py b/ingestion/src/metadata/domain/tags/canonicalizer.py
index e532873d7422..cd47b166c77f 100644
--- a/ingestion/src/metadata/domain/tags/canonicalizer.py
+++ b/ingestion/src/metadata/domain/tags/canonicalizer.py
@@ -28,7 +28,9 @@
wait_random_exponential,
)
-from metadata.generated.schema.entity.classification.classification import Classification
+from metadata.generated.schema.entity.classification.classification import (
+ Classification,
+)
from metadata.generated.schema.entity.classification.tag import Tag
from metadata.generated.schema.type.basic import ProviderType
from metadata.ingestion.ometa.ometa_api import OpenMetadata
@@ -87,10 +89,15 @@ def classification(
results = self._es_search(Classification, name)
canonical = Canonical(name=name, description=default_description)
for entity in results:
- if entity.provider == ProviderType.system and entity.name.root.lower() == key:
+ if (
+ entity.provider == ProviderType.system
+ and entity.name.root.lower() == key
+ ):
canonical = Canonical(
name=entity.name.root,
- description=entity.description.root if entity.description else default_description,
+ description=entity.description.root
+ if entity.description
+ else default_description,
)
break
@@ -113,7 +120,9 @@ def tag(
"""
tag_fqn = cast(
"str",
- fqn.build(None, Tag, classification_name=classification_name, tag_name=tag_name),
+ fqn.build(
+ None, Tag, classification_name=classification_name, tag_name=tag_name
+ ),
)
key = tag_fqn.lower()
with self._lock:
@@ -131,7 +140,9 @@ def tag(
):
canonical = Canonical(
name=entity.name.root,
- description=entity.description.root if entity.description else default_tag_description,
+ description=entity.description.root
+ if entity.description
+ else default_tag_description,
)
break
@@ -142,4 +153,9 @@ def tag(
@_es_retry
def _es_search(self, entity_type: Any, search_string: str) -> Iterable[Any]:
"""Run an ES search by FQN with retries."""
- return self._metadata.es_search_from_fqn(entity_type=entity_type, fqn_search_string=search_string) or []
+ return (
+ self._metadata.es_search_from_fqn(
+ entity_type=entity_type, fqn_search_string=search_string
+ )
+ or []
+ )
diff --git a/ingestion/src/metadata/domain/tags/registry.py b/ingestion/src/metadata/domain/tags/registry.py
index 28973d6a8c9c..5b35aa0c9a22 100644
--- a/ingestion/src/metadata/domain/tags/registry.py
+++ b/ingestion/src/metadata/domain/tags/registry.py
@@ -82,14 +82,24 @@ def __init__(self, metadata: OpenMetadata) -> None:
self._lock = threading.Lock()
def _intern_tag_label_locked(
- self, *, classification_name: str, tag_name: str, label_type: LabelType, state: State
+ self,
+ *,
+ classification_name: str,
+ tag_name: str,
+ label_type: LabelType,
+ state: State,
) -> TagLabel:
"""Return the shared ``TagLabel`` for the given key. Caller must hold ``self._lock``."""
key = _TagLabelKey(classification_name, tag_name, label_type, state)
cached = self._tag_label_cache.get(key)
if cached is not None:
return cached
- tag_fqn = cast("str", fqn.build(None, Tag, classification_name=classification_name, tag_name=tag_name))
+ tag_fqn = cast(
+ "str",
+ fqn.build(
+ None, Tag, classification_name=classification_name, tag_name=tag_name
+ ),
+ )
cached = TagLabel( # pyright: ignore[reportCallIssue]
tagFQN=TagFQN(tag_fqn),
labelType=label_type,
@@ -113,7 +123,10 @@ def attach(
) -> None:
"""Register a tag <-> entity association."""
if not tag_name or not tag_name.strip():
- logger.debug("TagRegistry: skipping empty tag for classification %s", classification_name)
+ logger.debug(
+ "TagRegistry: skipping empty tag for classification %s",
+ classification_name,
+ )
return
with self._lock:
@@ -164,11 +177,19 @@ def clear_scope(self, scope_fqn: str) -> None:
with self._lock:
self._cleared_scopes.add(scope_fqn)
- kept = {k: v for k, v in self._labels_by_entity.items() if k != scope_fqn and not k.startswith(prefix)}
+ kept = {
+ k: v
+ for k, v in self._labels_by_entity.items()
+ if k != scope_fqn and not k.startswith(prefix)
+ }
dropped = len(self._labels_by_entity) - len(kept)
self._labels_by_entity = kept
if dropped:
- logger.debug("TagRegistry: cleared scope %s (%d entity labels dropped)", scope_fqn, dropped)
+ logger.debug(
+ "TagRegistry: cleared scope %s (%d entity labels dropped)",
+ scope_fqn,
+ dropped,
+ )
def is_known(self, tag_fqn: str) -> bool:
"""Return True if the tag FQN has been recorded (case-sensitive match)."""
@@ -183,7 +204,9 @@ def ensure_known(self, tag_fqn: str) -> bool:
if self.is_known(tag_fqn):
return True
- logger.debug("TagRegistry: cache miss for %s; fetching from OpenMetadata.", tag_fqn)
+ logger.debug(
+ "TagRegistry: cache miss for %s; fetching from OpenMetadata.", tag_fqn
+ )
try:
entity = self._metadata.get_by_name(entity=Tag, fqn=tag_fqn)
except Exception:
@@ -192,7 +215,8 @@ def ensure_known(self, tag_fqn: str) -> bool:
if entity is None:
logger.warning(
- "TagRegistry: tag %s not found in OpenMetadata; labels referencing it will be skipped.", tag_fqn
+ "TagRegistry: tag %s not found in OpenMetadata; labels referencing it will be skipped.",
+ tag_fqn,
)
return False
diff --git a/ingestion/src/metadata/ingestion/models/topology.py b/ingestion/src/metadata/ingestion/models/topology.py
index 56d72efcf4cb..bbb67a6949b0 100644
--- a/ingestion/src/metadata/ingestion/models/topology.py
+++ b/ingestion/src/metadata/ingestion/models/topology.py
@@ -14,7 +14,16 @@
import queue
import threading
from functools import cache, singledispatchmethod
-from typing import Annotated, Any, Dict, Generic, List, Optional, Type, TypeVar # noqa: UP035
+from typing import ( # noqa: UP035
+ Annotated,
+ Any,
+ Dict,
+ Generic,
+ List,
+ Optional,
+ Type,
+ TypeVar,
+)
from pydantic import BaseModel, ConfigDict, Field, create_model
@@ -126,7 +135,9 @@ class TopologyNode(BaseModel):
] = None
threads: Annotated[
bool,
- Field(description="Flag that defines if a node is open to MultiThreading processing."),
+ Field(
+ description="Flag that defines if a node is open to MultiThreading processing."
+ ),
] = False
diff --git a/ingestion/src/metadata/ingestion/source/dashboard/powerbi/metadata.py b/ingestion/src/metadata/ingestion/source/dashboard/powerbi/metadata.py
index f841e5e11e15..64dcfd701098 100644
--- a/ingestion/src/metadata/ingestion/source/dashboard/powerbi/metadata.py
+++ b/ingestion/src/metadata/ingestion/source/dashboard/powerbi/metadata.py
@@ -792,7 +792,9 @@ def create_report_dashboard_lineage(
) -> Iterable[Either[CreateDashboardRequest]]:
"""Create lineage between report and dashboard"""
try:
- logger.debug(f"Processing to create report and dashboard lineage for dashboard: {dashboard_details.id}")
+ logger.debug(
+ f"Processing to create report and dashboard lineage for dashboard: {dashboard_details.id}"
+ )
charts = dashboard_details.tiles
dashboard_fqn = fqn.build(
self.metadata,
@@ -811,9 +813,13 @@ def create_report_dashboard_lineage(
return
for chart in charts or []:
if chart.reportId:
- logger.debug(f"Dashboard's chart {chart.id} is linked with report id: {str(chart.reportId)}") # noqa: RUF010
+ logger.debug(
+ f"Dashboard's chart {chart.id} is linked with report id: {str(chart.reportId)}"
+ ) # noqa: RUF010
else:
- logger.debug(f"Dashboard's chart {chart.id} is not linked with any report")
+ logger.debug(
+ f"Dashboard's chart {chart.id} is not linked with any report"
+ )
continue
report = self._fetch_report_from_workspace(chart.reportId)
if report:
@@ -838,7 +844,9 @@ def create_report_dashboard_lineage(
logger.debug(
f"Creating lineage between report={report.id} and dashboard={dashboard_details.id}"
)
- yield self._get_add_lineage_request(to_entity=dashboard_entity, from_entity=report_entity)
+ yield self._get_add_lineage_request(
+ to_entity=dashboard_entity, from_entity=report_entity
+ )
else:
logger.debug(
f"Could not fetch report with report id: {str(chart.reportId)} from workspace data to create lineage with dashboard: {dashboard_details.id}" # noqa: RUF010
@@ -873,7 +881,9 @@ def _get_dataset_ids_from_report_datasources(self, report_id: str) -> List[str]:
if match:
dataset_ids.append(match.group(1))
if dataset_ids:
- logger.debug(f"Extracted dataset IDs from report datasources API call for report_id={report_id}")
+ logger.debug(
+ f"Extracted dataset IDs from report datasources API call for report_id={report_id}"
+ )
return dataset_ids
def create_datamodel_report_lineage(
@@ -885,7 +895,9 @@ def create_datamodel_report_lineage(
create the lineage between datamodel and report
"""
try:
- logger.debug(f"Processing to create datamodel and report lineage for report: {dashboard_details.id}")
+ logger.debug(
+ f"Processing to create datamodel and report lineage for report: {dashboard_details.id}"
+ )
report_fqn = fqn.build(
self.metadata,
entity_type=Dashboard,
@@ -903,13 +915,17 @@ def create_datamodel_report_lineage(
return
dataset_ids = []
if dashboard_details.datasetId:
- logger.debug(f"Report linked datasetId is present in api response for report: {dashboard_details.id}")
+ logger.debug(
+ f"Report linked datasetId is present in api response for report: {dashboard_details.id}"
+ )
dataset_ids = [dashboard_details.datasetId]
else:
logger.debug(
f"Processing to get report datasources from API to extract datasetIds for report: {dashboard_details.id} as datasetId is not present in api response"
)
- dataset_ids = self._get_dataset_ids_from_report_datasources(report_id=dashboard_details.id)
+ dataset_ids = self._get_dataset_ids_from_report_datasources(
+ report_id=dashboard_details.id
+ )
if dataset_ids:
for dataset_id in dataset_ids:
@@ -2365,7 +2381,9 @@ def _fetch_dataset_from_workspace(
return dataset_data
return None
- def _fetch_report_from_workspace(self, report_id: Optional[str]) -> Optional[PowerBIReport]: # noqa: UP045
+ def _fetch_report_from_workspace(
+ self, report_id: Optional[str]
+ ) -> Optional[PowerBIReport]: # noqa: UP045
"""
Method to search the report using id in the workspace dict
"""
diff --git a/ingestion/src/metadata/ingestion/source/database/athena/metadata.py b/ingestion/src/metadata/ingestion/source/database/athena/metadata.py
index 7400f35643d8..00399726a542 100644
--- a/ingestion/src/metadata/ingestion/source/database/athena/metadata.py
+++ b/ingestion/src/metadata/ingestion/source/database/athena/metadata.py
@@ -166,7 +166,9 @@ def prepare(self):
def get_schema_description(self, schema_name: str) -> Optional[str]:
return self.schema_description_map.get(schema_name)
- def query_table_names_and_types(self, schema_name: str) -> Iterable[TableNameAndType]:
+ def query_table_names_and_types(
+ self, schema_name: str
+ ) -> Iterable[TableNameAndType]:
"""Return tables with proper type detection using a single Glue API pass."""
if self.glue_client:
try:
@@ -176,13 +178,19 @@ def query_table_names_and_types(self, schema_name: str) -> Iterable[TableNameAnd
for table in page.get("TableList", []):
params = table.get("Parameters", {})
table_type = (
- TableType.Iceberg if params.get("table_type") == ICEBERG_TABLE_TYPE else TableType.External
+ TableType.Iceberg
+ if params.get("table_type") == ICEBERG_TABLE_TYPE
+ else TableType.External
+ )
+ results.append(
+ TableNameAndType(name=table["Name"], type_=table_type)
)
- results.append(TableNameAndType(name=table["Name"], type_=table_type))
return results # noqa: TRY300
except Exception as exc:
logger.debug(traceback.format_exc())
- logger.warning(f"Failed to fetch Glue table metadata for schema [{schema_name}]: {exc}")
+ logger.warning(
+ f"Failed to fetch Glue table metadata for schema [{schema_name}]: {exc}"
+ )
return [
TableNameAndType(name=name, type_=TableType.External)
for name in self.inspector.get_table_names(schema_name)
@@ -335,9 +343,9 @@ def get_table_description(
try:
table_info: dict = inspector.get_table_comment(table_name, schema_name)
table_option = inspector.get_table_options(table_name, schema_name)
- self.external_location_map[(self.context.get().database, schema_name, table_name)] = table_option.get(
- "awsathena_location"
- )
+ self.external_location_map[
+ (self.context.get().database, schema_name, table_name)
+ ] = table_option.get("awsathena_location")
# Catch any exception without breaking the ingestion
except Exception as exc: # pylint: disable=broad-except
logger.debug(traceback.format_exc())
@@ -368,7 +376,9 @@ def _get_columns_internal(
glue_client=self.glue_client,
)
- def get_table_extensions(self, table_name: str, table_type: TableType | None = None) -> dict[str, str] | None:
+ def get_table_extensions(
+ self, table_name: str, table_type: TableType | None = None
+ ) -> dict[str, str] | None:
if not getattr(self.source_config, "includeCustomProperties", False):
return None
if not self._string_property_type_ref:
@@ -383,9 +393,13 @@ def get_table_extensions(self, table_name: str, table_type: TableType | None = N
for prop_name, prop_value in tbl_properties.items():
if not prop_value:
continue
- sanitized_name = PROPERTY_NAME_INVALID_CHARS_PATTERN.sub(PROPERTY_NAME_REPLACEMENT, prop_name)
+ sanitized_name = PROPERTY_NAME_INVALID_CHARS_PATTERN.sub(
+ PROPERTY_NAME_REPLACEMENT, prop_name
+ )
if len(sanitized_name) > PROPERTY_NAME_MAX_LENGTH:
- sanitized_name = hashlib.md5(prop_name.encode("utf-8"), usedforsecurity=False).hexdigest()
+ sanitized_name = hashlib.md5(
+ prop_name.encode("utf-8"), usedforsecurity=False
+ ).hexdigest()
if sanitized_name not in self._processed_prop:
try:
self.metadata.create_or_update_custom_property( # pyright: ignore[reportUnknownMemberType, reportUnusedCallResult]
@@ -410,14 +424,24 @@ def get_table_extensions(self, table_name: str, table_type: TableType | None = N
registered_properties[sanitized_name] = prop_value
return registered_properties or None
- def _fetch_iceberg_properties(self, schema_name: str, table_name: str) -> dict[str, str]:
+ def _fetch_iceberg_properties(
+ self, schema_name: str, table_name: str
+ ) -> dict[str, str]:
"""Read Iceberg native properties from Athena's `
$properties` metatable."""
- query = text(f'SELECT key, value FROM "{schema_name}"."{table_name}$properties"')
+ query = text(
+ f'SELECT key, value FROM "{schema_name}"."{table_name}$properties"'
+ )
try:
with self.engine.connect() as conn:
result = conn.execute(query)
- return {str(row[0]): str(row[1]) for row in result if row[0] is not None and row[1] is not None}
+ return {
+ str(row[0]): str(row[1])
+ for row in result
+ if row[0] is not None and row[1] is not None
+ }
except Exception as exc:
- logger.debug(f"Unable to read Iceberg $properties for [{schema_name}.{table_name}]: {exc}")
+ logger.debug(
+ f"Unable to read Iceberg $properties for [{schema_name}.{table_name}]: {exc}"
+ )
logger.debug(traceback.format_exc())
return {}
diff --git a/ingestion/src/metadata/ingestion/source/database/common_db_source.py b/ingestion/src/metadata/ingestion/source/database/common_db_source.py
index 287303771cce..f73aceba5f85 100644
--- a/ingestion/src/metadata/ingestion/source/database/common_db_source.py
+++ b/ingestion/src/metadata/ingestion/source/database/common_db_source.py
@@ -385,12 +385,16 @@ def get_tables_name_and_type(self) -> Optional[Iterable[Tuple[str, str]]]:
try:
table_iter = self.query_table_names_and_types(schema_name)
except Exception as err:
- logger.warning(f"Fetching table list failed for schema {schema_name} due to - {err}")
+ logger.warning(
+ f"Fetching table list failed for schema {schema_name} due to - {err}"
+ )
logger.debug(traceback.format_exc())
table_iter = []
for table_and_type in table_iter:
try:
- table_name = self.standardize_table_name(schema_name, table_and_type.name)
+ table_name = self.standardize_table_name(
+ schema_name, table_and_type.name
+ )
table_fqn = fqn.build(
self.metadata,
entity_type=Table,
@@ -414,7 +418,9 @@ def get_tables_name_and_type(self) -> Optional[Iterable[Tuple[str, str]]]:
)
continue
except Exception as err:
- logger.warning(f"Skipping table {table_and_type.name!r} in schema {schema_name} due to - {err}")
+ logger.warning(
+ f"Skipping table {table_and_type.name!r} in schema {schema_name} due to - {err}"
+ )
logger.debug(traceback.format_exc())
continue
yield table_name, table_and_type.type_
@@ -423,12 +429,16 @@ def get_tables_name_and_type(self) -> Optional[Iterable[Tuple[str, str]]]:
try:
view_iter = self.query_view_names_and_types(schema_name)
except Exception as err:
- logger.warning(f"Fetching view list failed for schema {schema_name} due to - {err}")
+ logger.warning(
+ f"Fetching view list failed for schema {schema_name} due to - {err}"
+ )
logger.debug(traceback.format_exc())
view_iter = []
for view_and_type in view_iter:
try:
- view_name = self.standardize_table_name(schema_name, view_and_type.name)
+ view_name = self.standardize_table_name(
+ schema_name, view_and_type.name
+ )
view_fqn = fqn.build(
self.metadata,
entity_type=Table,
@@ -452,7 +462,9 @@ def get_tables_name_and_type(self) -> Optional[Iterable[Tuple[str, str]]]:
)
continue
except Exception as err:
- logger.warning(f"Skipping view {view_and_type.name!r} in schema {schema_name} due to - {err}")
+ logger.warning(
+ f"Skipping view {view_and_type.name!r} in schema {schema_name} due to - {err}"
+ )
logger.debug(traceback.format_exc())
continue
yield view_name, view_and_type.type_
@@ -634,8 +646,12 @@ def yield_table(
table_type=table_type,
),
owners=self.get_owner_ref(table_name=table_name),
- locationPath=self.get_location_path(table_name=table_name, schema_name=schema_name),
- extension=self.get_table_extensions(table_name=table_name, table_type=table_type),
+ locationPath=self.get_location_path(
+ table_name=table_name, schema_name=schema_name
+ ),
+ extension=self.get_table_extensions(
+ table_name=table_name, table_type=table_type
+ ),
)
is_partitioned, partition_details = self.get_table_partition_details(
diff --git a/ingestion/src/metadata/ingestion/source/database/database_service.py b/ingestion/src/metadata/ingestion/source/database/database_service.py
index 5031b26c6920..b8ccea3bbb61 100644
--- a/ingestion/src/metadata/ingestion/source/database/database_service.py
+++ b/ingestion/src/metadata/ingestion/source/database/database_service.py
@@ -247,7 +247,9 @@ def tags_registry(self) -> TagRegistry:
cached = instance_dict.get("tags_registry")
if cached is not None:
return cached
- return instance_dict.setdefault("tags_registry", TagRegistry(metadata=self.metadata))
+ return instance_dict.setdefault(
+ "tags_registry", TagRegistry(metadata=self.metadata)
+ )
@property
def tag_canonicalizer(self) -> TagCanonicalizer:
@@ -256,7 +258,9 @@ def tag_canonicalizer(self) -> TagCanonicalizer:
cached = instance_dict.get("tag_canonicalizer")
if cached is not None:
return cached
- return instance_dict.setdefault("tag_canonicalizer", TagCanonicalizer(metadata=self.metadata))
+ return instance_dict.setdefault(
+ "tag_canonicalizer", TagCanonicalizer(metadata=self.metadata)
+ )
@property
def name(self) -> str:
@@ -926,7 +930,9 @@ def yield_life_cycle_data(self, _) -> Iterable[Either[OMetaLifeCycleData]]:
def clear_schema_tag_scope(self):
"""Drop tag-registry state for the current schema scope."""
- schema_name = self.context.get().database_schema # pyright: ignore[reportAttributeAccessIssue]
+ schema_name = (
+ self.context.get().database_schema
+ ) # pyright: ignore[reportAttributeAccessIssue]
if schema_name:
schema_fqn = cast(
"str",
@@ -943,7 +949,9 @@ def clear_schema_tag_scope(self):
def clear_database_tag_scope(self):
"""Drop tag-registry state for the current database scope."""
- database_name = self.context.get().database # pyright: ignore[reportAttributeAccessIssue]
+ database_name = (
+ self.context.get().database
+ ) # pyright: ignore[reportAttributeAccessIssue]
if database_name:
database_fqn = cast(
"str",
diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/metadata.py b/ingestion/src/metadata/ingestion/source/database/databricks/metadata.py
index d1716405a7f5..abffccf477f7 100644
--- a/ingestion/src/metadata/ingestion/source/database/databricks/metadata.py
+++ b/ingestion/src/metadata/ingestion/source/database/databricks/metadata.py
@@ -150,13 +150,17 @@ def _fetch_nested_descriptions_via_describe_json(
return {}
try:
result = connection.execute(
- text(f"DESCRIBE TABLE EXTENDED `{db_name}`.`{schema}`.`{table_name}` AS JSON")
+ text(
+ f"DESCRIBE TABLE EXTENDED `{db_name}`.`{schema}`.`{table_name}` AS JSON"
+ )
).fetchone()
if not result or not result[0]:
return {}
payload = json.loads(result[0])
except Exception as err: # pylint: disable=broad-except
- logger.debug(f"DESCRIBE AS JSON unavailable or unparseable for {db_name}.{schema}.{table_name}: {err}")
+ logger.debug(
+ f"DESCRIBE AS JSON unavailable or unparseable for {db_name}.{schema}.{table_name}: {err}"
+ )
return {}
return _build_column_descriptions_map(payload)
@@ -341,7 +345,9 @@ def get_columns(self, connection, table_name, schema=None, **kw):
sub_rows = {
r[0]: r[1]
for r in connection.execute(
- text(f"DESCRIBE TABLE `{kw.get('db_name')}`.`{schema}`.`{table_name}` `{col_name}`")
+ text(
+ f"DESCRIBE TABLE `{kw.get('db_name')}`.`{schema}`.`{table_name}` `{col_name}`"
+ )
).fetchall()
}
col_info["system_data_type"] = sub_rows["data_type"]
@@ -361,10 +367,14 @@ def get_columns(self, connection, table_name, schema=None, **kw):
)
if supports_nested_descriptions:
if nested_descriptions_by_column is None:
- nested_descriptions_by_column = _fetch_nested_descriptions_via_describe_json(
- connection, kw.get("db_name"), schema, table_name
+ nested_descriptions_by_column = (
+ _fetch_nested_descriptions_via_describe_json(
+ connection, kw.get("db_name"), schema, table_name
+ )
)
- nested_descriptions = nested_descriptions_by_column.get(col_name)
+ nested_descriptions = nested_descriptions_by_column.get(
+ col_name
+ )
if nested_descriptions:
col_info["nested_descriptions"] = nested_descriptions
except (DatabaseError, KeyError) as err:
@@ -374,7 +384,9 @@ def get_columns(self, connection, table_name, schema=None, **kw):
logger.debug(traceback.format_exc())
result.append(col_info)
except Exception as err: # pylint: disable=broad-except
- logger.warning(f"Skipping column '{col_name}' in {schema}.{table_name} due to unexpected error: {err}")
+ logger.warning(
+ f"Skipping column '{col_name}' in {schema}.{table_name} due to unexpected error: {err}"
+ )
logger.debug(traceback.format_exc())
return result
diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py
index 86141bdf3ff5..019b6e50c1ce 100644
--- a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py
+++ b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py
@@ -50,10 +50,7 @@
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
-from metadata.generated.schema.type.basic import (
- EntityName,
- SourceUrl,
-)
+from metadata.generated.schema.type.basic import EntityName, SourceUrl
from metadata.generated.schema.type.entityReferenceList import EntityReferenceList
from metadata.generated.schema.type.tagLabel import TagLabel
from metadata.ingestion.api.delete import delete_entity_by_name
@@ -578,13 +575,17 @@ def yield_tag(
)
continue
- entity_fqn = fqn._build(self.context.get().database_service, *fqn_elements) # pyright: ignore[reportAttributeAccessIssue]
+ entity_fqn = fqn._build(
+ self.context.get().database_service, *fqn_elements
+ ) # pyright: ignore[reportAttributeAccessIssue]
try:
classification = self.tag_canonicalizer.classification(
row[0], default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION
)
tag = self.tag_canonicalizer.tag(
- classification.name, row[1], default_tag_description=SNOWFLAKE_TAG_DESCRIPTION
+ classification.name,
+ row[1],
+ default_tag_description=SNOWFLAKE_TAG_DESCRIPTION,
)
self.tags_registry.attach(
@@ -611,7 +612,8 @@ def yield_tag(
for tag_info in self.schema_tags_map[schema_name]:
try:
classification = self.tag_canonicalizer.classification(
- tag_info["tag_name"], default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION
+ tag_info["tag_name"],
+ default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
)
tag = self.tag_canonicalizer.tag(
classification.name,
@@ -637,9 +639,13 @@ def yield_tag(
),
right=None,
)
- yield from (Either(left=None, right=record) for record in self.tags_registry.drain())
+ yield from (
+ Either(left=None, right=record) for record in self.tags_registry.drain()
+ )
- def yield_database_tag(self, database_name: str) -> Iterable[Either[OMetaTagAndClassification]]:
+ def yield_database_tag(
+ self, database_name: str
+ ) -> Iterable[Either[OMetaTagAndClassification]]:
"""Yield database-level tags for the topology."""
if not self.source_config.includeTags:
return
@@ -659,10 +665,13 @@ def yield_database_tag(self, database_name: str) -> Iterable[Either[OMetaTagAndC
for tag_info in self.database_tags_map[database_name]:
try:
classification = self.tag_canonicalizer.classification(
- tag_info["tag_name"], default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION
+ tag_info["tag_name"],
+ default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
)
tag = self.tag_canonicalizer.tag(
- classification.name, tag_info["tag_value"], default_tag_description=SNOWFLAKE_TAG_DESCRIPTION
+ classification.name,
+ tag_info["tag_value"],
+ default_tag_description=SNOWFLAKE_TAG_DESCRIPTION,
)
self.tags_registry.attach(
@@ -683,7 +692,9 @@ def yield_database_tag(self, database_name: str) -> Iterable[Either[OMetaTagAndC
),
right=None,
)
- yield from (Either(left=None, right=record) for record in self.tags_registry.drain())
+ yield from (
+ Either(left=None, right=record) for record in self.tags_registry.drain()
+ )
def _get_table_names_and_types(
self, schema_name: str, table_type: TableType = TableType.Regular
@@ -702,7 +713,11 @@ def _get_table_names_and_types(
)
deleted_fqns = []
- for table in snowflake_tables.get_deleted(): # pyright: ignore[reportAttributeAccessIssue]
+ for (
+ table
+ ) in (
+ snowflake_tables.get_deleted()
+ ): # pyright: ignore[reportAttributeAccessIssue]
try:
deleted_fqns.append(
fqn.build(
@@ -715,11 +730,16 @@ def _get_table_names_and_types(
)
)
except Exception as err:
- logger.warning(f"Skipping deleted-table FQN for {table.name!r} in schema {schema_name}: {err}")
+ logger.warning(
+ f"Skipping deleted-table FQN for {table.name!r} in schema {schema_name}: {err}"
+ )
logger.debug(traceback.format_exc())
self.context.get_global().deleted_tables.extend(deleted_fqns)
- return [TableNameAndType(name=table.name, type_=table.type_) for table in snowflake_tables.get_not_deleted()] # pyright: ignore[reportAttributeAccessIssue]
+ return [
+ TableNameAndType(name=table.name, type_=table.type_)
+ for table in snowflake_tables.get_not_deleted()
+ ] # pyright: ignore[reportAttributeAccessIssue]
def _get_stream_names_and_types(self, schema_name: str) -> List[TableNameAndType]:
table_type = TableType.Stream
@@ -902,8 +922,8 @@ def _get_stored_procedures_internal(
f"Missing ownership permissions on procedure {stored_procedure.name}."
" Trying to fetch description via DESCRIBE."
)
- stored_procedure.definition = (
- self.describe_procedure_definition(stored_procedure)
+ stored_procedure.definition = self.describe_procedure_definition(
+ stored_procedure
)
if self.is_stored_procedure_filtered(stored_procedure.name):
continue
@@ -1160,14 +1180,18 @@ def _get_classification_name(self, tag_label: TagLabel) -> str:
parts = fqn.split(tag_fqn) if tag_fqn else []
return parts[0] if parts else tag_fqn
- def _has_classification(self, classification_name: str, tag_list: List[TagLabel]) -> bool: # noqa: UP006
+ def _has_classification(
+ self, classification_name: str, tag_list: List[TagLabel]
+ ) -> bool: # noqa: UP006
"""Check if a tag with the given classification name already exists"""
for tag in tag_list: # noqa: SIM110
if self._get_classification_name(tag) == classification_name:
return True
return False
- def get_database_tag_labels(self, database_name: str) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
+ def get_database_tag_labels(
+ self, database_name: str
+ ) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
"""Return tags for the database entity from registry."""
database_fqn = cast(
"str",
@@ -1180,7 +1204,9 @@ def get_database_tag_labels(self, database_name: str) -> Optional[List[TagLabel]
)
return self.tags_registry.labels_for(database_fqn) or None
- def get_column_tag_labels(self, table_name: str, column: dict) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
+ def get_column_tag_labels(
+ self, table_name: str, column: dict
+ ) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
"""Return tags for a column entity from the registry.
Column tags don't inherit from parent entities (table/schema/database)
@@ -1201,7 +1227,9 @@ def get_column_tag_labels(self, table_name: str, column: dict) -> Optional[List[
)
return self.tags_registry.labels_for(col_fqn) or None
- def get_schema_tag_labels(self, schema_name: str) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
+ def get_schema_tag_labels(
+ self, schema_name: str
+ ) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
"""
Return tags for schema entity including:
1. Snowflake schema-level tags
@@ -1231,7 +1259,9 @@ def get_schema_tag_labels(self, schema_name: str) -> Optional[List[TagLabel]]:
# Add inherited database tags (only if classification doesn't already exist)
for label in self.tags_registry.labels_for(database_fqn):
- if not self._has_classification(self._get_classification_name(label), schema_tags):
+ if not self._has_classification(
+ self._get_classification_name(label), schema_tags
+ ):
schema_tags.append(label)
return schema_tags if schema_tags else None
@@ -1282,12 +1312,16 @@ def get_tag_labels(self, table_name: str) -> Optional[List[TagLabel]]:
# Add inherited schema tags (only if classification doesn't already exist)
for label in self.tags_registry.labels_for(schema_fqn):
- if not self._has_classification(self._get_classification_name(label), table_tags):
+ if not self._has_classification(
+ self._get_classification_name(label), table_tags
+ ):
table_tags.append(label)
# Add inherited database tags (only if classification doesn't already exist)
for label in self.tags_registry.labels_for(database_fqn):
- if not self._has_classification(self._get_classification_name(label), table_tags):
+ if not self._has_classification(
+ self._get_classification_name(label), table_tags
+ ):
table_tags.append(label)
return table_tags if table_tags else None
diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/utils.py b/ingestion/src/metadata/ingestion/source/database/snowflake/utils.py
index 33bfaa60410e..c70a8af38ade 100644
--- a/ingestion/src/metadata/ingestion/source/database/snowflake/utils.py
+++ b/ingestion/src/metadata/ingestion/source/database/snowflake/utils.py
@@ -441,7 +441,9 @@ def get_schema_columns(self, connection, schema, **kw):
ordinal_position,
) in result:
try:
- table_name = self.normalize_name(fqn.quote_name(table_name)) # noqa: PLW2901
+ table_name = self.normalize_name(
+ fqn.quote_name(table_name)
+ ) # noqa: PLW2901
except ValueError:
logger.warning(
"Skipping column row in schema %s with unsupported table name %r",
diff --git a/ingestion/src/metadata/pii/tag_analyzer.py b/ingestion/src/metadata/pii/tag_analyzer.py
index e114bd3e4621..d90201a2d154 100644
--- a/ingestion/src/metadata/pii/tag_analyzer.py
+++ b/ingestion/src/metadata/pii/tag_analyzer.py
@@ -83,10 +83,14 @@ def should_skip_recognizer(self, exception_list: list[RecognizerException]):
)
def _supports_language(self, created: EntityRecognizer) -> bool:
- return self._language is ClassificationLanguage.any or created.supported_language in {
- ClassificationLanguage.any.value,
- self._language.value,
- }
+ return (
+ self._language is ClassificationLanguage.any
+ or created.supported_language
+ in {
+ ClassificationLanguage.any.value,
+ self._language.value,
+ }
+ )
def get_recognizers_by(self, target: recognizer.Target) -> list[EntityRecognizer]:
if self.tag.autoClassificationEnabled is False:
diff --git a/ingestion/src/metadata/utils/datalake/datalake_utils.py b/ingestion/src/metadata/utils/datalake/datalake_utils.py
index 3811b19ba6ec..270e452a791e 100644
--- a/ingestion/src/metadata/utils/datalake/datalake_utils.py
+++ b/ingestion/src/metadata/utils/datalake/datalake_utils.py
@@ -337,7 +337,9 @@ def _get_columns(cls, data_frame: "DataFrame"):
}
if data_type == DataType.ARRAY:
parsed_string["arrayDataType"] = DataType.UNKNOWN
- struct_children = cls._get_array_struct_children(data_frame[column].dropna()[:100])
+ struct_children = cls._get_array_struct_children(
+ data_frame[column].dropna()[:100]
+ )
if struct_children:
parsed_string["arrayDataType"] = DataType.STRUCT
parsed_string["children"] = struct_children
@@ -377,10 +379,14 @@ def fetch_col_types(cls, data_frame, column_name):
for df_row_val in df_row_val_list:
try:
if isinstance(df_row_val, (dict, list)):
- parsed_object_datatype_list.append(type(df_row_val).__name__.lower())
+ parsed_object_datatype_list.append(
+ type(df_row_val).__name__.lower()
+ )
else:
parsed_object_datatype_list.append(
- type(ast.literal_eval(str(df_row_val))).__name__.lower()
+ type(
+ ast.literal_eval(str(df_row_val))
+ ).__name__.lower()
)
except (ValueError, SyntaxError):
# we try to parse the value as a datetime, if it fails, we fallback to string
@@ -394,7 +400,9 @@ def fetch_col_types(cls, data_frame, column_name):
if not str(df_row_val).isnumeric():
# check if the row value is time
try:
- datetime.strptime(str(df_row_val), "%H:%M:%S").time()
+ datetime.strptime(
+ str(df_row_val), "%H:%M:%S"
+ ).time()
dtype_ = "timedelta[ns]"
except (ValueError, TypeError):
# check if the row value is date / time / datetime
@@ -454,11 +462,19 @@ def unique_json_structure(cls, dicts: List[Dict]) -> Dict:
result[key] = cls.unique_json_structure(
[nested_json if isinstance(nested_json, dict) else {}, value]
)
- elif isinstance(value, list) and value and all(isinstance(item, dict) for item in value):
+ elif (
+ isinstance(value, list)
+ and value
+ and all(isinstance(item, dict) for item in value)
+ ):
merged_struct = cls.unique_json_structure(value)
existing = result.get(key)
- existing_struct = existing.struct if isinstance(existing, _ArrayOfStruct) else {}
- result[key] = _ArrayOfStruct(cls.unique_json_structure([existing_struct, merged_struct]))
+ existing_struct = (
+ existing.struct if isinstance(existing, _ArrayOfStruct) else {}
+ )
+ result[key] = _ArrayOfStruct(
+ cls.unique_json_structure([existing_struct, merged_struct])
+ )
else:
result[key] = value
return result
@@ -482,8 +498,12 @@ def construct_json_column_children(cls, json_column: Dict) -> List[Dict]:
column["children"] = cls.construct_json_column_children(value.struct)
else:
type_ = type(value).__name__.lower()
- column["dataTypeDisplay"] = cls._data_formats.get(type_, DataType.UNKNOWN).value
- column["dataType"] = cls._data_formats.get(type_, DataType.UNKNOWN).value
+ column["dataTypeDisplay"] = cls._data_formats.get(
+ type_, DataType.UNKNOWN
+ ).value
+ column["dataType"] = cls._data_formats.get(
+ type_, DataType.UNKNOWN
+ ).value
if isinstance(value, dict):
column["children"] = cls.construct_json_column_children(value)
children.append(column)
@@ -517,7 +537,9 @@ def get_children(cls, json_column) -> List[Dict]:
f"parsed type is {type(parsed).__name__}"
)
except (TypeError, json.JSONDecodeError) as exc:
- logger.debug(f"Skipping unparseable string value while extracting column children: {exc}")
+ logger.debug(
+ f"Skipping unparseable string value while extracting column children: {exc}"
+ )
else:
logger.debug(
"Skipping non-string, non-dict value while extracting column children: "
diff --git a/ingestion/tests/integration/auto_classification/test_azuresql_temporal_table.py b/ingestion/tests/integration/auto_classification/test_azuresql_temporal_table.py
index c3eb508bc327..14ef65f3a515 100644
--- a/ingestion/tests/integration/auto_classification/test_azuresql_temporal_table.py
+++ b/ingestion/tests/integration/auto_classification/test_azuresql_temporal_table.py
@@ -116,7 +116,8 @@
pytestmark = pytest.mark.skipif(
not all(os.environ.get(v) for v in REQUIRED_ENV_VARS),
- reason="AzureSQL temporal table integration tests require environment variables: " + ", ".join(REQUIRED_ENV_VARS),
+ reason="AzureSQL temporal table integration tests require environment variables: "
+ + ", ".join(REQUIRED_ENV_VARS),
)
@@ -148,7 +149,11 @@ def create_service_request():
def ensure_temporal_table(db_service, table_suffix):
conn_config = db_service.connection.config
driver = (conn_config.driver or AZURE_SQL_DRIVER).replace(" ", "+")
- password = conn_config.password.get_secret_value() if conn_config.password else os.environ["AZURE_SQL_PASSWORD"]
+ password = (
+ conn_config.password.get_secret_value()
+ if conn_config.password
+ else os.environ["AZURE_SQL_PASSWORD"]
+ )
connection_url = (
f"mssql+pyodbc://{conn_config.username}:{password}"
f"@{conn_config.hostPort}/{conn_config.database}"
@@ -161,7 +166,8 @@ def ensure_temporal_table(db_service, table_suffix):
with engine.connect() as conn:
conn.execute(
- text(f"""
+ text(
+ f"""
IF OBJECT_ID('dbo.[{table_name}]', 'U') IS NULL
BEGIN
CREATE TABLE dbo.[{table_name}] (
@@ -173,11 +179,13 @@ def ensure_temporal_table(db_service, table_suffix):
PERIOD FOR SYSTEM_TIME (ValidFrom, ValidTo)
) WITH (SYSTEM_VERSIONING = ON (HISTORY_TABLE = dbo.[{history_name}]))
END
- """)
+ """
+ )
)
conn.commit()
conn.execute(
- text(f"""
+ text(
+ f"""
IF NOT EXISTS (SELECT 1 FROM dbo.[{table_name}] WHERE id IN (1, 2, 3))
BEGIN
INSERT INTO dbo.[{table_name}] (id, name, email) VALUES
@@ -185,7 +193,8 @@ def ensure_temporal_table(db_service, table_suffix):
(2, 'Bob', 'bob@example.com'),
(3, 'Carol', 'carol@example.com')
END
- """)
+ """
+ )
)
conn.commit()
@@ -262,12 +271,22 @@ def autoclassification_config(db_service, workflow_config, sink_config, table_na
@pytest.fixture(scope="module")
-def load_metadata(run_workflow, ingestion_config, ensure_temporal_table, patch_passwords_for_db_services):
+def load_metadata(
+ run_workflow,
+ ingestion_config,
+ ensure_temporal_table,
+ patch_passwords_for_db_services,
+):
return run_workflow(MetadataWorkflow, ingestion_config)
@pytest.fixture(scope="module")
-def run_classification(run_workflow, autoclassification_config, load_metadata, patch_passwords_for_db_services):
+def run_classification(
+ run_workflow,
+ autoclassification_config,
+ load_metadata,
+ patch_passwords_for_db_services,
+):
return run_workflow(AutoClassificationWorkflow, autoclassification_config)
@@ -289,8 +308,12 @@ def test_temporal_columns_excluded_from_sample_data(
assert len(result.sampleData.rows) > 0
column_names = [col.root for col in result.sampleData.columns]
- assert "ValidFrom" not in column_names, "ValidFrom must be excluded from sample data"
- assert "ValidTo" not in column_names, "ValidTo must be excluded from sample data"
+ assert (
+ "ValidFrom" not in column_names
+ ), "ValidFrom must be excluded from sample data"
+ assert (
+ "ValidTo" not in column_names
+ ), "ValidTo must be excluded from sample data"
assert "id" in column_names
assert "name" in column_names
assert "email" in column_names
diff --git a/ingestion/tests/unit/domain/tags/test_canonicalizer.py b/ingestion/tests/unit/domain/tags/test_canonicalizer.py
index c616878ba0a3..06c2ab9fd83c 100644
--- a/ingestion/tests/unit/domain/tags/test_canonicalizer.py
+++ b/ingestion/tests/unit/domain/tags/test_canonicalizer.py
@@ -15,7 +15,9 @@
import pytest
from metadata.domain.tags import Canonical, TagCanonicalizer
-from metadata.generated.schema.entity.classification.classification import Classification
+from metadata.generated.schema.entity.classification.classification import (
+ Classification,
+)
from metadata.generated.schema.type.basic import ProviderType
@@ -59,25 +61,37 @@ def _system_tag(classification: str, name: str, description: str = "") -> MagicM
class TestClassification:
- def test_no_match_returns_source_unchanged(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
+ def test_no_match_returns_source_unchanged(
+ self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
+ ):
mock_metadata.es_search_from_fqn.return_value = []
result = canonicalizer.classification("MyClass", "Source desc")
assert result == Canonical(name="MyClass", description="Source desc")
- def test_system_match_uses_canonical_case(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
- mock_metadata.es_search_from_fqn.return_value = [_system_classification("PII", "Canonical desc")]
+ def test_system_match_uses_canonical_case(
+ self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
+ ):
+ mock_metadata.es_search_from_fqn.return_value = [
+ _system_classification("PII", "Canonical desc")
+ ]
result = canonicalizer.classification("pii", "Source desc")
assert result == Canonical(name="PII", description="Canonical desc")
- def test_caches_per_case_insensitive_key(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
- mock_metadata.es_search_from_fqn.return_value = [_system_classification("PII", "Canonical desc")]
+ def test_caches_per_case_insensitive_key(
+ self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
+ ):
+ mock_metadata.es_search_from_fqn.return_value = [
+ _system_classification("PII", "Canonical desc")
+ ]
canonicalizer.classification("pii", "Source desc")
canonicalizer.classification("PII", "Source desc")
canonicalizer.classification("Pii", "Source desc")
# Three case variants share the same case-insensitive cache key
assert mock_metadata.es_search_from_fqn.call_count == 1
- def test_non_system_match_ignored(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
+ def test_non_system_match_ignored(
+ self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
+ ):
non_system = _system_classification("PII", "Canonical desc")
non_system.provider = ProviderType.user
mock_metadata.es_search_from_fqn.return_value = [non_system]
@@ -89,36 +103,52 @@ def test_classification_es_called_with_correct_args(
):
mock_metadata.es_search_from_fqn.return_value = []
canonicalizer.classification("Foo", "Source desc")
- mock_metadata.es_search_from_fqn.assert_called_once_with(entity_type=Classification, fqn_search_string="Foo")
+ mock_metadata.es_search_from_fqn.assert_called_once_with(
+ entity_type=Classification, fqn_search_string="Foo"
+ )
class TestTag:
- def test_no_match_returns_source_unchanged(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
+ def test_no_match_returns_source_unchanged(
+ self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
+ ):
mock_metadata.es_search_from_fqn.return_value = []
result = canonicalizer.tag("PII", "MyTag", "Source desc")
assert result == Canonical(name="MyTag", description="Source desc")
- def test_system_match_uses_canonical_case(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
- mock_metadata.es_search_from_fqn.return_value = [_system_tag("PII", "Sensitive", "Canonical desc")]
+ def test_system_match_uses_canonical_case(
+ self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
+ ):
+ mock_metadata.es_search_from_fqn.return_value = [
+ _system_tag("PII", "Sensitive", "Canonical desc")
+ ]
result = canonicalizer.tag("PII", "sensitive", "Source desc")
assert result == Canonical(name="Sensitive", description="Canonical desc")
- def test_caches_per_case_insensitive_key(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
- mock_metadata.es_search_from_fqn.return_value = [_system_tag("PII", "Sensitive", "")]
+ def test_caches_per_case_insensitive_key(
+ self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
+ ):
+ mock_metadata.es_search_from_fqn.return_value = [
+ _system_tag("PII", "Sensitive", "")
+ ]
canonicalizer.tag("PII", "sensitive", "Source desc")
canonicalizer.tag("PII", "SENSITIVE", "Source desc")
canonicalizer.tag("PII", "Sensitive", "Source desc")
# Three case variants share the same case-insensitive cache key
assert mock_metadata.es_search_from_fqn.call_count == 1
- def test_match_requires_classification_match(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
+ def test_match_requires_classification_match(
+ self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
+ ):
# ES returns a tag but for a different classification — no canonicalization
wrong_class_tag = _system_tag("OtherClass", "Sensitive", "Canonical desc")
mock_metadata.es_search_from_fqn.return_value = [wrong_class_tag]
result = canonicalizer.tag("PII", "sensitive", "Source desc")
assert result == Canonical(name="sensitive", description="Source desc")
- def test_non_system_match_ignored(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
+ def test_non_system_match_ignored(
+ self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
+ ):
non_system = _system_tag("PII", "Sensitive", "Canonical desc")
non_system.provider = ProviderType.user
mock_metadata.es_search_from_fqn.return_value = [non_system]
@@ -148,7 +178,9 @@ def test_persistent_failure_raises_after_retries_exhaust(
canonicalizer.classification("MyClass", "Source desc")
assert mock_metadata.es_search_from_fqn.call_count == 5
- def test_persistent_failure_does_not_poison_cache(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
+ def test_persistent_failure_does_not_poison_cache(
+ self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
+ ):
# First call: ES persistently fails -> raises.
mock_metadata.es_search_from_fqn.side_effect = RuntimeError("persistent")
with pytest.raises(RuntimeError):
@@ -156,6 +188,8 @@ def test_persistent_failure_does_not_poison_cache(self, canonicalizer: TagCanoni
# ES recovers; subsequent call must reach ES again, not return a cached fallback.
mock_metadata.es_search_from_fqn.side_effect = None
- mock_metadata.es_search_from_fqn.return_value = [_system_classification("MyClass", "Canonical desc")]
+ mock_metadata.es_search_from_fqn.return_value = [
+ _system_classification("MyClass", "Canonical desc")
+ ]
result = canonicalizer.classification("MyClass", "Source desc")
assert result == Canonical(name="MyClass", description="Canonical desc")
diff --git a/ingestion/tests/unit/domain/tags/test_registry.py b/ingestion/tests/unit/domain/tags/test_registry.py
index 7cddc99eab0a..7a8e96ce069a 100644
--- a/ingestion/tests/unit/domain/tags/test_registry.py
+++ b/ingestion/tests/unit/domain/tags/test_registry.py
@@ -62,7 +62,9 @@ def test_attach_multiple_tags_same_entity_returns_all(self, registry: TagRegistr
labels = registry.labels_for("svc.db.schema.table")
assert len(labels) == 2
- def test_labels_for_unattached_entity_returns_empty_list(self, registry: TagRegistry):
+ def test_labels_for_unattached_entity_returns_empty_list(
+ self, registry: TagRegistry
+ ):
assert registry.labels_for("svc.db.schema.unknown") == []
def test_labels_for_is_idempotent(self, registry: TagRegistry):
@@ -97,7 +99,9 @@ def test_drain_dedupes_same_tag_across_entities(self, registry: TagRegistry):
pending = list(registry.drain())
assert len(pending) == 1
- def test_drain_yields_distinct_payloads_for_distinct_tags(self, registry: TagRegistry):
+ def test_drain_yields_distinct_payloads_for_distinct_tags(
+ self, registry: TagRegistry
+ ):
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_1", tag="TagA"))
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_2", tag="TagB"))
pending = list(registry.drain())
@@ -122,7 +126,9 @@ def test_drain_dedupes_same_fqn_across_label_types(self, registry: TagRegistry):
label_type=LabelType.Automated,
)
pending = list(registry.drain())
- assert len(pending) == 1, "fqn-level dedup must collapse PUTs across label_type variants"
+ assert (
+ len(pending) == 1
+ ), "fqn-level dedup must collapse PUTs across label_type variants"
class TestClearScope:
@@ -148,7 +154,9 @@ def test_clear_scope_preserves_other_scopes(self, registry: TagRegistry):
def test_clear_scope_no_false_prefix_match(self, registry: TagRegistry):
# 'schema_a' is NOT a prefix of 'schema_alpha' once the FQN
# separator is taken into account.
- registry.attach(**_attach_kwargs("svc.db.schema_alpha", "svc.db.schema_alpha.tbl"))
+ registry.attach(
+ **_attach_kwargs("svc.db.schema_alpha", "svc.db.schema_alpha.tbl")
+ )
registry.clear_scope("svc.db.schema_a")
assert len(registry.labels_for("svc.db.schema_alpha.tbl")) == 1
@@ -189,7 +197,9 @@ def test_is_known_is_case_sensitive(self, registry: TagRegistry):
assert registry.is_known("Class.Tag") is True
assert registry.is_known("class.tag") is False # different tag server-side
- def test_ensure_known_cache_hit_skips_io(self, registry: TagRegistry, mock_metadata: MagicMock):
+ def test_ensure_known_cache_hit_skips_io(
+ self, registry: TagRegistry, mock_metadata: MagicMock
+ ):
registry.attach(
**_attach_kwargs(
"svc.db",
@@ -201,20 +211,26 @@ def test_ensure_known_cache_hit_skips_io(self, registry: TagRegistry, mock_metad
assert registry.ensure_known("Class.Tag") is True
mock_metadata.get_by_name.assert_not_called()
- def test_ensure_known_cache_miss_calls_get_by_name_once(self, registry: TagRegistry, mock_metadata: MagicMock):
+ def test_ensure_known_cache_miss_calls_get_by_name_once(
+ self, registry: TagRegistry, mock_metadata: MagicMock
+ ):
mock_metadata.get_by_name.return_value = MagicMock()
assert registry.ensure_known("Other.Tag") is True
assert registry.ensure_known("Other.Tag") is True # cached now
assert mock_metadata.get_by_name.call_count == 1
- def test_ensure_known_404_returns_false_and_does_not_cache(self, registry: TagRegistry, mock_metadata: MagicMock):
+ def test_ensure_known_404_returns_false_and_does_not_cache(
+ self, registry: TagRegistry, mock_metadata: MagicMock
+ ):
mock_metadata.get_by_name.return_value = None
assert registry.ensure_known("Missing.Tag") is False
assert registry.ensure_known("Missing.Tag") is False
# Re-queries on each miss; not cached.
assert mock_metadata.get_by_name.call_count == 2
- def test_ensure_known_swallows_exception(self, registry: TagRegistry, mock_metadata: MagicMock):
+ def test_ensure_known_swallows_exception(
+ self, registry: TagRegistry, mock_metadata: MagicMock
+ ):
mock_metadata.get_by_name.side_effect = RuntimeError("network down")
assert registry.ensure_known("Crashed.Tag") is False
diff --git a/ingestion/tests/unit/metadata/pii/test_language_filtering.py b/ingestion/tests/unit/metadata/pii/test_language_filtering.py
index 6d54c880cd8e..1871eb8a71c0 100644
--- a/ingestion/tests/unit/metadata/pii/test_language_filtering.py
+++ b/ingestion/tests/unit/metadata/pii/test_language_filtering.py
@@ -351,7 +351,9 @@ def fr_language_tag(self):
],
)
- def test_any_language_recognizer_included_when_agent_is_en(self, any_language_tag, sample_column, mock_nlp_engine):
+ def test_any_language_recognizer_included_when_agent_is_en(
+ self, any_language_tag, sample_column, mock_nlp_engine
+ ):
analyzer = TagAnalyzer(
tag=any_language_tag,
column=sample_column,
@@ -364,7 +366,9 @@ def test_any_language_recognizer_included_when_agent_is_en(self, any_language_ta
assert len(recognizers) == 1
assert recognizers[0].supported_language == ClassificationLanguage.any.value
- def test_any_language_recognizer_included_when_agent_is_any(self, any_language_tag, sample_column, mock_nlp_engine):
+ def test_any_language_recognizer_included_when_agent_is_any(
+ self, any_language_tag, sample_column, mock_nlp_engine
+ ):
analyzer = TagAnalyzer(
tag=any_language_tag,
column=sample_column,
diff --git a/ingestion/tests/unit/pii/test_cases/azuresql_temporal_table.py b/ingestion/tests/unit/pii/test_cases/azuresql_temporal_table.py
index 658ad06ec473..6df0d63df3ee 100644
--- a/ingestion/tests/unit/pii/test_cases/azuresql_temporal_table.py
+++ b/ingestion/tests/unit/pii/test_cases/azuresql_temporal_table.py
@@ -25,7 +25,9 @@
table = Table(
id=Uuid(root=uuid.uuid4()),
name=EntityName(root="customers_temporal"),
- fullyQualifiedName=FullyQualifiedEntityName(root="Service.database.schema.customers_temporal"),
+ fullyQualifiedName=FullyQualifiedEntityName(
+ root="Service.database.schema.customers_temporal"
+ ),
columns=[
Column(
name=ColumnName(root="id"),
@@ -36,7 +38,9 @@
precision=1,
scale=None,
dataTypeDisplay="int",
- fullyQualifiedName=FullyQualifiedEntityName(root="Service.database.schema.customers_temporal.id"),
+ fullyQualifiedName=FullyQualifiedEntityName(
+ root="Service.database.schema.customers_temporal.id"
+ ),
),
Column(
name=ColumnName(root="name"),
@@ -47,7 +51,9 @@
precision=1,
scale=None,
dataTypeDisplay="string",
- fullyQualifiedName=FullyQualifiedEntityName(root="Service.database.schema.customers_temporal.name"),
+ fullyQualifiedName=FullyQualifiedEntityName(
+ root="Service.database.schema.customers_temporal.name"
+ ),
),
Column(
name=ColumnName(root="email"),
@@ -58,7 +64,9 @@
precision=1,
scale=None,
dataTypeDisplay="string",
- fullyQualifiedName=FullyQualifiedEntityName(root="Service.database.schema.customers_temporal.email"),
+ fullyQualifiedName=FullyQualifiedEntityName(
+ root="Service.database.schema.customers_temporal.email"
+ ),
),
Column(
name=ColumnName(root="ValidFrom"),
@@ -69,7 +77,9 @@
precision=1,
scale=None,
dataTypeDisplay="datetime",
- fullyQualifiedName=FullyQualifiedEntityName(root="Service.database.schema.customers_temporal.ValidFrom"),
+ fullyQualifiedName=FullyQualifiedEntityName(
+ root="Service.database.schema.customers_temporal.ValidFrom"
+ ),
),
Column(
name=ColumnName(root="ValidTo"),
@@ -80,7 +90,9 @@
precision=1,
scale=None,
dataTypeDisplay="datetime",
- fullyQualifiedName=FullyQualifiedEntityName(root="Service.database.schema.customers_temporal.ValidTo"),
+ fullyQualifiedName=FullyQualifiedEntityName(
+ root="Service.database.schema.customers_temporal.ValidTo"
+ ),
),
],
)
diff --git a/ingestion/tests/unit/profiler/sqlalchemy/azuresql/test_azuresql_sampling.py b/ingestion/tests/unit/profiler/sqlalchemy/azuresql/test_azuresql_sampling.py
index 1fe2b836b087..5b03f90ef563 100644
--- a/ingestion/tests/unit/profiler/sqlalchemy/azuresql/test_azuresql_sampling.py
+++ b/ingestion/tests/unit/profiler/sqlalchemy/azuresql/test_azuresql_sampling.py
@@ -233,7 +233,9 @@ def capture_fetch(cols=None):
):
sampler.fetch_sample_data(columns=[valid_from_col, valid_to_col])
- assert received["columns"] == [], "Expected empty list when all columns are filtered, not the original list"
+ assert (
+ received["columns"] == []
+ ), "Expected empty list when all columns are filtered, not the original list"
def test_sampling_with_partition(self, sampler_mock):
"""
diff --git a/ingestion/tests/unit/topology/database/test_athena.py b/ingestion/tests/unit/topology/database/test_athena.py
index 5e83d515ea3c..799561a7213b 100644
--- a/ingestion/tests/unit/topology/database/test_athena.py
+++ b/ingestion/tests/unit/topology/database/test_athena.py
@@ -413,7 +413,11 @@ def _mock_query_rows(source, rows):
def _get_request(mock_metadata, call_index=0):
"""Pull the CreateCustomPropertyRequest from a create_or_update_custom_property call."""
- return mock_metadata.create_or_update_custom_property.call_args_list[call_index].args[0].createCustomPropertyRequest
+ return (
+ mock_metadata.create_or_update_custom_property.call_args_list[call_index]
+ .args[0]
+ .createCustomPropertyRequest
+ )
class TestGetTableExtensionsEarlyExits:
@@ -422,34 +426,53 @@ class TestGetTableExtensionsEarlyExits:
def test_returns_none_when_include_custom_properties_disabled(self, athena_source):
athena_source.source_config.includeCustomProperties = False
with patch.object(athena_source, "_fetch_iceberg_properties") as mock_fetch:
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result is None
mock_fetch.assert_not_called()
def test_returns_none_without_type_ref(self, athena_source):
athena_source._string_property_type_ref = None
- assert athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg) is None
+ assert (
+ athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ is None
+ )
def test_returns_none_for_external_table(self, athena_source):
- assert athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.External) is None
+ assert (
+ athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.External)
+ is None
+ )
def test_returns_none_for_regular_table(self, athena_source):
- assert athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Regular) is None
+ assert (
+ athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Regular)
+ is None
+ )
def test_returns_none_when_table_type_is_none(self, athena_source):
assert athena_source.get_table_extensions(MOCK_TABLE_NAME) is None
def test_returns_none_when_query_yields_no_properties(self, athena_source):
with patch.object(athena_source, "_fetch_iceberg_properties", return_value={}):
- assert athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg) is None
+ assert (
+ athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ is None
+ )
def test_returns_none_when_all_values_filtered_out(self, athena_source):
props = {"k1": None, "k2": ""}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata"),
):
- assert athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg) is None
+ assert (
+ athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ is None
+ )
class TestGetTableExtensionsSanitization:
@@ -458,10 +481,14 @@ class TestGetTableExtensionsSanitization:
def test_dot_is_preserved(self, athena_source):
props = {"myprop.owner": "team-a"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"myprop.owner": "team-a"}
request = _get_request(mock_metadata)
@@ -471,10 +498,14 @@ def test_dot_is_preserved(self, athena_source):
def test_hyphen_is_preserved(self, athena_source):
props = {"myprop-owner": "x"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"myprop-owner": "x"}
request = _get_request(mock_metadata)
@@ -484,10 +515,14 @@ def test_allowed_punctuation_combined_preserved(self, athena_source):
"""Dots and hyphens together are allowed — name passes through untouched."""
props = {"myprop.airflow-dag-id": "scrape-dag"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"myprop.airflow-dag-id": "scrape-dag"}
request = _get_request(mock_metadata)
@@ -498,10 +533,14 @@ def test_other_special_chars_still_replaced(self, athena_source):
"""Everything outside [A-Za-z0-9_.-] gets replaced with __."""
props = {"myprop/airflow:dag id@prod": "v"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"myprop__airflow__dag__id__prod": "v"}
request = _get_request(mock_metadata)
@@ -511,20 +550,28 @@ def test_mixed_allowed_and_disallowed_chars(self, athena_source):
"""Allowed chars (. -) stay; disallowed chars (/ space) get replaced."""
props = {"myprop.data/type-v1 beta": "v"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata"),
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"myprop.data__type-v1__beta": "v"}
def test_already_valid_name_unchanged(self, athena_source):
props = {"simple_key": "value"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"simple_key": "value"}
request = _get_request(mock_metadata)
@@ -534,20 +581,28 @@ def test_already_valid_name_unchanged(self, athena_source):
def test_alphanumeric_and_underscore_preserved(self, athena_source):
props = {"abc123_XYZ": "v"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata"),
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"abc123_XYZ": "v"}
def test_sanitized_name_at_256_chars_not_hashed(self, athena_source):
name = "a" * 256
props = {name: "value"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {name: "value"}
request = _get_request(mock_metadata)
@@ -557,12 +612,18 @@ def test_long_sanitized_name_is_md5_hashed(self, athena_source):
original = "myprop." + ("a" * 260)
props = {original: "value"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
- expected_hash = hashlib.md5(original.encode("utf-8"), usedforsecurity=False).hexdigest()
+ expected_hash = hashlib.md5(
+ original.encode("utf-8"), usedforsecurity=False
+ ).hexdigest()
assert result == {expected_hash: "value"}
request = _get_request(mock_metadata)
assert request.name.root == expected_hash
@@ -594,10 +655,14 @@ class TestGetTableExtensionsValueFiltering:
def test_skips_none_valued_property(self, athena_source):
props = {"k1": "v1", "k2": None}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"k1": "v1"}
assert mock_metadata.create_or_update_custom_property.call_count == 1
@@ -605,30 +670,42 @@ def test_skips_none_valued_property(self, athena_source):
def test_skips_empty_string_valued_property(self, athena_source):
props = {"k1": "v1", "k2": ""}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata"),
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"k1": "v1"}
def test_keeps_string_zero(self, athena_source):
"""'0' is falsy-ish in some checks but is a legitimate value."""
props = {"k": "0"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata"),
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"k": "0"}
def test_keeps_whitespace_value(self, athena_source):
"""A single space is not an empty string and should pass through."""
props = {"k": " "}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata"),
):
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"k": " "}
@@ -638,7 +715,9 @@ class TestGetTableExtensionsDedup:
def test_same_prop_across_tables_registered_once(self, athena_source):
props = {"shared_key": "v"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
athena_source.get_table_extensions("tbl1", TableType.Iceberg)
@@ -665,16 +744,24 @@ def test_registration_failure_does_not_mark_prop_processed(self, athena_source):
"""A failed registration must not be cached — so a retry on the next table can succeed."""
props = {"k1": "v1"}
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
- mock_metadata.create_or_update_custom_property.side_effect = Exception("boom")
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ mock_metadata.create_or_update_custom_property.side_effect = Exception(
+ "boom"
+ )
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result is None
assert "k1" not in athena_source._processed_prop
- def test_registration_failure_for_one_prop_does_not_block_others(self, athena_source):
+ def test_registration_failure_for_one_prop_does_not_block_others(
+ self, athena_source
+ ):
"""Registration errors on one prop don't prevent others from being returned."""
props = {"bad_prop": "x", "good_prop": "y"}
call_flag = {"first": True}
@@ -686,11 +773,15 @@ def side_effect(_):
return
with (
- patch.object(athena_source, "_fetch_iceberg_properties", return_value=props),
+ patch.object(
+ athena_source, "_fetch_iceberg_properties", return_value=props
+ ),
patch.object(athena_source, "metadata") as mock_metadata,
):
mock_metadata.create_or_update_custom_property.side_effect = side_effect
- result = athena_source.get_table_extensions(MOCK_TABLE_NAME, TableType.Iceberg)
+ result = athena_source.get_table_extensions(
+ MOCK_TABLE_NAME, TableType.Iceberg
+ )
assert result == {"good_prop": "y"}
@@ -704,7 +795,9 @@ def test_returns_properties_from_query(self, athena_source):
[("myprop.owner", "team-a"), ("myprop.source", "ex")],
)
- result = athena_source._fetch_iceberg_properties(MOCK_DATABASE_SCHEMA.name.root, MOCK_TABLE_NAME)
+ result = athena_source._fetch_iceberg_properties(
+ MOCK_DATABASE_SCHEMA.name.root, MOCK_TABLE_NAME
+ )
assert result == {"myprop.owner": "team-a", "myprop.source": "ex"}
def test_returns_empty_dict_on_exception(self, athena_source):
@@ -712,7 +805,9 @@ def test_returns_empty_dict_on_exception(self, athena_source):
mock_engine.connect.side_effect = Exception("connection refused")
athena_source.engine = mock_engine
- result = athena_source._fetch_iceberg_properties(MOCK_DATABASE_SCHEMA.name.root, MOCK_TABLE_NAME)
+ result = athena_source._fetch_iceberg_properties(
+ MOCK_DATABASE_SCHEMA.name.root, MOCK_TABLE_NAME
+ )
assert result == {}
def test_filters_null_key_and_null_value_rows(self, athena_source):
@@ -726,7 +821,9 @@ def test_filters_null_key_and_null_value_rows(self, athena_source):
],
)
- result = athena_source._fetch_iceberg_properties(MOCK_DATABASE_SCHEMA.name.root, MOCK_TABLE_NAME)
+ result = athena_source._fetch_iceberg_properties(
+ MOCK_DATABASE_SCHEMA.name.root, MOCK_TABLE_NAME
+ )
assert result == {"k1": "v1", "k3": "v3"}
def test_query_targets_dollar_properties_metatable(self, athena_source):
@@ -744,7 +841,9 @@ def test_query_targets_dollar_properties_metatable(self, athena_source):
def test_values_are_coerced_to_string(self, athena_source):
_mock_query_rows(athena_source, [("k_int", 42), ("k_bool", True)])
- result = athena_source._fetch_iceberg_properties(MOCK_DATABASE_SCHEMA.name.root, MOCK_TABLE_NAME)
+ result = athena_source._fetch_iceberg_properties(
+ MOCK_DATABASE_SCHEMA.name.root, MOCK_TABLE_NAME
+ )
assert result == {"k_int": "42", "k_bool": "True"}
diff --git a/ingestion/tests/unit/topology/database/test_databricks_nested_comments.py b/ingestion/tests/unit/topology/database/test_databricks_nested_comments.py
index cafa07b11dde..215e352f7820 100644
--- a/ingestion/tests/unit/topology/database/test_databricks_nested_comments.py
+++ b/ingestion/tests/unit/topology/database/test_databricks_nested_comments.py
@@ -218,27 +218,46 @@ def test_query_failure_returns_empty(self):
connection = MagicMock()
connection.execute.side_effect = Exception("syntax error: AS JSON unsupported")
- assert _fetch_nested_descriptions_via_describe_json(connection, "db", "schema", "table") == {}
+ assert (
+ _fetch_nested_descriptions_via_describe_json(
+ connection, "db", "schema", "table"
+ )
+ == {}
+ )
def test_empty_result_returns_empty(self):
connection = MagicMock()
connection.execute.return_value.fetchone.return_value = None
- assert _fetch_nested_descriptions_via_describe_json(connection, "db", "schema", "table") == {}
+ assert (
+ _fetch_nested_descriptions_via_describe_json(
+ connection, "db", "schema", "table"
+ )
+ == {}
+ )
def test_invalid_json_returns_empty(self):
connection = MagicMock()
connection.execute.return_value.fetchone.return_value = ("not valid json {",)
- assert _fetch_nested_descriptions_via_describe_json(connection, "db", "schema", "table") == {}
+ assert (
+ _fetch_nested_descriptions_via_describe_json(
+ connection, "db", "schema", "table"
+ )
+ == {}
+ )
def test_valid_json_extracts_descriptions(self):
import json as _json
connection = MagicMock()
- connection.execute.return_value.fetchone.return_value = (_json.dumps(_CUSTOMER_PROFILES_JSON),)
+ connection.execute.return_value.fetchone.return_value = (
+ _json.dumps(_CUSTOMER_PROFILES_JSON),
+ )
- result = _fetch_nested_descriptions_via_describe_json(connection, "db", "schema", "customer_profiles")
+ result = _fetch_nested_descriptions_via_describe_json(
+ connection, "db", "schema", "customer_profiles"
+ )
assert ("first_name",) in result["personal_info"]
assert result["personal_info"][("first_name",)] == "Customer first name"
@@ -252,7 +271,12 @@ def test_missing_db_or_schema_returns_empty_without_query(self, db_name, schema)
identifiers and rely on the except block to swallow the error."""
connection = MagicMock()
- assert _fetch_nested_descriptions_via_describe_json(connection, db_name, schema, "table") == {}
+ assert (
+ _fetch_nested_descriptions_via_describe_json(
+ connection, db_name, schema, "table"
+ )
+ == {}
+ )
connection.execute.assert_not_called()
@@ -274,7 +298,9 @@ def test_top_level_descriptions(self):
_apply_nested_descriptions(col, descs, ())
children_by_name = {c.name.root: c for c in col.children}
- assert children_by_name["first_name"].description == Markdown(root="Customer first name")
+ assert children_by_name["first_name"].description == Markdown(
+ root="Customer first name"
+ )
assert children_by_name["dob"].description == Markdown(root="Date of birth")
def test_nested_struct_descriptions(self):
@@ -320,7 +346,9 @@ def test_column_with_no_children_is_safe(self):
assert col.description is None
-@patch("metadata.ingestion.source.database.databricks.metadata._fetch_nested_descriptions_via_describe_json")
+@patch(
+ "metadata.ingestion.source.database.databricks.metadata._fetch_nested_descriptions_via_describe_json"
+)
@patch("metadata.ingestion.source.database.databricks.metadata._get_column_rows")
class TestDescribeJsonLazyFetch:
"""The ``DESCRIBE TABLE EXTENDED ... AS JSON`` round-trip is fired only
@@ -337,7 +365,9 @@ def _run(self, mock_connection):
db_name="db",
)
- def test_skipped_when_table_has_no_complex_columns(self, mock_rows, mock_fetch_json):
+ def test_skipped_when_table_has_no_complex_columns(
+ self, mock_rows, mock_fetch_json
+ ):
"""Primitive-only table → AS JSON query never runs."""
mock_rows.return_value = [
("id", "bigint", None),
@@ -361,7 +391,9 @@ def _connection_with_describe_rows(self):
]
return connection
- def test_called_once_for_table_with_one_complex_column(self, mock_rows, mock_fetch_json):
+ def test_called_once_for_table_with_one_complex_column(
+ self, mock_rows, mock_fetch_json
+ ):
mock_rows.return_value = [
("id", "bigint", None),
("info", "struct", None),
@@ -374,7 +406,9 @@ def test_called_once_for_table_with_one_complex_column(self, mock_rows, mock_fet
mock_fetch_json.assert_called_once_with(connection, "db", "schema", "tbl")
- def test_called_once_for_table_with_multiple_complex_columns(self, mock_rows, mock_fetch_json):
+ def test_called_once_for_table_with_multiple_complex_columns(
+ self, mock_rows, mock_fetch_json
+ ):
"""Cached after first complex column — second/third columns reuse
the result instead of triggering another round-trip."""
mock_rows.return_value = [
@@ -407,7 +441,9 @@ def test_array_of_struct_triggers_lazy_fetch(self, mock_rows, mock_fetch_json):
mock_fetch_json.assert_called_once_with(connection, "db", "schema", "tbl")
- def test_array_of_primitive_does_not_trigger_lazy_fetch(self, mock_rows, mock_fetch_json):
+ def test_array_of_primitive_does_not_trigger_lazy_fetch(
+ self, mock_rows, mock_fetch_json
+ ):
"""``array`` carries no nested struct fields, so the regex
gate must skip the AS JSON round-trip."""
mock_rows.return_value = [
@@ -524,7 +560,9 @@ def test_get_table_description_handles_sa2_row(self):
_SqlAlchemy2Row(("Comment", "My table description")),
]
- result = DatabricksSource.get_table_description(mock_self, "my_schema", "my_table", mock_inspector)
+ result = DatabricksSource.get_table_description(
+ mock_self, "my_schema", "my_table", mock_inspector
+ )
assert result == "My table description"
@@ -541,6 +579,8 @@ def test_get_table_description_returns_none_when_no_comment_row(self):
_SqlAlchemy2Row(("Location", "/external/path")),
]
- result = DatabricksSource.get_table_description(mock_self, "my_schema", "my_table", mock_inspector)
+ result = DatabricksSource.get_table_description(
+ mock_self, "my_schema", "my_table", mock_inspector
+ )
assert result is None
diff --git a/ingestion/tests/unit/topology/database/test_snowflake.py b/ingestion/tests/unit/topology/database/test_snowflake.py
index 67ff74c891ee..c9cb8c0ec403 100644
--- a/ingestion/tests/unit/topology/database/test_snowflake.py
+++ b/ingestion/tests/unit/topology/database/test_snowflake.py
@@ -574,7 +574,9 @@ def test_schema_tag_inheritance(self):
schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA")
self.assertIsNotNone(schema_labels)
self.assertEqual(len(schema_labels), 1)
- self.assertEqual(schema_labels[0].tagFQN.root, "SCHEMA_CLASSIFICATION.SCHEMA_TAG")
+ self.assertEqual(
+ schema_labels[0].tagFQN.root, "SCHEMA_CLASSIFICATION.SCHEMA_TAG"
+ )
table_labels = source.get_tag_labels(table_name="TEST_TABLE")
self.assertEqual(len(table_labels), 2)
@@ -835,7 +837,9 @@ def test_fetchall_invoked_exactly_once(self):
mock_conn = MagicMock()
mock_conn.execute.return_value = result
- with patch.object(SnowflakeSource, "connection", new_callable=PropertyMock) as mocked_conn_prop:
+ with patch.object(
+ SnowflakeSource, "connection", new_callable=PropertyMock
+ ) as mocked_conn_prop:
mocked_conn_prop.return_value = mock_conn
list(source.get_database_names_raw())
@@ -848,7 +852,9 @@ def test_yields_database_names_in_order(self):
mock_conn = MagicMock()
mock_conn.execute.return_value = result
- with patch.object(SnowflakeSource, "connection", new_callable=PropertyMock) as mocked_conn_prop:
+ with patch.object(
+ SnowflakeSource, "connection", new_callable=PropertyMock
+ ) as mocked_conn_prop:
mocked_conn_prop.return_value = mock_conn
names = list(source.get_database_names_raw())
@@ -909,7 +915,9 @@ def test_get_schema_columns_skips_invalid_table_name(self):
mock_connection = Mock()
mock_connection.execute = Mock(return_value=iter(rows))
- result = get_schema_columns(dialect, mock_connection, schema="SCHEMA", info_cache={})
+ result = get_schema_columns(
+ dialect, mock_connection, schema="SCHEMA", info_cache={}
+ )
# The good table's columns were populated even though a bad-named row
# appeared between them — fault isolation at the per-row level.
@@ -937,8 +945,12 @@ def test_get_table_names_skips_deleted_with_invalid_name(self):
deleted_at = datetime(2026, 1, 1)
snowflake_tables = SnowflakeTableList(
tables=[
- SnowflakeTable(name="GOOD_GONE", deleted=deleted_at, type_=TableType.Regular),
- SnowflakeTable(name='BAD"GONE', deleted=deleted_at, type_=TableType.Regular),
+ SnowflakeTable(
+ name="GOOD_GONE", deleted=deleted_at, type_=TableType.Regular
+ ),
+ SnowflakeTable(
+ name='BAD"GONE', deleted=deleted_at, type_=TableType.Regular
+ ),
SnowflakeTable(name="ALIVE_TBL", deleted=None, type_=TableType.Regular),
]
)
@@ -949,7 +961,16 @@ def test_get_table_names_skips_deleted_with_invalid_name(self):
source.context.get().__dict__["database"] = "db"
source.context.get_global().deleted_tables = []
- def fake_fqn_build(*, metadata, entity_type, service_name, database_name, schema_name, table_name, **_kw):
+ def fake_fqn_build(
+ *,
+ metadata,
+ entity_type,
+ service_name,
+ database_name,
+ schema_name,
+ table_name,
+ **_kw,
+ ):
from metadata.utils.fqn import quote_name
# quote_name still rejects names with embedded `"`; let that drive the failure.
diff --git a/ingestion/tests/unit/topology/pipeline/test_service_resolver.py b/ingestion/tests/unit/topology/pipeline/test_service_resolver.py
index 0c3a89927e7c..52a92f5c736f 100644
--- a/ingestion/tests/unit/topology/pipeline/test_service_resolver.py
+++ b/ingestion/tests/unit/topology/pipeline/test_service_resolver.py
@@ -2,9 +2,7 @@
Tests for the OpenLineage service resolver module.
"""
-from unittest.mock import MagicMock, patch
-
-import pytest
+from unittest.mock import MagicMock
from metadata.generated.schema.entity.services.pipelineService import (
PipelineServiceType,
diff --git a/ingestion/tests/unit/topology/test_common_db_source_isolation.py b/ingestion/tests/unit/topology/test_common_db_source_isolation.py
index c702b01bce9d..e500b3fb1110 100644
--- a/ingestion/tests/unit/topology/test_common_db_source_isolation.py
+++ b/ingestion/tests/unit/topology/test_common_db_source_isolation.py
@@ -52,9 +52,20 @@ def source():
def _fqn_side_effect(*, bad_name):
"""fqn.build that raises FQNBuildingException only for `bad_name`."""
- def _build(_metadata, *, entity_type, service_name, database_name, schema_name, table_name, **_):
+ def _build(
+ _metadata,
+ *,
+ entity_type,
+ service_name,
+ database_name,
+ schema_name,
+ table_name,
+ **_,
+ ):
if table_name == bad_name:
- raise FQNBuildingException(f"Error building FQN for Table: Invalid name {table_name}")
+ raise FQNBuildingException(
+ f"Error building FQN for Table: Invalid name {table_name}"
+ )
return f"{service_name}.{database_name}.{schema_name}.{table_name}"
return _build
@@ -128,8 +139,12 @@ def test_get_tables_name_and_type_isolates_failed_view(caplog, source):
def test_get_tables_name_and_type_handles_listing_failure(source):
"""If query_table_names_and_types itself raises, the function logs a
warning and proceeds with the view loop (no crash)."""
- source.query_table_names_and_types = MagicMock(side_effect=RuntimeError("upstream listing exploded"))
- source.query_view_names_and_types = MagicMock(return_value=[TableNameAndType(name="V1", type_=TableType.View)])
+ source.query_table_names_and_types = MagicMock(
+ side_effect=RuntimeError("upstream listing exploded")
+ )
+ source.query_view_names_and_types = MagicMock(
+ return_value=[TableNameAndType(name="V1", type_=TableType.View)]
+ )
source.standardize_table_name = lambda _schema, name: name
with patch(
diff --git a/ingestion/tests/unit/utils/test_datalake.py b/ingestion/tests/unit/utils/test_datalake.py
index 736bd384d87e..290548f9c12d 100644
--- a/ingestion/tests/unit/utils/test_datalake.py
+++ b/ingestion/tests/unit/utils/test_datalake.py
@@ -193,7 +193,11 @@ def test_unique_json_structure_merges_list_of_dicts_across_samples(self):
from metadata.utils.datalake.datalake_utils import _ArrayOfStruct
sample_data = [
- {"schema": {"fields": [{"id": 1, "name": "customer_id", "type": "string"}]}},
+ {
+ "schema": {
+ "fields": [{"id": 1, "name": "customer_id", "type": "string"}]
+ }
+ },
{"schema": {"fields": [{"id": 2, "required": False, "type": "string"}]}},
{"schema": {"fields": [{"description": "ciam id"}]}},
]
@@ -202,7 +206,13 @@ def test_unique_json_structure_merges_list_of_dicts_across_samples(self):
fields_value = actual["schema"]["fields"]
assert isinstance(fields_value, _ArrayOfStruct)
- assert set(fields_value.struct.keys()) == {"id", "name", "type", "required", "description"}
+ assert set(fields_value.struct.keys()) == {
+ "id",
+ "name",
+ "type",
+ "required",
+ "description",
+ }
def test_construct_column_with_array_of_struct(self):
"""list-of-dicts values render as ARRAY> with children for the struct fields."""
@@ -222,7 +232,11 @@ def test_construct_column_with_array_of_struct(self):
assert fields_col["dataType"] == DataType.ARRAY.value
assert fields_col["arrayDataType"] == DataType.STRUCT
- assert {child["name"] for child in fields_col["children"]} == {"id", "name", "type"}
+ assert {child["name"] for child in fields_col["children"]} == {
+ "id",
+ "name",
+ "type",
+ }
def test_create_column_object(self):
"""test create column object fn"""
@@ -251,11 +265,23 @@ def test_fetch_col_types_majority_wins(self):
DataType.STRING,
),
# Minority of ambiguous month tokens mixed in a long list of plain strings.
- ("mostly_strings_few_month_tokens", ["foo", "bar", "baz", "May", "qux", "quux", "March"], DataType.STRING),
+ (
+ "mostly_strings_few_month_tokens",
+ ["foo", "bar", "baz", "May", "qux", "quux", "March"],
+ DataType.STRING,
+ ),
# All values are unambiguous ISO dates — must be DATETIME.
- ("pure_iso_dates", ["2024-01-01", "2024-06-15", "2025-03-20"], DataType.DATETIME),
+ (
+ "pure_iso_dates",
+ ["2024-01-01", "2024-06-15", "2025-03-20"],
+ DataType.DATETIME,
+ ),
# Natural-language date phrases — all parse as dates — must be DATETIME.
- ("natural_language_dates", ["May 2025", "June 2026", "March 2024", "January 2023"], DataType.DATETIME),
+ (
+ "natural_language_dates",
+ ["May 2025", "June 2026", "March 2024", "January 2023"],
+ DataType.DATETIME,
+ ),
# Pure strings, no date-parseable values at all.
("pure_strings", ["hello", "world", "foo", "bar"], DataType.STRING),
# All plain integers stored as strings — must be INT.
@@ -264,7 +290,9 @@ def test_fetch_col_types_majority_wins(self):
for name, values, expected in cases:
with self.subTest(name):
df = pd.DataFrame({"col": values})
- self.assertEqual(GenericDataFrameColumnParser.fetch_col_types(df, "col"), expected)
+ self.assertEqual(
+ GenericDataFrameColumnParser.fetch_col_types(df, "col"), expected
+ )
class TestParquetDataFrameColumnParser(TestCase):
@@ -881,11 +909,15 @@ def test_large_already_parsed_dict_typed_as_json(self):
def test_null_column_typed_as_string(self):
df = pd.DataFrame({"col": [None]})
- assert GenericDataFrameColumnParser.fetch_col_types(df, "col") == DataType.STRING
+ assert (
+ GenericDataFrameColumnParser.fetch_col_types(df, "col") == DataType.STRING
+ )
def test_string_column_typed_as_string(self):
df = pd.DataFrame({"col": ["hello"]})
- assert GenericDataFrameColumnParser.fetch_col_types(df, "col") == DataType.STRING
+ assert (
+ GenericDataFrameColumnParser.fetch_col_types(df, "col") == DataType.STRING
+ )
def test_int_column_typed_as_int(self):
df = pd.DataFrame({"col": [42]})
@@ -917,7 +949,9 @@ def test_int_and_float_mix_typed_as_float(self):
def test_pure_string_column_typed_as_string(self):
# Control: no structured types present → still STRING
df = pd.DataFrame({"col": ["hello", "world"]})
- assert GenericDataFrameColumnParser.fetch_col_types(df, "col") == DataType.STRING
+ assert (
+ GenericDataFrameColumnParser.fetch_col_types(df, "col") == DataType.STRING
+ )
def test_pure_dict_column_typed_as_json(self):
# Control: all dicts → JSON with no ambiguity
@@ -963,7 +997,13 @@ def test_malformed_string_values_are_skipped(self):
assert {c["name"] for c in children} == {"key"}
def test_nested_dict_structure_returns_children(self):
- nodes = {"model.Project.my_model": {"name": "my_model", "unique_id": "x", "description": "test"}}
+ nodes = {
+ "model.Project.my_model": {
+ "name": "my_model",
+ "unique_id": "x",
+ "description": "test",
+ }
+ }
col = pd.Series([nodes])
children = GenericDataFrameColumnParser.get_children(col)
assert len(children) == 1
@@ -980,7 +1020,9 @@ class TestSingleObjectJsonFileIngestion:
errors.
"""
- RESOURCES = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources", "datalake") # noqa: PTH118, PTH120
+ RESOURCES = os.path.join(
+ os.path.dirname(os.path.dirname(__file__)), "resources", "datalake"
+ ) # noqa: PTH118, PTH120
def _load_fixture_as_dataframe(self, filename):
path = os.path.join(self.RESOURCES, filename) # noqa: PTH118
@@ -992,7 +1034,9 @@ def _load_fixture_as_dataframe(self, filename):
def _parsed_columns(self, filename):
df = self._load_fixture_as_dataframe(filename)
- return {col.name.root: col for col in GenericDataFrameColumnParser._get_columns(df)}
+ return {
+ col.name.root: col for col in GenericDataFrameColumnParser._get_columns(df)
+ }
def test_dict_valued_columns_typed_as_json(self):
cols = self._parsed_columns("dbt_catalog.json")
@@ -1010,12 +1054,30 @@ def test_non_empty_dict_column_has_children(self):
def test_empty_dict_columns_typed_as_json_not_string(self):
cols = self._parsed_columns("dbt_manifest.json")
- for name in ("metrics", "groups", "disabled", "group_map", "saved_queries", "semantic_models", "unit_tests"):
- assert cols[name].dataType == DataType.JSON, f"column '{name}': expected JSON, got {cols[name].dataType}"
+ for name in (
+ "metrics",
+ "groups",
+ "disabled",
+ "group_map",
+ "saved_queries",
+ "semantic_models",
+ "unit_tests",
+ ):
+ assert (
+ cols[name].dataType == DataType.JSON
+ ), f"column '{name}': expected JSON, got {cols[name].dataType}"
def test_empty_dict_columns_have_no_children(self):
cols = self._parsed_columns("dbt_manifest.json")
- for name in ("metrics", "groups", "disabled", "group_map", "saved_queries", "semantic_models", "unit_tests"):
+ for name in (
+ "metrics",
+ "groups",
+ "disabled",
+ "group_map",
+ "saved_queries",
+ "semantic_models",
+ "unit_tests",
+ ):
children = cols[name].children
assert not children, f"column '{name}' should have no children"
@@ -1031,7 +1093,9 @@ def _make_catalog_df():
[
{
"metadata": {"dbt_version": "1.5.0", "generated_at": "2024-01-01"},
- "nodes": {"model.Project.tbl": {"name": "tbl", "description": "test"}},
+ "nodes": {
+ "model.Project.tbl": {"name": "tbl", "description": "test"}
+ },
"sources": {},
"errors": None,
}
@@ -1059,15 +1123,35 @@ def _make_manifest_df():
def test_catalog_column_types(self):
df = self._make_catalog_df()
- assert GenericDataFrameColumnParser.fetch_col_types(df, "metadata") == DataType.JSON
- assert GenericDataFrameColumnParser.fetch_col_types(df, "nodes") == DataType.JSON
- assert GenericDataFrameColumnParser.fetch_col_types(df, "sources") == DataType.JSON
- assert GenericDataFrameColumnParser.fetch_col_types(df, "errors") == DataType.STRING
+ assert (
+ GenericDataFrameColumnParser.fetch_col_types(df, "metadata")
+ == DataType.JSON
+ )
+ assert (
+ GenericDataFrameColumnParser.fetch_col_types(df, "nodes") == DataType.JSON
+ )
+ assert (
+ GenericDataFrameColumnParser.fetch_col_types(df, "sources") == DataType.JSON
+ )
+ assert (
+ GenericDataFrameColumnParser.fetch_col_types(df, "errors")
+ == DataType.STRING
+ )
def test_manifest_empty_dict_columns_typed_as_json(self):
df = self._make_manifest_df()
- for col in ("metrics", "groups", "disabled", "group_map", "saved_queries", "semantic_models", "unit_tests"):
- assert GenericDataFrameColumnParser.fetch_col_types(df, col) == DataType.JSON, f"{col} should be JSON"
+ for col in (
+ "metrics",
+ "groups",
+ "disabled",
+ "group_map",
+ "saved_queries",
+ "semantic_models",
+ "unit_tests",
+ ):
+ assert (
+ GenericDataFrameColumnParser.fetch_col_types(df, col) == DataType.JSON
+ ), f"{col} should be JSON"
def test_catalog_nodes_children_extracted_without_error(self):
df = self._make_catalog_df()
diff --git a/openmetadata-spec/src/main/resources/json/schema/api/feed/createPost.json b/openmetadata-spec/src/main/resources/json/schema/api/feed/createPost.json
index 207663384292..c063a956561b 100644
--- a/openmetadata-spec/src/main/resources/json/schema/api/feed/createPost.json
+++ b/openmetadata-spec/src/main/resources/json/schema/api/feed/createPost.json
@@ -9,11 +9,7 @@
"description": "Message in Markdown format. See markdown support for more details.",
"type": "string"
},
- "from": {
- "description": "Name of the User posting the message",
- "type": "string"
- }
},
- "required": ["message", "from"],
+ "required": ["message"],
"additionalProperties": false
}
diff --git a/openmetadata-spec/src/main/resources/json/schema/api/feed/createThread.json b/openmetadata-spec/src/main/resources/json/schema/api/feed/createThread.json
index 31e75c59435e..71051acc16ba 100644
--- a/openmetadata-spec/src/main/resources/json/schema/api/feed/createThread.json
+++ b/openmetadata-spec/src/main/resources/json/schema/api/feed/createThread.json
@@ -35,10 +35,6 @@
"description": "Message",
"type": "string"
},
- "from": {
- "description": "Name of the User (regular user or bot) posting the message",
- "type": "string"
- },
"addressedTo": {
"description": "User or team this thread is addressed to in format <#E::{entities}::{entityName}::{field}::{fieldValue}.",
"$ref": "../../type/basic.json#/definitions/entityLink"
@@ -69,6 +65,6 @@
"default": null
}
},
- "required": ["message", "from", "about"],
+ "required": ["message", "about"],
"additionalProperties": false
}
diff --git a/openmetadata-ui/src/main/resources/ui/playwright/e2e/Features/ActivityFeed.spec.ts b/openmetadata-ui/src/main/resources/ui/playwright/e2e/Features/ActivityFeed.spec.ts
index c108595227ac..42992599a989 100644
--- a/openmetadata-ui/src/main/resources/ui/playwright/e2e/Features/ActivityFeed.spec.ts
+++ b/openmetadata-ui/src/main/resources/ui/playwright/e2e/Features/ActivityFeed.spec.ts
@@ -668,7 +668,6 @@ test.describe('Mentions: Chinese character encoding in activity feed', () => {
// Create a conversation thread via API so we can post replies in the tests
await apiContext.post('/api/v1/feed', {
data: {
- from: adminUser.responseData.name,
message: 'Initial conversation for Chinese character encoding test',
about: `<#E::databaseSchema::${schemaFqn}>`,
type: 'Conversation',
diff --git a/openmetadata-ui/src/main/resources/ui/playwright/e2e/Features/Tasks/TaskNavigation.spec.ts b/openmetadata-ui/src/main/resources/ui/playwright/e2e/Features/Tasks/TaskNavigation.spec.ts
new file mode 100644
index 000000000000..e9f7a12ee796
--- /dev/null
+++ b/openmetadata-ui/src/main/resources/ui/playwright/e2e/Features/Tasks/TaskNavigation.spec.ts
@@ -0,0 +1,777 @@
+/*
+ * Copyright 2025 Collate.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import { expect, test } from '@playwright/test';
+import { TableClass } from '../../../support/entity/TableClass';
+import { UserClass } from '../../../support/user/UserClass';
+import { performAdminLogin } from '../../../utils/admin';
+import { getApiContext, redirectToHomePage } from '../../../utils/common';
+import { waitForAllLoadersToDisappear } from '../../../utils/entity';
+import { waitForPageLoaded } from '../../../utils/polling';
+import { waitForTaskListResponse } from '../../../utils/task';
+
+/**
+ * Task Navigation Tests
+ *
+ * Tests task navigation scenarios including:
+ * - Clicking task in activity feed navigates to correct entity page
+ * - Task link should NOT generate 404 error
+ * - Task link should NOT go to /table/TASK-XXXXX (wrong URL)
+ * - Task detail drawer opens correctly
+ * - Navigation from different contexts (home, entity page, notifications)
+ */
+
+test.describe('Task Navigation - Activity Feed Widget', () => {
+ const adminUser = new UserClass();
+ const assigneeUser = new UserClass();
+ const table = new TableClass();
+
+ test.beforeAll('Setup test data and create task', async ({ browser }) => {
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+
+ try {
+ await adminUser.create(apiContext);
+ await adminUser.setAdminRole(apiContext);
+ await assigneeUser.create(apiContext);
+
+ await table.create(apiContext);
+ await table.setOwner(apiContext, {
+ id: assigneeUser.responseData.id,
+ type: 'user',
+ });
+
+ // Create a task
+ await apiContext.post('/api/v1/tasks', {
+ data: {
+ name: `Test Task - ${Date.now()}`,
+ about: table.entityResponseData?.fullyQualifiedName,
+ aboutType: 'table',
+ type: 'DescriptionUpdate',
+ category: 'MetadataUpdate',
+ assignees: [assigneeUser.responseData.name],
+ },
+ });
+ } finally {
+ await afterAction();
+ }
+ });
+
+ test.afterAll('Cleanup test data', async ({ browser }) => {
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+
+ try {
+ await table.delete(apiContext);
+ await assigneeUser.delete(apiContext);
+ await adminUser.delete(apiContext);
+ } finally {
+ await afterAction();
+ }
+ });
+
+ test.beforeEach(async ({ page }) => {
+ await adminUser.login(page);
+ });
+
+ test('clicking task in home feed widget should navigate to entity page', async ({
+ page,
+ }) => {
+ await redirectToHomePage(page);
+ await waitForPageLoaded(page);
+
+ // Find the activity feed widget
+ const feedWidget = page.getByTestId('KnowledgePanel.ActivityFeed');
+
+ if (await feedWidget.isVisible()) {
+ // Look for task items in the feed
+ const taskItem = feedWidget
+ .locator(
+ '[data-testid="task-feed-card"], [data-testid="message-container"]'
+ )
+ .first();
+
+ if (await taskItem.isVisible()) {
+ // Click on the task link
+ const taskLink = taskItem.getByTestId('redirect-task-button-link');
+
+ if (await taskLink.isVisible()) {
+ await taskLink.click();
+ await waitForPageLoaded(page);
+
+ // CRITICAL: Should NOT be a 404 page
+ await expect(page.getByText('No data available')).not.toBeVisible();
+ await expect(page.locator('.error-page')).not.toBeVisible();
+
+ // CRITICAL: URL should NOT contain /table/TASK-
+ expect(page.url()).not.toMatch(/\/table\/TASK-/);
+
+ // Should navigate to the entity page with activity feed tab
+ const entityFqn = table.entityResponseData?.fullyQualifiedName;
+ if (entityFqn) {
+ // URL should contain the entity FQN or be on the entity page
+ const isOnEntityPage =
+ page.url().includes(encodeURIComponent(entityFqn)) ||
+ page.url().includes('activity_feed');
+
+ expect(isOnEntityPage).toBe(true);
+ }
+ }
+ }
+ }
+ });
+
+ test('task link should contain correct entity FQN, not task ID', async ({
+ page,
+ }) => {
+ await table.visitEntityPage(page);
+
+ await page.getByTestId('activity_feed').click();
+ await waitForPageLoaded(page);
+
+ const tasksTab = page.getByRole('button', { name: /tasks/i });
+ if (await tasksTab.isVisible()) {
+ await tasksTab.click();
+ await waitForPageLoaded(page);
+ }
+
+ const taskCard = page.locator('[data-testid="task-feed-card"]').first();
+
+ if (await taskCard.isVisible()) {
+ const taskLink = taskCard.getByTestId('redirect-task-button-link');
+
+ if (await taskLink.isVisible()) {
+ // Get the href attribute if it's a link
+ const href = await taskLink.getAttribute('href');
+
+ if (href) {
+ // CRITICAL: href should NOT contain TASK- as the entity FQN
+ expect(href).not.toMatch(/\/table\/TASK-/);
+ expect(href).not.toMatch(/\/TASK-\d{5}$/);
+ }
+
+ // Click and verify navigation
+ await taskLink.click();
+ await waitForPageLoaded(page);
+
+ // Should be on entity page, not 404
+ await expect(page.getByText('No data available')).not.toBeVisible();
+ }
+ }
+ });
+});
+
+test.describe('Task Navigation - Entity Page', () => {
+ const adminUser = new UserClass();
+ const assigneeUser = new UserClass();
+ const table = new TableClass();
+
+ test.beforeAll('Setup test data', async ({ browser }) => {
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+
+ try {
+ await adminUser.create(apiContext);
+ await adminUser.setAdminRole(apiContext);
+ await assigneeUser.create(apiContext);
+
+ await table.create(apiContext);
+ await table.setOwner(apiContext, {
+ id: assigneeUser.responseData.id,
+ type: 'user',
+ });
+
+ // Create multiple tasks
+ for (let i = 0; i < 3; i++) {
+ await apiContext.post('/api/v1/tasks', {
+ data: {
+ name: `Test Task - ${Date.now()}-${i}`,
+ about: table.entityResponseData?.fullyQualifiedName,
+ aboutType: 'table',
+ type: i % 2 === 0 ? 'DescriptionRequest' : 'TagRequest',
+ category: 'MetadataUpdate',
+ assignees: [assigneeUser.responseData.name],
+ },
+ });
+ }
+ } finally {
+ await afterAction();
+ }
+ });
+
+ test.afterAll('Cleanup test data', async ({ browser }) => {
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+
+ try {
+ await table.delete(apiContext);
+ await assigneeUser.delete(apiContext);
+ await adminUser.delete(apiContext);
+ } finally {
+ await afterAction();
+ }
+ });
+
+ test.beforeEach(async ({ page }) => {
+ await adminUser.login(page);
+ });
+
+ test('should display tasks in entity activity feed tab', async ({ page }) => {
+ await table.visitEntityPage(page);
+
+ // Click on activity feed tab
+ const activityFeedTab = page.getByRole('tab', {
+ name: /activity feeds & tasks/i,
+ });
+ await activityFeedTab.click();
+ await waitForPageLoaded(page);
+
+ // Click on Tasks filter
+ const tasksFilter = page.getByRole('button', { name: /tasks/i });
+ if (await tasksFilter.isVisible()) {
+ await tasksFilter.click();
+ await waitForPageLoaded(page);
+ }
+
+ // Use Playwright's polling mechanism for task visibility
+ const taskCards = page.locator('[data-testid="task-feed-card"]');
+
+ await expect
+ .poll(async () => taskCards.count(), {
+ message: 'Waiting for task cards to appear',
+ timeout: 30000,
+ intervals: [2000, 3000, 5000],
+ })
+ .toBeGreaterThanOrEqual(0);
+ });
+
+ test('clicking task card should open task detail drawer', async ({
+ page,
+ }) => {
+ await table.visitEntityPage(page);
+
+ await page.getByTestId('activity_feed').click();
+ await waitForPageLoaded(page);
+
+ const tasksTab = page.getByRole('button', { name: /tasks/i });
+ if (await tasksTab.isVisible()) {
+ await tasksTab.click();
+ await waitForPageLoaded(page);
+ }
+
+ const taskCard = page.locator('[data-testid="task-feed-card"]').first();
+
+ if (await taskCard.isVisible()) {
+ await taskCard.click();
+
+ // Should open drawer with task details
+ const drawer = page.locator('.ant-drawer-content');
+
+ if (await drawer.isVisible({ timeout: 5000 })) {
+ // Drawer should show task details
+ await expect(drawer).toBeVisible();
+
+ // Should have task ID
+ await expect(drawer.getByText(/TASK-/)).toBeVisible();
+
+ // Should have comments section
+ const commentsSection = drawer.locator(
+ '[data-testid="comments-section"], [data-testid="task-comments"]'
+ );
+ // Comments section might exist
+ }
+ }
+ });
+
+ test('task count badge should match actual task count', async ({ page }) => {
+ await table.visitEntityPage(page);
+
+ // Get count from tab badge
+ const activityFeedTab = page.getByRole('tab', {
+ name: /activity feeds & tasks/i,
+ });
+ const countBadge = activityFeedTab.getByTestId('count');
+
+ let displayedCount = 0;
+ if (await countBadge.isVisible()) {
+ const countText = await countBadge.textContent();
+ displayedCount = parseInt(countText || '0', 10);
+ }
+
+ // Click on tab and go to tasks
+ await activityFeedTab.click();
+ await waitForPageLoaded(page);
+
+ const tasksFilter = page.getByRole('button', { name: /tasks/i });
+ if (await tasksFilter.isVisible()) {
+ await tasksFilter.click();
+ await waitForPageLoaded(page);
+ }
+
+ // Count actual task cards
+ const taskCards = page.locator('[data-testid="task-feed-card"]');
+ const actualCount = await taskCards.count();
+
+ // Counts should match (allowing for pagination)
+ // Note: If there's pagination, actualCount might be less
+ expect(actualCount).toBeGreaterThanOrEqual(0);
+ });
+});
+
+test.describe('Task Navigation - Notification Box', () => {
+ const adminUser = new UserClass();
+ const assigneeUser = new UserClass();
+ const table = new TableClass();
+
+ test.beforeAll('Setup test data', async ({ browser }) => {
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+
+ try {
+ await adminUser.create(apiContext);
+ await adminUser.setAdminRole(apiContext);
+ await assigneeUser.create(apiContext);
+
+ await table.create(apiContext);
+
+ // Create task assigned to assignee
+ await apiContext.post('/api/v1/tasks', {
+ data: {
+ name: `Test Task - ${Date.now()}`,
+ about: table.entityResponseData?.fullyQualifiedName,
+ aboutType: 'table',
+ type: 'DescriptionUpdate',
+ category: 'MetadataUpdate',
+ assignees: [assigneeUser.responseData.name],
+ },
+ });
+ } finally {
+ await afterAction();
+ }
+ });
+
+ test.afterAll('Cleanup test data', async ({ browser }) => {
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+
+ try {
+ await table.delete(apiContext);
+ await assigneeUser.delete(apiContext);
+ await adminUser.delete(apiContext);
+ } finally {
+ await afterAction();
+ }
+ });
+
+ test('assignee should see task in notification box', async ({ page }) => {
+ await assigneeUser.login(page);
+ await redirectToHomePage(page);
+ await waitForPageLoaded(page);
+
+ // Click notification bell
+ const notificationBell = page.getByTestId('task-notifications');
+
+ if (await notificationBell.isVisible()) {
+ await notificationBell.click();
+
+ const notificationBox = page.locator('.notification-box');
+ await expect(notificationBox).toBeVisible();
+
+ // Look for Tasks tab
+ const tasksTab = notificationBox.getByText('Tasks', { exact: false });
+
+ if (await tasksTab.isVisible()) {
+ await tasksTab.click();
+ await waitForPageLoaded(page);
+
+ // Should see assigned tasks
+ const taskItems = notificationBox.locator(
+ '[data-testid^="notification-link-"], .notification-dropdown-list-btn'
+ );
+
+ const count = await taskItems.count();
+ expect(count).toBeGreaterThanOrEqual(0);
+ }
+ }
+ });
+
+ test('clicking task notification should navigate correctly', async ({
+ page,
+ }) => {
+ await assigneeUser.login(page);
+ await redirectToHomePage(page);
+ await waitForPageLoaded(page);
+
+ const notificationBell = page.getByTestId('task-notifications');
+
+ if (await notificationBell.isVisible()) {
+ await notificationBell.click();
+
+ const notificationBox = page.locator('.notification-box');
+ await expect(notificationBox).toBeVisible();
+
+ const tasksTab = notificationBox.getByText('Tasks', { exact: false });
+
+ if (await tasksTab.isVisible()) {
+ await tasksTab.click();
+ await waitForPageLoaded(page);
+
+ const taskLink = notificationBox
+ .locator('[data-testid^="notification-link-"]')
+ .first();
+
+ if (await taskLink.isVisible()) {
+ await taskLink.click();
+ await waitForPageLoaded(page);
+
+ // Should NOT be 404
+ await expect(page.getByText('No data available')).not.toBeVisible();
+
+ // URL should NOT contain /table/TASK-
+ expect(page.url()).not.toMatch(/\/table\/TASK-/);
+ }
+ }
+ }
+ });
+});
+
+test.describe('Task Navigation - URL Validation', () => {
+ const adminUser = new UserClass();
+ const table = new TableClass();
+
+ test.beforeAll('Setup test data', async ({ browser }) => {
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+
+ try {
+ await adminUser.create(apiContext);
+ await adminUser.setAdminRole(apiContext);
+
+ await table.create(apiContext);
+ } finally {
+ await afterAction();
+ }
+ });
+
+ test.afterAll('Cleanup test data', async ({ browser }) => {
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+
+ try {
+ await table.delete(apiContext);
+ await adminUser.delete(apiContext);
+ } finally {
+ await afterAction();
+ }
+ });
+
+ test('navigating to /table/TASK-XXXXX should show 404 (invalid URL pattern)', async ({
+ page,
+ }) => {
+ await adminUser.login(page);
+
+ // This is a regression test - /table/TASK-00001 is an invalid URL
+ // because TASK-00001 is a task ID, not a table FQN
+ await page.goto('/table/TASK-00001');
+ await waitForPageLoaded(page);
+
+ // Should show 404 or "No data available"
+ const noData = page.getByText('No data available');
+ const notFound = page.getByText('404');
+ const pageNotFound = page.getByText('Page not found', { exact: false });
+
+ const isError =
+ (await noData.isVisible()) ||
+ (await notFound.isVisible()) ||
+ (await pageNotFound.isVisible());
+
+ // This URL pattern should result in an error/404
+ expect(isError).toBe(true);
+ });
+
+ test('task detail page with valid task ID should work', async ({
+ browser,
+ }) => {
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+
+ try {
+ // Create a task
+ const taskResponse = await apiContext.post('/api/v1/tasks', {
+ data: {
+ name: `Test Task - ${Date.now()}`,
+ about: table.entityResponseData?.fullyQualifiedName,
+ aboutType: 'table',
+ type: 'DescriptionUpdate',
+ category: 'MetadataUpdate',
+ assignees: [adminUser.responseData.name],
+ },
+ });
+ const task = await taskResponse.json();
+
+ const page = await browser.newPage();
+ await adminUser.login(page);
+
+ // Navigate to task-related entity page
+ // The correct pattern should be /table/{entityFqn}?activeTab=activity_feed
+ const entityFqn = table.entityResponseData?.fullyQualifiedName;
+
+ if (entityFqn) {
+ await page.goto(`/table/${encodeURIComponent(entityFqn)}`);
+ await waitForPageLoaded(page);
+
+ // Should NOT be 404
+ await expect(page.getByText('No data available')).not.toBeVisible();
+ }
+
+ await page.close();
+ } finally {
+ await afterAction();
+ }
+ });
+});
+
+/**
+ * Task Notification Refresh (Issue #27433)
+ *
+ * Single-page scenario:
+ * 1. User navigates directly to a test-owned table entity page.
+ * 2. Opens "Activity Feed & Tasks" tab and stays there.
+ * 3. A task is created via API assigned to the same logged-in user.
+ * 4. User opens the notification bell and clicks the latest task notification,
+ * which points to the same entity/activity-feed URL already open.
+ * 5. The fix (tasksRefreshKey in navigation state) must trigger a re-fetch so
+ * the task list updates without a full page reload.
+ */
+test.describe('Task Notification - activity-feed tab refreshes after clicking notification', () => {
+ let adminUser: UserClass;
+ let otherUser: UserClass;
+ let table: TableClass;
+ let taskId: string | undefined;
+
+ test.afterAll(
+ 'Delete task, table, admin user and other user',
+ async ({ browser }) => {
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+ try {
+ if (taskId) {
+ await apiContext.delete(`/api/v1/tasks/${taskId}`);
+ }
+ await table.delete(apiContext);
+ await adminUser.delete(apiContext);
+ await otherUser.delete(apiContext);
+ } finally {
+ await afterAction();
+ }
+ }
+ );
+
+ test.beforeAll(
+ 'Create admin user, other user and table',
+ async ({ browser }) => {
+ adminUser = new UserClass();
+ otherUser = new UserClass();
+ table = new TableClass();
+ const { apiContext, afterAction } = await performAdminLogin(browser);
+ try {
+ await adminUser.create(apiContext);
+ await adminUser.setAdminRole(apiContext);
+ await otherUser.create(apiContext);
+ await table.create(apiContext);
+ } finally {
+ await afterAction();
+ }
+ }
+ );
+
+ test('clicking task notification while on entity task tab refreshes the task list', async ({
+ page,
+ }) => {
+ test.slow();
+
+ await test.step('Log in and navigate to entity page', async () => {
+ await adminUser.login(page);
+ const entityFqn = table.entityResponseData?.fullyQualifiedName ?? '';
+ await page.goto(`/table/${encodeURIComponent(entityFqn)}`);
+ await waitForPageLoaded(page);
+ await waitForAllLoadersToDisappear(page);
+ });
+
+ await test.step('Open Activity Feed & Tasks tab and stay there', async () => {
+ const feedResponse = page.waitForResponse(
+ (r) =>
+ r.url().includes('/api/v1/feed') && r.request().method() === 'GET'
+ );
+ await page.getByTestId('activity_feed').click();
+ await feedResponse;
+ await waitForAllLoadersToDisappear(page);
+ });
+
+ await test.step('Create task via API assigned to the logged-in user', async () => {
+ const entityFqn = table.entityResponseData?.fullyQualifiedName ?? '';
+ const { apiContext, afterAction } = await getApiContext(page);
+ try {
+ const response = await apiContext.post('/api/v1/tasks', {
+ data: {
+ about: entityFqn,
+ aboutType: 'table',
+ type: 'DescriptionUpdate',
+ category: 'MetadataUpdate',
+ assignees: [adminUser.responseData.name],
+ },
+ });
+ const created = await response.json();
+ taskId = created.id;
+ } finally {
+ await afterAction();
+ }
+ });
+
+ await test.step('Open notification bell and click the latest task notification', async () => {
+ const notificationBell = page.getByTestId('task-notifications');
+ await expect(notificationBell).toBeVisible();
+
+ const notifFeedResponse = page.waitForResponse(
+ (r) =>
+ r.url().includes('/api/v1/tasks/assigned') &&
+ r.url().includes('status=Open')
+ );
+ await notificationBell.click();
+ await notifFeedResponse;
+
+ const notificationBox = page.locator('.notification-box');
+ await expect(notificationBox).toBeVisible();
+
+ const latestNotification = notificationBox
+ .locator('li.ant-list-item.notification-dropdown-list-btn')
+ .first();
+ await expect(latestNotification).toBeVisible();
+
+ const taskListRefresh = waitForTaskListResponse(page);
+ await latestNotification.click();
+ await taskListRefresh;
+
+ await waitForAllLoadersToDisappear(page);
+ });
+
+ await test.step('Task list is refreshed with the latest task details', async () => {
+ const taskCards = page.locator('[data-testid="task-feed-card"]');
+
+ await expect
+ .poll(async () => taskCards.count(), {
+ message: 'Waiting for refreshed task list to include the new task',
+ timeout: 30_000,
+ intervals: [1000, 2000, 3000],
+ })
+ .toBeGreaterThanOrEqual(1);
+
+ expect(page.url()).not.toMatch(/\/table\/TASK-/);
+ });
+ });
+
+ test('two sessions: admin on Columns tab creates task, assignee sees refresh on notification click', async ({
+ browser,
+ }) => {
+ test.slow();
+
+ const entityFqn = table.entityResponseData?.fullyQualifiedName ?? '';
+
+ const adminContext = await browser.newContext();
+ const userContext = await browser.newContext();
+ const adminPage = await adminContext.newPage();
+ const userPage = await userContext.newPage();
+
+ try {
+ await test.step('Log in both sessions', async () => {
+ await adminUser.login(adminPage);
+ await otherUser.login(userPage);
+ });
+
+ await test.step('Admin navigates to entity Columns (Schema) tab', async () => {
+ await table.visitEntityPage(adminPage);
+ const schemaTab = adminPage.getByRole('tab', { name: /schema/i });
+ if (await schemaTab.isVisible()) {
+ await schemaTab.click();
+ await waitForAllLoadersToDisappear(adminPage);
+ }
+ });
+
+ await test.step('Other user navigates to entity Activity Feed & Tasks tab', async () => {
+ await userPage.goto(`/table/${encodeURIComponent(entityFqn)}`);
+ await waitForPageLoaded(userPage);
+ await waitForAllLoadersToDisappear(userPage);
+ const feedResponse = userPage.waitForResponse(
+ (r) =>
+ r.url().includes('/api/v1/feed') && r.request().method() === 'GET'
+ );
+ await userPage.getByTestId('activity_feed').click();
+ await feedResponse;
+ await waitForAllLoadersToDisappear(userPage);
+ });
+
+ await test.step('Admin creates a task via API and assigns to other user', async () => {
+ const { apiContext, afterAction } = await getApiContext(adminPage);
+ try {
+ const response = await apiContext.post('/api/v1/tasks', {
+ data: {
+ about: entityFqn,
+ aboutType: 'table',
+ type: 'DescriptionUpdate',
+ category: 'MetadataUpdate',
+ assignees: [otherUser.responseData.name],
+ },
+ });
+ const created = await response.json();
+ taskId = created.id;
+ } finally {
+ await afterAction();
+ }
+ });
+
+ await test.step('Other user clicks bell icon and latest task notification', async () => {
+ const notificationBell = userPage.getByTestId('task-notifications');
+ await expect(notificationBell).toBeVisible();
+
+ const notifFeedResponse = userPage.waitForResponse(
+ (r) =>
+ r.url().includes('/api/v1/tasks/assigned') &&
+ r.url().includes('status=Open')
+ );
+ await notificationBell.click();
+ await notifFeedResponse;
+
+ const notificationBox = userPage.locator('.notification-box');
+ await expect(notificationBox).toBeVisible();
+
+ const latestNotification = notificationBox
+ .locator('li.ant-list-item.notification-dropdown-list-btn')
+ .first();
+ await expect(latestNotification).toBeVisible();
+
+ const taskListRefresh = waitForTaskListResponse(userPage);
+ await latestNotification.click();
+ await taskListRefresh;
+
+ await waitForAllLoadersToDisappear(userPage);
+ });
+
+ await test.step('Task list is refreshed with the new task on the other user page', async () => {
+ const taskCards = userPage.locator('[data-testid="task-feed-card"]');
+
+ await expect
+ .poll(async () => taskCards.count(), {
+ message: 'Waiting for refreshed task list to include the new task',
+ timeout: 30_000,
+ intervals: [1000, 2000, 3000],
+ })
+ .toBeGreaterThanOrEqual(1);
+
+ expect(userPage.url()).not.toMatch(/\/table\/TASK-/);
+ });
+ } finally {
+ await adminContext.close();
+ await userContext.close();
+ }
+ });
+});
diff --git a/openmetadata-ui/src/main/resources/ui/src/components/ActivityFeed/ActivityFeedProvider/ActivityFeedProvider.tsx b/openmetadata-ui/src/main/resources/ui/src/components/ActivityFeed/ActivityFeedProvider/ActivityFeedProvider.tsx
index fb2065b9c942..aee30d9849ec 100644
--- a/openmetadata-ui/src/main/resources/ui/src/components/ActivityFeed/ActivityFeedProvider/ActivityFeedProvider.tsx
+++ b/openmetadata-ui/src/main/resources/ui/src/components/ActivityFeed/ActivityFeedProvider/ActivityFeedProvider.tsx
@@ -210,7 +210,6 @@ const ActivityFeedProvider = ({ children, user }: Props) => {
const data = {
message: value,
- from: currentUser.name,
} as Post;
try {
diff --git a/openmetadata-ui/src/main/resources/ui/src/components/ActivityFeed/ActivityFeedTab/ActivityFeedTab.component.tsx b/openmetadata-ui/src/main/resources/ui/src/components/ActivityFeed/ActivityFeedTab/ActivityFeedTab.component.tsx
index aa15e29f8d07..44c6e45f1d11 100644
--- a/openmetadata-ui/src/main/resources/ui/src/components/ActivityFeed/ActivityFeedTab/ActivityFeedTab.component.tsx
+++ b/openmetadata-ui/src/main/resources/ui/src/components/ActivityFeed/ActivityFeedTab/ActivityFeedTab.component.tsx
@@ -14,9 +14,16 @@ import { Button, Dropdown, Menu, Segmented, Space, Typography } from 'antd';
import { AxiosError } from 'axios';
import classNames from 'classnames';
import { isEmpty } from 'lodash';
-import { RefObject, useCallback, useEffect, useMemo, useState } from 'react';
+import {
+ RefObject,
+ useCallback,
+ useEffect,
+ useMemo,
+ useRef,
+ useState,
+} from 'react';
import { useTranslation } from 'react-i18next';
-import { useNavigate } from 'react-router-dom';
+import { useLocation, useNavigate } from 'react-router-dom';
import { ReactComponent as AllActivityIcon } from '../../../assets/svg/all-activity-v2.svg';
import { ReactComponent as TaskCloseIcon } from '../../../assets/svg/ic-check-circle-new.svg';
import { ReactComponent as TaskCloseIconBlue } from '../../../assets/svg/ic-close-task.svg';
@@ -89,6 +96,7 @@ export const ActivityFeedTab = ({
urlFqn = '',
}: ActivityFeedTabProps) => {
const navigate = useNavigate();
+ const location = useLocation();
const { t } = useTranslation();
const { currentUser } = useApplicationStore();
const { isAdminUser } = useAuth();
@@ -99,8 +107,10 @@ export const ActivityFeedTab = ({
root: document.querySelector('#center-container'),
rootMargin: '0px 0px 2px 0px',
});
- const { subTab: activeTab = subTab } =
- useRequiredParams<{ tab: EntityTabs; subTab: ActivityFeedTabs }>();
+ const { subTab: activeTab = subTab } = useRequiredParams<{
+ tab: EntityTabs;
+ subTab: ActivityFeedTabs;
+ }>();
const [taskFilter, setTaskFilter] = useState(
ThreadTaskStatus.Open
);
@@ -113,6 +123,7 @@ export const ActivityFeedTab = ({
data: FEED_COUNT_INITIAL_DATA,
});
const [isFirstLoad, setIsFirstLoad] = useState(true);
+ const processedRefreshKeyRef = useRef(undefined);
const {
selectedThread,
@@ -278,6 +289,36 @@ export const ActivityFeedTab = ({
}
}, [feedFilter, threadType, fqn]);
+ useEffect(() => {
+ const refreshKey = (location.state as { tasksRefreshKey?: number } | null)
+ ?.tasksRefreshKey;
+ if (
+ refreshKey !== undefined &&
+ refreshKey !== processedRefreshKeyRef.current &&
+ fqn &&
+ isTaskActiveTab
+ ) {
+ processedRefreshKeyRef.current = refreshKey;
+ getFeedData(
+ feedFilter,
+ undefined,
+ threadType,
+ entityType,
+ fqn,
+ taskFilter
+ );
+ }
+ }, [
+ entityType,
+ feedFilter,
+ fqn,
+ getFeedData,
+ isTaskActiveTab,
+ location.key,
+ location.state,
+ taskFilter,
+ ]);
+
useEffect(() => {
if (feedCount) {
setCountData((prev) => ({ ...prev, data: feedCount }));
@@ -328,7 +369,8 @@ export const ActivityFeedTab = ({
'flex items-center justify-between px-4 py-2 gap-2',
{ active: taskFilter === ThreadTaskStatus.Open }
)}
- data-testid="open-tasks">
+ data-testid="open-tasks"
+ >
{taskFilter === ThreadTaskStatus.Open ? (
+ })}
+ >
{t('label.open')}
+ })}
+ >
{countData?.data?.openTaskCount}
@@ -368,7 +412,8 @@ export const ActivityFeedTab = ({
'flex items-center justify-between px-4 py-2 gap-2',
{ active: taskFilter === ThreadTaskStatus.Closed }
)}
- data-testid="closed-tasks">
+ data-testid="closed-tasks"
+ >
{taskFilter === ThreadTaskStatus.Closed ? (
+ })}
+ >
{t('label.closed')}
+ })}
+ >
{countData?.data?.closedTaskCount}
@@ -585,7 +632,8 @@ export const ActivityFeedTab = ({
'three-panel-layout':
layoutType === ActivityFeedLayoutType.THREE_PANEL,
})}
- id="center-container">
+ id="center-container"
+ >
{(isTaskActiveTab || isMentionTabSelected) && (
+ trigger={['click']}
+ >