diff --git a/deploy/build_mcp_docs.py b/deploy/build_mcp_docs.py index 2040820d..0fd01c19 100644 --- a/deploy/build_mcp_docs.py +++ b/deploy/build_mcp_docs.py @@ -31,6 +31,8 @@ DocGroup.INVESTIGATE, DocGroup.BROWSE_PROFILING, DocGroup.TRIGGER, + DocGroup.SCORING, + DocGroup.MANAGE, ] _FALLBACK_GROUP = "Other tools" diff --git a/deploy/testgen.dockerfile b/deploy/testgen.dockerfile index 743f9edf..800cfa4f 100644 --- a/deploy/testgen.dockerfile +++ b/deploy/testgen.dockerfile @@ -1,4 +1,4 @@ -ARG TESTGEN_BASE_LABEL=v15 +ARG TESTGEN_BASE_LABEL=v16 FROM datakitchen/dataops-testgen-base:${TESTGEN_BASE_LABEL} AS release-image diff --git a/pyproject.toml b/pyproject.toml index b7d263a2..7af55506 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataops-testgen" -version = "5.33.3" +version = "5.48.0" description = "DataKitchen's Data Quality DataOps TestGen" authors = [ { "name" = "DataKitchen, Inc.", "email" = "info@datakitchen.io" }, @@ -41,6 +41,7 @@ dependencies = [ "oracledb==3.4.0", "hdbcli==2.25.31", "sqlalchemy-hana==4.4.0", + "salesforce-cdp-connector>=1.0.19", "pyodbc==5.2.0", "psycopg2-binary==2.9.11", "pycryptodome==3.21", @@ -117,6 +118,9 @@ release = [ testgen = "testgen.__main__:cli" tg-patch-streamlit = "testgen.ui.scripts.patch_streamlit:patch" +[project.entry-points."sqlalchemy.dialects"] +salesforce_data360 = "testgen.common.database.salesforce_data360_dialect:SalesforceData360Dialect" + [project.urls] "Source Code" = "https://github.com/DataKitchen/dataops-testgen" "Bug Tracker" = "https://github.com/DataKitchen/dataops-testgen/issues" @@ -397,3 +401,7 @@ asset_dir = "ui/components/frontend/js" [[tool.streamlit.component.components]] name = "sidebar" asset_dir = "ui/components/frontend/js" + +[[tool.streamlit.component.components]] +name = "feedback_widget" +asset_dir = "ui/components/frontend/js" diff --git a/testgen/__main__.py b/testgen/__main__.py index a1ea67a1..ceb396c5 100644 --- a/testgen/__main__.py +++ b/testgen/__main__.py @@ -982,7 +982,7 @@ def init_ui(): "run", app_file, "--browser.gatherUsageStats=false", - f"--logger.level={'debug' if settings.IS_DEBUG else 'error'}", + "--logger.level=error", "--client.showErrorDetails=none", "--client.toolbarMode=minimal", "--server.enableStaticServing=true", diff --git a/testgen/api/__init__.py b/testgen/api/__init__.py index e69de29b..5d82c173 100644 --- a/testgen/api/__init__.py +++ b/testgen/api/__init__.py @@ -0,0 +1,12 @@ +from fastapi import APIRouter + +from testgen.api.app import router as _app_router +from testgen.api.jobs import router as _jobs_router +from testgen.api.runs import router as _runs_router +from testgen.api.test_definitions import router as _test_definitions_router + +router = APIRouter(prefix="/api/v1") +router.include_router(_app_router) +router.include_router(_jobs_router) +router.include_router(_runs_router) +router.include_router(_test_definitions_router) diff --git a/testgen/api/app.py b/testgen/api/app.py index 8111916a..2f81b1fb 100644 --- a/testgen/api/app.py +++ b/testgen/api/app.py @@ -5,7 +5,7 @@ from testgen.api.deps import db_session from testgen.common import version_service -router = APIRouter(prefix="/api/v1", tags=["API"], dependencies=[Depends(db_session)]) +router = APIRouter(tags=["API"], dependencies=[Depends(db_session)]) @router.get("/health") diff --git a/testgen/api/deps.py b/testgen/api/deps.py index 8daac06f..1807d68d 100644 --- a/testgen/api/deps.py +++ b/testgen/api/deps.py @@ -4,10 +4,14 @@ from fastapi import Depends, HTTPException, Security, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy import select from testgen.common.auth import authorize_token, decode_jwt_token from testgen.common.models import Session, _current_session_wrapper, get_current_session +from testgen.common.models.job_execution import PUBLIC_JOB_KEYS, JobExecution from testgen.common.models.project_membership import ProjectMembership +from testgen.common.models.table_group import TableGroup +from testgen.common.models.test_suite import TestSuite from testgen.common.models.user import User from testgen.utils.plugins import PluginHook @@ -73,14 +77,25 @@ def has_project_permission(user: User, project_code: str, permission: str) -> bo # --- Resolver dependency factories --- # Each factory takes a permission string and returns Depends(). The entity ID -# comes from a URL path parameter (FastAPI resolves it natively). -# Entity not found and insufficient permission both raise the same 404 -# with a stable code/message — no variation that could leak the cause. +# comes from a URL path parameter (FastAPI resolves it natively, including +# UUID validation that yields a 422 for malformed inputs). _require_user = Depends(get_authorized_user) _not_found = api_error(404, "not_found", "Not found") +def _check_access(entity, user: User, permission: str): + """Return ``entity`` if the user has ``permission`` on its project, else raise 404. + + Entity-not-found and insufficient-permission both surface as the same 404 + with a stable code/message — no variation that could leak the cause to an + unauthorized caller. + """ + if entity and has_project_permission(user, entity.project_code, permission): + return entity + raise _not_found + + def resolve_project_code(permission: str): """Verify the user has ``permission`` on the project identified by ``project_code`` path param.""" def dependency(project_code: str, user: User = _require_user) -> str: @@ -92,45 +107,31 @@ def dependency(project_code: str, user: User = _require_user) -> str: def resolve_table_group(permission: str): """Resolve a TableGroup by ``table_group_id`` path param and verify project permission.""" - from testgen.common.models.table_group import TableGroup - def dependency(table_group_id: UUID, user: User = _require_user) -> TableGroup: - if (table_group := TableGroup.get(table_group_id)) and has_project_permission(user, table_group.project_code, permission): - return table_group - raise _not_found + return _check_access(TableGroup.get(table_group_id), user, permission) return Depends(dependency) def resolve_test_suite(permission: str): """Resolve a non-monitor TestSuite by ``test_suite_id`` path param and verify project permission.""" - from testgen.common.models.test_suite import TestSuite - def dependency(test_suite_id: UUID, user: User = _require_user) -> TestSuite: - if (test_suite := TestSuite.get_regular(test_suite_id)) and has_project_permission(user, test_suite.project_code, permission): - return test_suite - raise _not_found + return _check_access(TestSuite.get_regular(test_suite_id), user, permission) return Depends(dependency) def resolve_job(permission: str, *extra_filters): """Resolve a JobExecution by ``job_id`` path param and verify project permission. - Internally-submitted jobs (source='system') are never exposed via the API. - Extra ORM clauses are appended to the WHERE clause, e.g. to restrict by job_key. - Mismatches surface as the same 404 — no information leakage. + Only jobs whose ``job_key`` is in ``PUBLIC_JOB_KEYS`` are exposed via the API. + Internal kinds (score rollups, recalculations, monitor runs) are filtered out + by construction. Extra ORM clauses are appended to the WHERE clause to further + restrict by job_key when a caller wants a single kind. """ - from sqlalchemy import select - - from testgen.common.models.job_execution import JobExecution - def dependency(job_id: UUID, user: User = _require_user) -> JobExecution: query = select(JobExecution).where( JobExecution.id == job_id, - JobExecution.source != "system", + JobExecution.job_key.in_(PUBLIC_JOB_KEYS), *extra_filters, ) - job = get_current_session().scalars(query).first() - if job and has_project_permission(user, job.project_code, permission): - return job - raise _not_found + return _check_access(get_current_session().scalars(query).first(), user, permission) return Depends(dependency) diff --git a/testgen/api/jobs.py b/testgen/api/jobs.py index 7131fdf1..3ab291cf 100644 --- a/testgen/api/jobs.py +++ b/testgen/api/jobs.py @@ -10,8 +10,9 @@ resolve_table_group, resolve_test_suite, ) -from testgen.api.schemas import ErrorResponse, JobKey, JobListResponse, JobResponse, JobSource, JobSubmittedResponse -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.api.schemas import ErrorResponse, JobListResponse, JobResponse, JobSubmittedResponse +from testgen.common.enums import JobKey, JobSource, JobStatus +from testgen.common.models.job_execution import PUBLIC_JOB_KEYS, JobExecution from testgen.common.models.table_group import TableGroup from testgen.common.models.test_suite import TestSuite @@ -19,7 +20,7 @@ 404: {"model": ErrorResponse, "description": "Not found"}, } -router = APIRouter(prefix="/api/v1", tags=["Jobs"], dependencies=[Depends(db_session)], responses=_error_responses) +router = APIRouter(tags=["Jobs"], dependencies=[Depends(db_session)], responses=_error_responses) @router.post( @@ -105,7 +106,7 @@ def list_jobs( """List job executions for a project, with optional filters and pagination.""" items, total = JobExecution.list_for_project( project_code, - JobExecution.source != "system", + JobExecution.job_key.in_(PUBLIC_JOB_KEYS), job_key=job_key, status=status, page=page, diff --git a/testgen/api/runs.py b/testgen/api/runs.py index 64110721..fcbe0445 100644 --- a/testgen/api/runs.py +++ b/testgen/api/runs.py @@ -26,7 +26,7 @@ 404: {"model": ErrorResponse, "description": "Not found"}, } -router = APIRouter(prefix="/api/v1", tags=["runs"], dependencies=[Depends(db_session)], responses=_error_responses) +router = APIRouter(tags=["runs"], dependencies=[Depends(db_session)], responses=_error_responses) @router.get( diff --git a/testgen/api/schemas.py b/testgen/api/schemas.py index 753152d8..2543f227 100644 --- a/testgen/api/schemas.py +++ b/testgen/api/schemas.py @@ -6,27 +6,11 @@ from pydantic import BaseModel, field_validator -from testgen.common.models.job_execution import JobStatus +from testgen.common.enums import JobKey, JobSource, JobStatus # --- Jobs --- -class JobKey(StrEnum): - run_profile = "run-profile" - run_tests = "run-tests" - run_monitors = "run-monitors" - run_test_generation = "run-test-generation" - - -class JobSource(StrEnum): - api = "api" - ui = "ui" - scheduler = "scheduler" - mcp = "mcp" - cli = "cli" - backfill = "backfill" - - class JobSubmittedResponse(BaseModel): """Returned on 202 Accepted after successful job submission.""" diff --git a/testgen/api/test_definitions.py b/testgen/api/test_definitions.py index 6b205cc0..6d2d0e41 100644 --- a/testgen/api/test_definitions.py +++ b/testgen/api/test_definitions.py @@ -21,7 +21,6 @@ } router = APIRouter( - prefix="/api/v1", tags=["Test Definitions"], dependencies=[Depends(db_session)], responses=_error_responses, diff --git a/testgen/commands/exec_job.py b/testgen/commands/exec_job.py index 4f63a494..48421b0e 100644 --- a/testgen/commands/exec_job.py +++ b/testgen/commands/exec_job.py @@ -11,9 +11,10 @@ from uuid import UUID from testgen.commands.job_registry import JOB_DISPATCH, run_final_callbacks +from testgen.common.enums import JobStatus from testgen.common.job_context import JobContext, job_context from testgen.common.models import database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.utils import get_exception_message LOG = logging.getLogger("testgen") @@ -35,8 +36,8 @@ def exec_job(job_execution_id: UUID) -> None: LOG.error("Job execution %s not found", job_execution_id) sys.exit(1) - handler = JOB_DISPATCH.get(job_exec.job_key) - if not handler: + job_config = JOB_DISPATCH.get(job_exec.job_key) + if not job_config: job_exec.mark_interrupted(f"Unknown job key: {job_exec.job_key}") return @@ -48,7 +49,7 @@ def exec_job(job_execution_id: UUID) -> None: with database_session(): job_exec = JobExecution.get(job_execution_id) job_context.set(JobContext(job_id=job_execution_id, source=job_exec.source)) - handler(**job_exec.kwargs) + job_config.handler(**job_exec.kwargs) with database_session(): job_exec = JobExecution.get(job_execution_id) diff --git a/testgen/commands/job_registry.py b/testgen/commands/job_registry.py index 45d5bfe7..3fa3dceb 100644 --- a/testgen/commands/job_registry.py +++ b/testgen/commands/job_registry.py @@ -1,7 +1,8 @@ """Wiring between the JobExecution engine and the concrete job handlers. Two registries keyed by `job_key`: - - `JOB_DISPATCH`: maps a job to its handler (`exec_job` resolves this). + - `JOB_DISPATCH`: maps a job to its `JobConfig` (handler + per-job metadata). + `exec_job` and the scheduler resolve this. - `JOB_FINAL_CALLBACKS`: maps a job to post-terminal-transition callbacks (notifications, follow-up job submissions). `run_final_callbacks` iterates. @@ -12,16 +13,19 @@ import logging from collections.abc import Callable +from dataclasses import dataclass from sqlalchemy import select +from testgen.commands.run_data_cleanup import run_data_cleanup from testgen.commands.run_profiling import run_profiling from testgen.commands.run_recalculate_project_scores import run_recalculate_project_scores from testgen.commands.run_score_update import run_score_update from testgen.commands.run_test_execution import run_test_execution from testgen.commands.test_generation import run_test_generation +from testgen.common.enums import JobKey, JobSource, JobStatus from testgen.common.models import database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.test_run import TestRun from testgen.common.notifications.monitor_run import send_monitor_notifications @@ -32,13 +36,31 @@ FinalCallback = Callable[[JobExecution], None] -JOB_DISPATCH: dict[str, Callable] = { - "run-profile": run_profiling, - "run-tests": run_test_execution, - "run-monitors": run_test_execution, - "run-test-generation": run_test_generation, - "run-score-update": run_score_update, - "recalculate-project-scores": run_recalculate_project_scores, + +@dataclass(frozen=True) +class JobConfig: + """Per-job-key registration metadata. + + `scheduler_source` is the value the scheduler tags `JobExecution.source` + with when it spawns this job key — ``"scheduler"`` for user-facing jobs, + ``"system"`` for system-internal jobs (e.g., retention cleanup). Read + only by the scheduler; direct `JobExecution.submit(source=...)` callers + (UI, follow-up enqueues, CLI) set their own source independently and do + not consult this field. + """ + + handler: Callable + scheduler_source: JobSource = JobSource.scheduler + + +JOB_DISPATCH: dict[JobKey, JobConfig] = { + JobKey.run_profile: JobConfig(handler=run_profiling), + JobKey.run_tests: JobConfig(handler=run_test_execution), + JobKey.run_monitors: JobConfig(handler=run_test_execution), + JobKey.run_test_generation: JobConfig(handler=run_test_generation), + JobKey.run_score_update: JobConfig(handler=run_score_update, scheduler_source=JobSource.system), + JobKey.recalculate_project_scores: JobConfig(handler=run_recalculate_project_scores, scheduler_source=JobSource.system), + JobKey.run_data_cleanup: JobConfig(handler=run_data_cleanup, scheduler_source=JobSource.system), } @@ -91,18 +113,18 @@ def _enqueue_score_update(job_exec: JobExecution) -> None: with database_session(): JobExecution.submit( - job_key="run-score-update", + job_key=JobKey.run_score_update, kwargs={ "parent_job_id": str(job_exec.id), "parent_job_key": job_exec.job_key, }, - source="system", + source=JobSource.system, project_code=job_exec.project_code, ) -JOB_FINAL_CALLBACKS: dict[str, list[FinalCallback]] = { - "run-profile": [_notify_profiling_run, _enqueue_score_update], - "run-tests": [_notify_test_run, _enqueue_score_update], - "run-monitors": [_notify_monitor_run], +JOB_FINAL_CALLBACKS: dict[JobKey, list[FinalCallback]] = { + JobKey.run_profile: [_notify_profiling_run, _enqueue_score_update], + JobKey.run_tests: [_notify_test_run, _enqueue_score_update], + JobKey.run_monitors: [_notify_monitor_run], } diff --git a/testgen/commands/job_runner.py b/testgen/commands/job_runner.py index 37a96dc5..f95584e6 100644 --- a/testgen/commands/job_runner.py +++ b/testgen/commands/job_runner.py @@ -11,8 +11,9 @@ from sqlalchemy import select from testgen.commands.exec_job import FINAL_STATUSES, POLL_INTERVAL +from testgen.common.enums import JobSource, JobStatus from testgen.common.models import database_session, get_current_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.test_run import TestRun @@ -34,7 +35,7 @@ def submit_and_wait( job_exec = JobExecution.submit( job_key=job_key, kwargs=kwargs, - source="cli", + source=JobSource.cli, project_code=project_code, ) job_id = job_exec.id diff --git a/testgen/commands/queries/execute_tests_query.py b/testgen/commands/queries/execute_tests_query.py index e81d95a9..03eab489 100644 --- a/testgen/commands/queries/execute_tests_query.py +++ b/testgen/commands/queries/execute_tests_query.py @@ -8,9 +8,15 @@ from testgen.common import read_template_sql_file from testgen.common.clean_sql import concat_columns -from testgen.common.database.database_service import get_flavor_service, get_tg_schema, replace_params +from testgen.common.database.database_service import ( + fetch_dict_from_db, + get_flavor_service, + get_tg_schema, + replace_params, +) from testgen.common.freshness_service import ( count_excluded_minutes, + get_freshness_gated_baseline, get_schedule_params, is_excluded_day, resolve_holiday_dates, @@ -264,6 +270,52 @@ def __init__(self, connection: Connection, table_group: TableGroup, test_suite: test_suite.holiday_codes_list, pd.DatetimeIndex([datetime(self.run_date.year - 1, 1, 1), datetime(self.run_date.year + 1, 12, 31)]), ) + # Cache of (schema, table) -> "did Freshness_Trend detect a fingerprint change in + # this run?". True / False / None (no Freshness_Trend result). Populated lazily + # per table; reused across all Volume/Metric defs for the same table. + self._freshness_changed_cache: dict[tuple[str, str], bool | None] = {} + + def _freshness_changed_for_table(self, test_def: TestExecutionDef) -> bool | None: + """Did Freshness_Trend detect a fingerprint change for the test's table in this run? + + Reads the latest Freshness_Trend result_signal written during the current run. + Freshness_Trend emits `result_signal = '0'` when the table fingerprint differs + from the previous run's baseline (i.e., an update was detected). Any other value + (the interval since last update) means no change. + + Returns True / False per the signal, or None if no Freshness_Trend result exists + for this table in this run. + """ + cache_key = (test_def.schema_name, test_def.table_name) + if cache_key in self._freshness_changed_cache: + return self._freshness_changed_cache[cache_key] + + rows = fetch_dict_from_db(*self._get_query("get_current_freshness_signal.sql", test_def=test_def)) + changed: bool | None = None + if rows and rows[0].get("result_signal") is not None: + changed = str(rows[0]["result_signal"]) == "0" + self._freshness_changed_cache[cache_key] = changed + return changed + + def _resolve_cat_operator_and_condition(self, test_def: TestExecutionDef) -> tuple[str, str]: + """Pick the operator / condition pair to feed into build_cat_expressions. + + For Volume_Trend / Metric_Trend with freshness-gating enabled, when Freshness_Trend + detected no change in this run the table is in a "stale period" — the measure must + equal baseline_value, and any deviation is a silent-write anomaly. In that case, + override the test definition's `NOT BETWEEN` band check with a strict equality + check against baseline_value. All other cases (band check, refresh detected, + non-monitor test types, or no Freshness_Trend result) keep the test definition's + own operator and condition. + """ + if ( + test_def.test_type in ("Volume_Trend", "Metric_Trend") + and (baseline := get_freshness_gated_baseline(test_def.prediction)) is not None + and self._freshness_changed_for_table(test_def) is False + ): + return "<>", str(baseline) + + return test_def.test_operator, test_def.test_condition def _get_input_parameters(self, test_def: TestExecutionDef) -> str: return "; ".join( @@ -458,21 +510,21 @@ def aggregate_cat_tests( ) -> tuple[list[tuple[str, None]], list[list[TestExecutionDef]]]: varchar_type = self.flavor_service.varchar_type concat_operator = self.flavor_service.concat_operator - quote = self.flavor_service.quote_character for td in test_defs: # Don't recalculate expressions if it was already done before if not td.measure_expression or not td.condition_expression: params = self._get_params(td) + operator, condition_template = self._resolve_cat_operator_and_condition(td) measure = replace_params(td.measure, params) measure = replace_templated_functions(measure, self.flavor) - condition = replace_params(td.test_condition, params) + condition = replace_params(condition_template, params) condition = replace_templated_functions(condition, self.flavor) td.measure_expression, td.condition_expression = build_cat_expressions( measure=measure, - test_operator=td.test_operator, + test_operator=operator, test_condition=condition, history_calculation=td.history_calculation, lower_tolerance=td.lower_tolerance, @@ -492,7 +544,7 @@ def aggregate_cat_tests( f"SELECT {len(aggregate_queries)} AS query_index, " f"{concat_operator.join([td.measure_expression for td in group])} AS result_measures, " f"{concat_operator.join([td.condition_expression for td in group])} AS result_codes " - f"FROM {quote}{group[0].schema_name}{quote}.{quote}{group[0].table_name}{quote}" + f"FROM {self.flavor_service.get_table_ref(group[0].schema_name, group[0].table_name)}" ) query = query.replace(":", "\\:") diff --git a/testgen/commands/queries/profiling_query.py b/testgen/commands/queries/profiling_query.py index d3f02a16..f7a72aac 100644 --- a/testgen/commands/queries/profiling_query.py +++ b/testgen/commands/queries/profiling_query.py @@ -1,8 +1,8 @@ import dataclasses from uuid import UUID -from testgen.commands.queries.refresh_data_chars_query import ColumnChars from testgen.common import read_template_sql_file +from testgen.common.database.column_chars import ColumnChars from testgen.common.database.database_service import process_conditionals, replace_params from testgen.common.models.connection import Connection from testgen.common.models.profiling_run import ProfilingRun diff --git a/testgen/commands/queries/refresh_data_chars_query.py b/testgen/commands/queries/refresh_data_chars_query.py index e5d72fa5..762e7261 100644 --- a/testgen/commands/queries/refresh_data_chars_query.py +++ b/testgen/commands/queries/refresh_data_chars_query.py @@ -1,28 +1,20 @@ -import dataclasses +import re from collections.abc import Iterable from datetime import datetime from testgen.common import read_template_sql_file +from testgen.common.database.column_chars import ColumnChars from testgen.common.database.database_service import get_flavor_service, replace_params from testgen.common.models.connection import Connection from testgen.common.models.table_group import TableGroup from testgen.utils import chunk_queries, to_sql_timestamp -@dataclasses.dataclass -class ColumnChars: - schema_name: str - table_name: str - column_name: str - ordinal_position: int = None - general_type: str = None - column_type: str = None - db_data_type: str = None - is_decimal: bool = False - approx_record_ct: int = None - # This should not default to 0 since we don't always retrieve actual row counts - # UI relies on the null value to know that the approx_record_ct should be displayed instead - record_ct: int = None +def _like_to_regex(pattern: str) -> re.Pattern[str]: + # Mirrors SQL LIKE semantics used in _get_table_criteria: `%` is the only + # wildcard; `_` is treated as a literal character (escaped to `\_` in the + # SQL path). Anything else is literal. + return re.compile("^" + re.escape(pattern.strip()).replace("%", ".*") + "$") class RefreshDataCharsSQL: @@ -100,6 +92,32 @@ def _get_table_criteria(self) -> str: return table_criteria + def filter_schema_columns(self, columns: list[ColumnChars]) -> list[ColumnChars]: + """Apply the table group's filters (table set, include/exclude masks) to a column list. + + Mirrors `_get_table_criteria` for flavors that bypass the SQL template path + (e.g., Salesforce Data 360, where columns come from the metadata API). + """ + result = columns + + if self.table_group.profiling_table_set: + allowed = {item.strip() for item in self.table_group.profiling_table_set.split(",")} + result = [c for c in result if c.table_name in allowed] + + if self.table_group.profiling_include_mask: + include_patterns = [ + _like_to_regex(item) for item in self.table_group.profiling_include_mask.split(",") + ] + result = [c for c in result if any(p.match(c.table_name) for p in include_patterns)] + + if self.table_group.profiling_exclude_mask: + exclude_patterns = [ + _like_to_regex(item) for item in self.table_group.profiling_exclude_mask.split(",") + ] + result = [c for c in result if not any(p.match(c.table_name) for p in exclude_patterns)] + + return result + def get_schema_ddf(self) -> tuple[str, dict]: # Runs on Target database return self._get_query( @@ -111,9 +129,8 @@ def get_schema_ddf(self) -> tuple[str, dict]: def get_row_counts(self, table_names: Iterable[str]) -> list[tuple[str, None]]: # Runs on Target database schema = self.table_group.table_group_schema - quote = self.flavor_service.quote_character count_queries = [ - f"SELECT '{table}' AS table_name, COUNT(*) AS row_count FROM {quote}{schema}{quote}.{quote}{table}{quote}" + f"SELECT '{table}' AS table_name, COUNT(*) AS row_count FROM {self.flavor_service.get_table_ref(schema, table)}" for table in table_names ] chunked_queries = chunk_queries(count_queries, " UNION ALL ", self.connection.max_query_chars) @@ -122,14 +139,9 @@ def get_row_counts(self, table_names: Iterable[str]) -> list[tuple[str, None]]: def verify_access(self, table_name: str) -> tuple[str, None]: # Runs on Target database schema = self.table_group.table_group_schema - quote = self.flavor_service.quote_character - table_ref = f"{quote}{schema}{quote}.{quote}{table_name}{quote}" - if (row_limiting := self.flavor_service.row_limiting_clause) == "top": - query = f"SELECT TOP 1 * FROM {table_ref}" - elif row_limiting == "fetch": - query = f"SELECT 1 FROM {table_ref} FETCH FIRST 1 ROWS ONLY" - else: - query = f"SELECT 1 FROM {table_ref} LIMIT 1" + table_ref = self.flavor_service.get_table_ref(schema, table_name) + prefix, suffix = self.flavor_service.row_limit_clauses(1) + query = f"SELECT {prefix} 1 FROM {table_ref} {suffix}".strip() return (query, None) def get_staging_data_chars(self, data_chars: list[ColumnChars], run_date: datetime) -> list[list[str | bool | int]]: diff --git a/testgen/commands/run_data_cleanup.py b/testgen/commands/run_data_cleanup.py new file mode 100644 index 00000000..08e6341c --- /dev/null +++ b/testgen/commands/run_data_cleanup.py @@ -0,0 +1,121 @@ +"""Per-project data retention cleanup. + +Deletes profiling runs, test runs, and their child results older than the +project's retention period, plus aged-out staging, score history, and +job_execution records. + +Always preserves the most recent profiling run per table group and the most +recent test run per test suite (including monitor suites). Profiling is +expensive and tends to run infrequently; downstream features — test +generation, freshness monitor generation, data catalog, and MCP analysis +tools — depend on the most recent profiling result for a table group, so +the project must always retain a baseline regardless of retention period +or run cadence. +""" + +import logging +from datetime import UTC, datetime, timedelta + +from testgen.common.models import database_session +from testgen.common.models.job_execution import JobExecution +from testgen.common.models.profiling_run import ProfilingRun +from testgen.common.models.scores import ScoreDefinitionResultHistoryEntry, ScoreHistoryLatestRun +from testgen.common.models.stg_data_chars_update import StgDataCharsUpdate +from testgen.common.models.stg_functional_table_update import StgFunctionalTableUpdate +from testgen.common.models.stg_secondary_profile_update import StgSecondaryProfileUpdate +from testgen.common.models.stg_test_definition_update import StgTestDefinitionUpdate +from testgen.common.models.test_run import TestRun + +LOG = logging.getLogger("testgen") + +BATCH_SIZE = 1000 + + +def run_data_cleanup(project_code: str, retention_days: int) -> None: + started_at = datetime.now(UTC) + cutoff = started_at - timedelta(days=retention_days) + LOG.info( + "Data retention cleanup started: project=%s retention_days=%d cutoff=%s", + project_code, retention_days, cutoff.isoformat(), + ) + + with database_session(): + protected_profiling_ids = ProfilingRun.find_latest_per_table_group(project_code) + protected_test_run_ids = TestRun.find_latest_per_test_suite(project_code) + # Translate protected run ids → their job_execution_ids so the JE sweep + # can carve them out. Nulls (older runs without a JE) are filtered here. + je_map = { + **ProfilingRun.get_job_execution_ids(list(protected_profiling_ids)), + **TestRun.get_job_execution_ids(list(protected_test_run_ids)), + } + protected_job_execution_ids = {je for je in je_map.values() if je is not None} + + LOG.info( + "Protected latest runs: profiling=%d test=%d job_executions=%d", + len(protected_profiling_ids), len(protected_test_run_ids), len(protected_job_execution_ids), + ) + + # Each delete owns its per-batch transactions internally — committing + # between batches releases locks and bounds WAL growth for large sweeps. + deleted_profiling = ProfilingRun.delete_older_than( + cutoff=cutoff, + project_code=project_code, + protected_ids=protected_profiling_ids, + batch_size=BATCH_SIZE, + ) + + deleted_tests = TestRun.delete_older_than( + cutoff=cutoff, + project_code=project_code, + protected_ids=protected_test_run_ids, + batch_size=BATCH_SIZE, + ) + + deleted_job_executions = JobExecution.delete_older_than( + cutoff=cutoff, + project_code=project_code, + protected_ids=protected_job_execution_ids, + batch_size=BATCH_SIZE, + ) + + # Score history: read protected mapping keys BEFORE deleting from either + # table — we need score_history_latest_runs intact to compute the carve-out + # for score_definition_results_history. + with database_session(): + protected_history_keys = ScoreHistoryLatestRun.find_protected_keys( + protected_profiling_ids=protected_profiling_ids, + protected_test_run_ids=protected_test_run_ids, + ) + + deleted_score_history = ScoreDefinitionResultHistoryEntry.delete_older_than( + cutoff=cutoff, + project_code=project_code, + protected_keys=protected_history_keys, + batch_size=BATCH_SIZE, + ) + + deleted_score_latest = ScoreHistoryLatestRun.delete_older_than( + cutoff=cutoff, + project_code=project_code, + protected_keys=protected_history_keys, + batch_size=BATCH_SIZE, + ) + + # Staging tables: defensive cleanup of orphans left behind by failed jobs. + # No carve-out — these are transient operational rows with no run linkage. + with database_session(): + deleted_stg = ( + StgSecondaryProfileUpdate.delete_older_than(cutoff, project_code) + + StgFunctionalTableUpdate.delete_older_than(cutoff, project_code) + + StgDataCharsUpdate.delete_older_than(cutoff, project_code) + + StgTestDefinitionUpdate.delete_older_than(cutoff, project_code) + ) + + elapsed = (datetime.now(UTC) - started_at).total_seconds() + LOG.info( + "Data retention cleanup complete: project=%s " + "deleted_profiling=%d deleted_tests=%d deleted_job_executions=%d " + "deleted_score_history=%d deleted_score_latest=%d deleted_staging=%d elapsed=%.1fs", + project_code, deleted_profiling, deleted_tests, deleted_job_executions, + deleted_score_history, deleted_score_latest, deleted_stg, elapsed, + ) diff --git a/testgen/commands/run_launch_db_config.py b/testgen/commands/run_launch_db_config.py index f7b69bcb..824d200d 100644 --- a/testgen/commands/run_launch_db_config.py +++ b/testgen/commands/run_launch_db_config.py @@ -7,6 +7,7 @@ from testgen.common.database.database_service import get_queries_for_command from testgen.common.encrypt import EncryptText, encrypt_ui_password from testgen.common.models import with_database_session +from testgen.common.models.user import initial_feedback_popup_seed from testgen.common.read_file import get_template_files from testgen.common.read_yaml_metadata_records import import_metadata_records_from_yaml from testgen.common.standalone_postgres import EMBEDDED_HOST_SENTINEL, is_standalone_mode @@ -40,6 +41,7 @@ def _get_params_mapping() -> dict: "UI_USER_USERNAME": settings.USERNAME, "UI_USER_EMAIL": "", "UI_USER_ENCRYPTED_PASSWORD": ui_user_encrypted_password, + "LAST_FEEDBACK_POPUP_SEED": initial_feedback_popup_seed(), "SCHEMA_NAME": get_tg_schema(), "PROJECT_CODE": settings.PROJECT_KEY, "CONNECTION_ID": 1, diff --git a/testgen/commands/run_observability_exporter.py b/testgen/commands/run_observability_exporter.py index 71179e9d..e0339026 100644 --- a/testgen/commands/run_observability_exporter.py +++ b/testgen/commands/run_observability_exporter.py @@ -318,6 +318,7 @@ def run_observability_exporter(project_code, test_suite): test_suites = TestSuite.select_minimal_where( TestSuite.project_code == project_code, TestSuite.test_suite == test_suite, + TestSuite.is_monitor.isnot(True), ) qty_of_exported_events = export_test_results(test_suites[0].id) click.echo(f"{qty_of_exported_events} events have been exported.") diff --git a/testgen/commands/run_profiling.py b/testgen/commands/run_profiling.py index 2125defc..ffd1a58f 100644 --- a/testgen/commands/run_profiling.py +++ b/testgen/commands/run_profiling.py @@ -9,7 +9,6 @@ TableSampling, calculate_sampling_params, ) -from testgen.commands.queries.refresh_data_chars_query import ColumnChars from testgen.commands.run_refresh_data_chars import run_data_chars_refresh from testgen.commands.test_generation import run_monitor_generation, run_test_generation from testgen.common import ( @@ -19,6 +18,7 @@ set_target_db_params, write_to_app_db, ) +from testgen.common.database.column_chars import ColumnChars from testgen.common.database.database_service import ThreadedProgress from testgen.common.job_context import job_context from testgen.common.mixpanel_service import MixpanelService diff --git a/testgen/commands/run_quick_start.py b/testgen/commands/run_quick_start.py index 5f7389bc..2d56f312 100644 --- a/testgen/commands/run_quick_start.py +++ b/testgen/commands/run_quick_start.py @@ -19,9 +19,10 @@ set_target_db_params, ) from testgen.common.database.flavor.flavor_service import ConnectionParams +from testgen.common.enums import JobSource, JobStatus from testgen.common.job_context import JobContext, job_context from testgen.common.models import database_session, with_database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.scores import ScoreDefinition from testgen.common.models.settings import PersistedSetting from testgen.common.models.table_group import TableGroup @@ -51,7 +52,7 @@ def run_with_job_execution( effective_date = run_date or datetime.now(UTC) wall_start = datetime.now(UTC) # Match the source a real trigger would use so demo data mirrors production attribution. - source = "scheduler" if job_key == "run-monitors" else "ui" + source = JobSource.scheduler if job_key == "run-monitors" else JobSource.ui with database_session() as session: je = JobExecution( @@ -67,7 +68,7 @@ def run_with_job_execution( je_id = je.id job_context.set(JobContext(job_id=je_id, source=source)) - JOB_DISPATCH[job_key](**handler_kwargs, run_date=run_date) + JOB_DISPATCH[job_key].handler(**handler_kwargs, run_date=run_date) with database_session(): je = JobExecution.get(je_id) diff --git a/testgen/commands/run_refresh_data_chars.py b/testgen/commands/run_refresh_data_chars.py index 94f9b3e0..c4b483ce 100644 --- a/testgen/commands/run_refresh_data_chars.py +++ b/testgen/commands/run_refresh_data_chars.py @@ -1,13 +1,15 @@ import logging from datetime import datetime -from testgen.commands.queries.refresh_data_chars_query import ColumnChars, RefreshDataCharsSQL +from testgen.commands.queries.refresh_data_chars_query import RefreshDataCharsSQL +from testgen.common.database.column_chars import ColumnChars from testgen.common.database.database_service import ( execute_db_queries, fetch_dict_from_db, fetch_from_db_threaded, write_to_app_db, ) +from testgen.common.database.flavor.flavor_service import resolve_connection_params from testgen.common.models.connection import Connection from testgen.common.models.table_group import TableGroup from testgen.utils import get_exception_message @@ -20,16 +22,25 @@ def run_data_chars_refresh(connection: Connection, table_group: TableGroup, run_ LOG.info("Getting DDF for table group") try: - data_chars = fetch_dict_from_db(*sql_generator.get_schema_ddf(), use_target_db=True) + if sql_generator.flavor_service.metadata_via_api: + # Flavor returns column metadata directly (e.g., Salesforce Data 360 + # via the connector's metadata API). These flavors have no information_schema. + # Apply the table-group filters in Python + # since we bypass the SQL {TABLE_CRITERIA} clause. + params = resolve_connection_params(connection.__dict__) + api_columns = sql_generator.flavor_service.get_schema_columns(params, table_group.table_group_schema) or [] + data_chars = sql_generator.filter_schema_columns(api_columns) + else: + rows = fetch_dict_from_db(*sql_generator.get_schema_ddf(), use_target_db=True) + data_chars = [ColumnChars(**row) for row in rows] except Exception as e: raise RuntimeError(f"Error refreshing columns for data catalog. {get_exception_message(e)}") from e - - data_chars = [ColumnChars(**column) for column in data_chars] + if data_chars: distinct_tables = {column.table_name for column in data_chars} LOG.info(f"Tables: {len(distinct_tables)}, Columns: {len(data_chars)}") count_queries = sql_generator.get_row_counts(distinct_tables) - + LOG.info("Getting row counts for table group") count_results, _, error_data = fetch_from_db_threaded( count_queries, use_target_db=True, max_threads=connection.max_threads, diff --git a/testgen/commands/run_refresh_score_cards_results.py b/testgen/commands/run_refresh_score_cards_results.py index 3a6a71f3..ee8800f8 100644 --- a/testgen/commands/run_refresh_score_cards_results.py +++ b/testgen/commands/run_refresh_score_cards_results.py @@ -3,6 +3,7 @@ import time from testgen.common.models import get_current_session, with_database_session +from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.scores import ( SCORE_CATEGORIES, ScoreCard, @@ -11,6 +12,7 @@ ScoreDefinitionResult, ScoreDefinitionResultHistoryEntry, ) +from testgen.common.models.test_run import TestRun from testgen.common.notifications.score_drop import collect_score_notification_data, send_score_drop_notifications LOG = logging.getLogger("testgen") @@ -169,3 +171,48 @@ def run_recalculate_score_card(*, project_code: str, definition_id: str): project_code, round(end_time - start_time, 2), ) + + +@with_database_session +def save_and_refresh_score_definition( + score_definition: ScoreDefinition, + *, + is_new: bool, +) -> ScoreDefinition: + """Save a scorecard and refresh / recalculate its cached scores. + + Owns the persist-then-refresh orchestration shared by the Score Explorer UI + and the ``update_scorecard`` MCP tool. UI-only concerns (Streamlit cache + clear, navigation, toasts) stay in the view layer. + + For new scorecards (``is_new=True``), seeds the first refresh with a + history entry timestamped at the latest profiling or test run for the + project, so the trend chart has an anchor point. For existing scorecards, + also runs ``run_recalculate_score_card`` to update history entries whose + scores might have shifted under the new filters. + """ + refresh_kwargs: dict = {} + if is_new: + # tz-aware sentinel: run_time is stored as TIMESTAMP(timezone=True), so a naive + # min would raise on comparison when only one of the two runs exists. + epoch = datetime.datetime.min.replace(tzinfo=datetime.UTC) + latest_run = max( + ( + ProfilingRun.get_latest_run(score_definition.project_code), + TestRun.get_latest_run(score_definition.project_code), + ), + key=lambda run: getattr(run, "run_time", epoch), + ) + refresh_kwargs = { + "add_history_entry": True, + "refresh_date": latest_run.run_time if latest_run else None, + } + + score_definition.save() + run_refresh_score_cards_results(definition_id=score_definition.id, **refresh_kwargs) + if not is_new: + run_recalculate_score_card( + project_code=score_definition.project_code, + definition_id=score_definition.id, + ) + return score_definition diff --git a/testgen/commands/run_test_execution.py b/testgen/commands/run_test_execution.py index 568a463d..65d759df 100644 --- a/testgen/commands/run_test_execution.py +++ b/testgen/commands/run_test_execution.py @@ -110,7 +110,12 @@ def run_test_execution( "METADATA": partial(_run_tests, sql_generator, "METADATA"), "CAT": partial(_run_cat_tests, sql_generator), } - # Run metadata tests last so that results for other tests are available to them + # Run order: QUERY → CAT → METADATA is load-bearing for monitor suites. + # Freshness_Trend (QUERY) writes the table fingerprint to test_results, which + # Volume_Trend and Metric_Trend (both CAT) read at execution time to apply + # freshness-gated thresholds (see TestExecutionSQL._get_params). Metadata tests + # stay last so results for other tests are available to them. Do not reorder + # without revisiting freshness-gating in the SQL templates and exec params. for run_type in ["QUERY", "CAT", "METADATA"]: if (run_test_defs := [td for td in valid_test_defs if td.run_type == run_type]): run_functions[run_type](run_test_defs, save_progress=not test_suite.is_monitor) diff --git a/testgen/commands/run_test_validation.py b/testgen/commands/run_test_validation.py index db247676..2c1a89a0 100644 --- a/testgen/commands/run_test_validation.py +++ b/testgen/commands/run_test_validation.py @@ -4,7 +4,9 @@ from testgen.commands.queries.execute_tests_query import TestExecutionDef, TestExecutionSQL from testgen.common import execute_db_queries, fetch_dict_from_db +from testgen.common.database.column_chars import ColumnChars from testgen.common.database.database_service import write_to_app_db +from testgen.common.database.flavor.flavor_service import resolve_connection_params LOG = logging.getLogger("testgen") @@ -79,6 +81,47 @@ def add_error(test_id: UUID, error: str) -> None: return identifiers_to_check, target_schemas, errors +def get_target_identifiers( + sql_generator: TestExecutionSQL, + target_schemas: set[str], +) -> tuple[set[tuple[str, str]], set[tuple[str, str, str]]]: + """Fetch (schema, table) and (schema, table, column) sets for validation. + + Flavors with ``metadata_via_api=True`` (e.g., Salesforce Data 360) + use ``get_schema_columns()`` — these flavors have no ``information_schema``. + Other flavors use the SQL template path. + """ + flavor_service = sql_generator.flavor_service + + if flavor_service.metadata_via_api: + params = resolve_connection_params(sql_generator.connection.__dict__) + api_columns: list[ColumnChars] = [] + for schema in target_schemas: + cols = flavor_service.get_schema_columns(params, schema) or [] + api_columns.extend(cols) + LOG.info("Got tables and columns from flavor metadata API for validation") + target_tables = {(c.schema_name.lower(), c.table_name.lower()) for c in api_columns} + target_columns = { + (c.schema_name.lower(), c.table_name.lower(), c.column_name.lower()) for c in api_columns + } + return target_tables, target_columns + + LOG.info("Getting tables and columns in target schemas for validation") + target_identifiers = fetch_dict_from_db( + *sql_generator.get_target_identifiers(target_schemas), + use_target_db=True, + ) + if not target_identifiers: + LOG.info("No tables or columns present in target schemas") + + target_tables = {(item["schema_name"].lower(), item["table_name"].lower()) for item in target_identifiers} + target_columns = { + (item["schema_name"].lower(), item["table_name"].lower(), item["column_name"].lower()) + for item in target_identifiers + } + return target_tables, target_columns + + def check_identifiers( identifiers_to_check: dict[tuple[str, str, str | None], set[UUID]], target_tables: set[tuple[str, str]], @@ -130,21 +173,7 @@ def run_test_validation( test_defs_by_id[test_id].errors = error_list if target_schemas: - LOG.info("Getting tables and columns in target schemas for validation") - target_identifiers = fetch_dict_from_db( - *sql_generator.get_target_identifiers(target_schemas), - use_target_db=True, - ) - if not target_identifiers: - LOG.info("No tables or columns present in target schemas") - - # Normalize identifiers before validating - target_tables = {(item["schema_name"].lower(), item["table_name"].lower()) for item in target_identifiers} - target_columns = { - (item["schema_name"].lower(), item["table_name"].lower(), item["column_name"].lower()) - for item in target_identifiers - } - + target_tables, target_columns = get_target_identifiers(sql_generator, target_schemas) check_errors = check_identifiers(identifiers_to_check, target_tables, target_columns) for test_id, error_list in check_errors.items(): if not test_defs_by_id[test_id].errors: diff --git a/testgen/commands/test_thresholds_prediction.py b/testgen/commands/test_thresholds_prediction.py index 7f6617ee..7c501e98 100644 --- a/testgen/commands/test_thresholds_prediction.py +++ b/testgen/commands/test_thresholds_prediction.py @@ -87,11 +87,18 @@ def run(self) -> None: df = to_dataframe(test_results, coerce_float=True) grouped_dfs = df.groupby("test_definition_id", group_keys=False) + # Freshness update events are fetched as secondary data only when the suite + # is a monitor — Volume_Trend / Metric_Trend in monitor suites couple to the + # Freshness_Trend signal to avoid stairstep false positives. + freshness_updates_by_table: dict[tuple[str, str], list[str]] = ( + self._fetch_freshness_updates_by_table() if self.test_suite.is_monitor else {} + ) + LOG.info(f"Training prediction models for tests: {len(grouped_dfs)}") prediction_results = [] for test_def_id, group in grouped_dfs: test_type = group["test_type"].iloc[0] - history = group[["test_time", "result_signal"]] + history = group[["test_time", "result_signal", "test_run_id"]] history = history.set_index("test_time") test_prediction = [ @@ -99,30 +106,42 @@ def run(self) -> None: test_def_id, to_sql_timestamp(self.run_date), ] - if test_type == "Freshness_Trend": + # Skip prediction if history is smaller than configured lookback + if len(history) < (self.test_suite.predict_min_lookback or 1): + test_prediction.extend([None, None, None, None]) + elif test_type == "Freshness_Trend": lower, upper, staleness, prediction = compute_freshness_threshold( history, sensitivity=self.test_suite.predict_sensitivity or PredictSensitivity.medium, - min_lookback=self.test_suite.predict_min_lookback or 1, exclude_weekends=self.test_suite.predict_exclude_weekends, holiday_codes=self.test_suite.holiday_codes_list, schedule_tz=self.tz, ) test_prediction.extend([lower, upper, staleness, prediction]) - else: - lower, upper, prediction = compute_sarimax_threshold( + elif test_type in ("Volume_Trend", "Metric_Trend"): + table_key = (group["schema_name"].iloc[0], group["table_name"].iloc[0]) + lower, upper, baseline, prediction = compute_volume_or_metric_threshold( history, + freshness_updates=freshness_updates_by_table.get(table_key, []), sensitivity=self.test_suite.predict_sensitivity or PredictSensitivity.medium, - min_lookback=self.test_suite.predict_min_lookback or 1, exclude_weekends=self.test_suite.predict_exclude_weekends, holiday_codes=self.test_suite.holiday_codes_list, schedule_tz=self.tz, ) if test_type == "Volume_Trend": - if lower is not None: + if lower is not None: lower = max(lower, 0.0) if upper is not None: upper = max(upper, 0.0) + test_prediction.extend([lower, upper, baseline, prediction]) + else: + lower, upper, prediction = compute_sarimax_threshold( + history, + sensitivity=self.test_suite.predict_sensitivity or PredictSensitivity.medium, + exclude_weekends=self.test_suite.predict_exclude_weekends, + holiday_codes=self.test_suite.holiday_codes_list, + schedule_tz=self.tz, + ) test_prediction.extend([lower, upper, None, prediction]) prediction_results.append(test_prediction) @@ -149,11 +168,21 @@ def _get_query( query = replace_params(query, params) return query, params + def _fetch_freshness_updates_by_table( + self, + ) -> dict[tuple[str, str], list[str]]: + """Fetch test_run_ids of Freshness_Trend fingerprint changes, indexed by table.""" + rows = fetch_dict_from_db(*self._get_query("get_freshness_fingerprint_events.sql")) + events_by_table: dict[tuple[str, str], list[str]] = {} + for row in rows: + key = (row["schema_name"], row["table_name"]) + events_by_table.setdefault(key, []).append(str(row["test_run_id"])) + return events_by_table + def compute_freshness_threshold( history: pd.DataFrame, sensitivity: PredictSensitivity, - min_lookback: int = 1, exclude_weekends: bool = False, holiday_codes: list[str] | None = None, schedule_tz: str | None = None, @@ -163,9 +192,6 @@ def compute_freshness_threshold( Returns (lower, upper, staleness_threshold, prediction_json) in business minutes, or (None, None, None, None) if not enough data. """ - if len(history) < min_lookback: - return None, None, None, None - upper_percentile, floor_multiplier, lower_percentile = FRESHNESS_THRESHOLD_MAP[sensitivity] staleness_factor = STALENESS_FACTOR_MAP[sensitivity] @@ -264,7 +290,6 @@ def compute_sarimax_threshold( history: pd.DataFrame, sensitivity: PredictSensitivity, num_forecast: int = NUM_FORECAST, - min_lookback: int = 1, exclude_weekends: bool = False, holiday_codes: list[str] | None = None, schedule_tz: str | None = None, @@ -273,12 +298,9 @@ def compute_sarimax_threshold( Returns (lower, upper, forecast_json) or (None, None, None) if insufficient data. """ - if len(history) < min_lookback: - return None, None, None - try: forecast = get_sarimax_forecast( - history, + history[["result_signal"]], # SARIMAX only consumes result_signal - drop other columns num_forecast=num_forecast, exclude_weekends=exclude_weekends, holiday_codes=holiday_codes, @@ -305,3 +327,59 @@ def compute_sarimax_threshold( return float(lower_tolerance), float(upper_tolerance), forecast.to_json() except NotEnoughData: return None, None, None + + +def compute_volume_or_metric_threshold( + history: pd.DataFrame, + freshness_updates: list[str], + sensitivity: PredictSensitivity, + num_forecast: int = NUM_FORECAST, + exclude_weekends: bool = False, + holiday_codes: list[str] | None = None, + schedule_tz: str | None = None, +) -> tuple[float | None, float | None, float | None, str | None]: + """SARIMAX threshold for Volume_Trend / Metric_Trend with freshness-gating. + + First, attempts a SARIMAX fit on the value series filtered only to points with freshness updates. + This avoids the "stairstep" false-positive shape where inter-change plateaus collapse the SE estimate. + The returned prediction JSON is augmented with `freshness_gated` and `baseline_value` so + that test execution can apply dual-branch evaluation. + + If the filtered fit fails for any reason, falls back to fit SARIMAX on + the raw value series and emits a prediction JSON without the freshness-gating markers. + + `history` is expected to have a `test_run_id` column alongside `result_signal`, and to be + indexed by `test_time`. `freshness_updates` is the list of run identifiers where + Freshness_Trend detected a fingerprint change. + """ + filtered_history = history.loc[history["test_run_id"].astype(str).isin(freshness_updates)] + lower, upper, prediction = compute_sarimax_threshold( + filtered_history, + sensitivity=sensitivity, + num_forecast=num_forecast, + exclude_weekends=exclude_weekends, + holiday_codes=holiday_codes, + schedule_tz=schedule_tz, + ) + if prediction is not None: + # Pull the baseline value from the most-recent filtered row. + last_update_ts = filtered_history.index.max() + baseline_value = filtered_history.loc[last_update_ts, "result_signal"] + baseline_value = float(baseline_value) if not pd.isna(baseline_value) else None + prediction_dict = json.loads(prediction) + prediction_dict.update({ + "freshness_gated": True, + "baseline_value": baseline_value, + }) + prediction = json.dumps(prediction_dict) + return lower, upper, baseline_value, prediction + + lower, upper, prediction = compute_sarimax_threshold( + history, + sensitivity=sensitivity, + num_forecast=num_forecast, + exclude_weekends=exclude_weekends, + holiday_codes=holiday_codes, + schedule_tz=schedule_tz, + ) + return lower, upper, None, prediction diff --git a/testgen/common/cron_service.py b/testgen/common/cron_service.py new file mode 100644 index 00000000..afe75485 --- /dev/null +++ b/testgen/common/cron_service.py @@ -0,0 +1,50 @@ +import zoneinfo +from datetime import datetime +from typing import TypedDict + +import cron_converter +import cron_descriptor + + +class CronSample(TypedDict, total=False): + id: str | None + error: str | None + samples: list[str] | list[int] | None + readable_expr: str | None + + +def get_cron_sample( + cron_expr: str, + cron_tz: str, + sample_count: int, + *, + reference_time: datetime | None = None, + formatted: bool = False, +) -> CronSample: + try: + cron_obj = cron_converter.Cron(cron_expr) + cron_schedule = cron_obj.schedule(reference_time or datetime.now(zoneinfo.ZoneInfo(cron_tz))) + readable_cron_schedule = cron_descriptor.get_description(cron_expr) + if formatted: + samples = [cron_schedule.next().strftime("%a %b %-d, %-I:%M %p") for _ in range(sample_count)] + else: + samples = [int(cron_schedule.next().timestamp()) for _ in range(sample_count)] + except zoneinfo.ZoneInfoNotFoundError: + return {"error": f"Unknown timezone `{cron_tz}`. Use an IANA name (e.g. `America/New_York`)."} + except ValueError as e: + return {"error": str(e)} + except Exception: + return {"error": "Error validating the Cron expression"} + else: + return { + "samples": samples, + "readable_expr": readable_cron_schedule, + } + + +def describe_cron(cron_expr: str) -> str | None: + """Human-readable description of a cron expression, e.g. ``At 04:00 AM``. Returns ``None`` if unparseable.""" + try: + return cron_descriptor.get_description(cron_expr) + except Exception: + return None diff --git a/testgen/common/custom_test_validation.py b/testgen/common/custom_test_validation.py new file mode 100644 index 00000000..bf2fb85b --- /dev/null +++ b/testgen/common/custom_test_validation.py @@ -0,0 +1,59 @@ +"""Shared validation for custom-test SQL queries. + +Wraps user-supplied SQL in a parent ``SELECT COUNT(*) FROM () ERR_TABLE`` form +matching the test execution runtime, then runs it against the target database. Optional +preview returns the first N rows for inspection. + +Wrapping serves two purposes: +- Validation parity with runtime — a bare query that runs may still fail when wrapped. +- DDL/DML rejection — non-SELECT statements fail to parse as a subquery. +""" + +from dataclasses import dataclass, field + +from sqlalchemy.engine import RowMapping + +from testgen.common.database.database_service import get_flavor_service, replace_params +from testgen.common.models.connection import Connection +from testgen.ui.services.database_service import fetch_from_target_db + + +@dataclass +class CustomQueryResult: + """Outcome of running a wrapped custom-test SQL query.""" + + row_count: int + preview_rows: list[RowMapping] = field(default_factory=list) + + +def validate_custom_query( + connection: Connection, + schema: str, + custom_sql: str, + preview_limit: int = 0, +) -> CustomQueryResult: + """Wrap and execute a custom-test SQL query against the target DB. + + Args: + connection: Target ``Connection`` to run the query on. + schema: Schema name for ``{DATA_SCHEMA}`` substitution in the user's SQL. + custom_sql: User-supplied query. Should return rows matching the test failure criteria. + preview_limit: When > 0, also fetch up to N rows for preview (only when row_count > 0). + + Returns the failure-criteria row count and (optionally) the preview rows. DB errors + propagate as-is — the caller decides how to surface them. + """ + sql_with_schema = replace_params(custom_sql, {"DATA_SCHEMA": schema}).rstrip().rstrip(";") + flavor_service = get_flavor_service(connection.sql_flavor) + + count_sql = f"SELECT COUNT(*) AS row_count FROM ({sql_with_schema}) ERR_TABLE" + count_rows = fetch_from_target_db(connection, count_sql) + row_count = int(count_rows[0]["row_count"]) if count_rows else 0 + + preview_rows: list[RowMapping] = [] + if preview_limit > 0 and row_count > 0: + prefix, suffix = flavor_service.row_limit_clauses(preview_limit) + preview_sql = f"SELECT {prefix} * FROM ({sql_with_schema}) ERR_TABLE {suffix}".strip() + preview_rows = fetch_from_target_db(connection, preview_sql) + + return CustomQueryResult(row_count=row_count, preview_rows=preview_rows) diff --git a/testgen/common/database/column_chars.py b/testgen/common/database/column_chars.py new file mode 100644 index 00000000..6faa08f7 --- /dev/null +++ b/testgen/common/database/column_chars.py @@ -0,0 +1,17 @@ +import dataclasses + + +@dataclasses.dataclass +class ColumnChars: + schema_name: str + table_name: str + column_name: str + ordinal_position: int | None = None + general_type: str | None = None + column_type: str | None = None + db_data_type: str | None = None + is_decimal: bool = False + approx_record_ct: int | None = None + # This should not default to 0 since we don't always retrieve actual row counts + # UI relies on the null value to know that the approx_record_ct should be displayed instead + record_ct: int | None = None diff --git a/testgen/common/database/flavor/flavor_service.py b/testgen/common/database/flavor/flavor_service.py index dafc7d2a..becb8b38 100644 --- a/testgen/common/database/flavor/flavor_service.py +++ b/testgen/common/database/flavor/flavor_service.py @@ -6,9 +6,10 @@ from sqlalchemy import create_engine as sqlalchemy_create_engine from sqlalchemy.engine.base import Engine +from testgen.common.database.column_chars import ColumnChars from testgen.common.encrypt import DecryptText -SQLFlavor = Literal["redshift", "redshift_spectrum", "snowflake", "mssql", "postgresql", "databricks", "bigquery", "oracle", "sap_hana"] +SQLFlavor = Literal["redshift", "redshift_spectrum", "snowflake", "mssql", "postgresql", "databricks", "bigquery", "oracle", "sap_hana", "salesforce_data360"] RowLimitingClause = Literal["limit", "top", "fetch"] @@ -100,10 +101,37 @@ class FlavorService: varchar_type = "VARCHAR(1000)" ddf_table_ref = "table_name" row_limiting_clause: RowLimitingClause = "limit" + + def row_limit_clauses(self, n: int) -> tuple[str, str]: + """Return ``(prefix, suffix)`` SQL fragments for limiting a SELECT to ``n`` rows.""" + if self.row_limiting_clause == "top": + return f"TOP {n}", "" + if self.row_limiting_clause == "fetch": + return "", f"FETCH FIRST {n} ROWS ONLY" + return "", f"LIMIT {n}" + default_uppercase = False test_query = "SELECT 1" url_scheme = "postgresql" + qualifies_table_refs_with_schema = True + metadata_via_api = False + + def get_schema_columns(self, _params: ResolvedConnectionParams, _schema: str) -> list[ColumnChars] | None: + """Return column metadata without querying information_schema. + + Override this for flavors that lack information_schema and set ``metadata_via_api = True``. + Return None to use the standard SQL template path. + """ + return None + + def get_table_ref(self, schema: str, table: str) -> str: + """Return a fully-qualified table reference for SQL queries.""" + q = self.quote_character + if not self.qualifies_table_refs_with_schema: + return f"{q}{table}{q}" + return f"{q}{schema}{q}.{q}{table}{q}" + def get_pre_connection_queries(self, params: ResolvedConnectionParams) -> list[tuple[str, dict | None]]: # noqa: ARG002 return [] @@ -134,4 +162,3 @@ def get_connection_string_from_fields(self, params: ResolvedConnectionParams) -> def get_connection_string_head(self, params: ResolvedConnectionParams) -> str: return f"{self.url_scheme}://{params.username}:{quote_plus(params.password)}@" - diff --git a/testgen/common/database/flavor/salesforce_data360_flavor_service.py b/testgen/common/database/flavor/salesforce_data360_flavor_service.py new file mode 100644 index 00000000..c5d27d66 --- /dev/null +++ b/testgen/common/database/flavor/salesforce_data360_flavor_service.py @@ -0,0 +1,128 @@ +from typing import Any + +from sqlalchemy.dialects import registry +from sqlalchemy.pool import StaticPool + +from testgen.common.database.column_chars import ColumnChars +from testgen.common.database.flavor.flavor_service import FlavorService, ResolvedConnectionParams + +# Register the dialect so create_engine("salesforce_data360://") works +# without requiring an installed entry point. +registry.register("salesforce_data360", "testgen.common.database.salesforce_data360_dialect", "SalesforceData360Dialect") + +# Mapping from Data 360 metadata types to TestGen general_type codes. +# Data 360's metadata API returns a small fixed vocabulary — these 6 are all that +# have been observed against profiled DMOs and DLOs. Unknown types preserve the +# raw metadata string as column_type and fall through to general_type "X" in +# get_schema_columns(), matching get_schema_ddf.sql behavior for other flavors. +_TYPE_MAP: dict[str, tuple[str, str, bool]] = { + # metadata_type → (column_type, general_type, is_decimal) + "STRING": ("varchar", "A", False), + "NUMBER": ("numeric", "N", True), + "BIGINT": ("bigint", "N", False), + "BOOLEAN": ("boolean", "B", False), + "DATE": ("date", "D", False), + "DATE_TIME": ("datetime", "D", False), +} + + +class SalesforceData360FlavorService(FlavorService): + + concat_operator = "||" + quote_character = '"' + escaped_single_quote = "''" + escaped_underscore = "\\_" + escape_clause = "" + varchar_type = "VARCHAR(1000)" + default_uppercase = False + test_query = "SELECT 1" + url_scheme = "salesforce_data360" + qualifies_table_refs_with_schema = False + metadata_via_api = True + + def get_connection_string(self, _params: ResolvedConnectionParams) -> str: + return "salesforce_data360://" + + def get_connection_string_from_fields(self, _params: ResolvedConnectionParams) -> str: + return "salesforce_data360://" + + def get_connect_args(self, params: ResolvedConnectionParams) -> dict: + # Map Connection model fields to salesforce-cdp-connector kwargs. + # project_host → login_url (org My Domain URL) + # project_user → client_id (Consumer Key from External Client App) + # password → client_secret (Client Credentials flow) + # project_db → username (JWT Bearer flow) + # private_key → private_key (JWT Bearer flow) + # connect_by_key → True = JWT, False = Client Credentials + # table_group_schema → dataspace (Data 360 Data Space — scopes the CDP token) + args: dict[str, Any] = { + "login_url": params.host, + "client_id": params.username, + } + + # Connection-only contexts (Test Connection from the connection wizard) have + # no table group yet, so dbschema is empty — the connector then defaults to + # the org's default Data Space, which is fine for "can we authenticate?". + # Table-group-scoped contexts (profiling, test execution, preview) supply + # the Data Space and the resulting CDP token is restricted to it. + if params.dbschema: + args["dataspace"] = params.dbschema + + if params.connect_by_key and params.private_key: + args["username"] = params.dbname + args["private_key"] = params.private_key + else: + args["client_secret"] = params.password + + return args + + def get_engine_args(self, _params: ResolvedConnectionParams) -> dict[str, Any]: + return { + "pool_pre_ping": False, + "poolclass": StaticPool, + } + + def get_pre_connection_queries(self, _params: ResolvedConnectionParams) -> list[tuple[str, dict | None]]: + return [] + + def get_schema_columns(self, params: ResolvedConnectionParams, schema: str) -> list[ColumnChars] | None: + """Fetch column metadata via the salesforce-cdp-connector metadata API. + + Data 360 has no information_schema — this method replaces the SQL-based + schema discovery for this flavor. + """ + from salesforcecdpconnector.connection import SalesforceCDPConnection + + connect_args = self.get_connect_args(params) + conn = SalesforceCDPConnection(**connect_args) + + try: + tables = conn.list_tables() + finally: + conn.close() + + columns: list[ColumnChars] = [] + for table in tables: + for ordinal, field in enumerate(table.fields, start=1): + if not field.name: + continue + + meta_type = (field.type or "").upper() + mapped = _TYPE_MAP.get(meta_type) + if mapped is not None: + column_type, general_type, is_decimal = mapped + else: + column_type, general_type, is_decimal = meta_type.lower(), "X", False + + columns.append(ColumnChars( + schema_name=schema, + table_name=table.name, + column_name=field.name, + column_type=column_type, + db_data_type=meta_type, + ordinal_position=ordinal, + general_type=general_type, + is_decimal=is_decimal, + )) + + return columns diff --git a/testgen/common/database/salesforce_data360_dialect.py b/testgen/common/database/salesforce_data360_dialect.py new file mode 100644 index 00000000..271f09f3 --- /dev/null +++ b/testgen/common/database/salesforce_data360_dialect.py @@ -0,0 +1,165 @@ +"""Minimal SQLAlchemy dialect for Salesforce Data 360. + +Wraps the ``salesforce-cdp-connector`` DB-API 2.0 module so that +SQLAlchemy's ``create_engine`` / ``engine.connect()`` flow works. + +The connector speaks PostgreSQL-compatible SQL (Tableau Hyper engine) +but uses HTTP + OAuth instead of a wire protocol, so we inherit from +``DefaultDialect`` rather than ``PGDialect`` to avoid unwanted +introspection queries. +""" + +import time + +import jwt +from salesforcecdpconnector import authentication_helper as _auth_helper +from salesforcecdpconnector.constants import ( + AUTH_PARAM_ASSERTION, + AUTH_PARAM_CLIENT_CREDENTIALS_GRANT_TYPE, + AUTH_PARAM_CLIENT_ID, + AUTH_PARAM_CLIENT_SECRET, + AUTH_PARAM_GRANT_TYPE, + AUTH_PARAM_JWT_GRANT_TYPE, + AUTH_RESPONSE_ACCESS_TOKEN, + AUTH_RESPONSE_INSTANCE_URL, +) +from salesforcecdpconnector.exceptions import Error as _CdpError +from sqlalchemy.engine.default import DefaultDialect + + +def _format_oauth_failure(grant_label: str, response) -> str: + """Extract Salesforce's ``error`` / ``error_description`` from an OAuth failure. + + The stock connector discards the response body and surfaces only the HTTP + status, which leaves users without an actionable signal (e.g. ``user + hasn't approved this consumer`` vs ``invalid assertion`` vs ``invalid + grant``). This pulls the body fields out so the error reaches the UI. + """ + detail = "" + try: + body = response.json() + description = body.get("error_description") + code = body.get("error") + if description and code: + detail = f": {code} — {description}" + elif description: + detail = f": {description}" + elif code: + detail = f": {code}" + else: + detail = f": {response.text[:300]}" + except ValueError: + if response.text: + detail = f": {response.text[:300]}" + return f"Salesforce {grant_label} authentication failed (HTTP {response.status_code}){detail}" + + +def _token_by_jwt_bearer_flow(self, login_url, username, client_id, private_key): + payload = { + "iss": client_id, + "exp": int(time.time()) + 3600, + "aud": login_url, + "sub": username, + } + encoded = jwt.encode(payload, private_key, algorithm="RS256") + params = {AUTH_PARAM_GRANT_TYPE: AUTH_PARAM_JWT_GRANT_TYPE, AUTH_PARAM_ASSERTION: encoded} + response = self.session.post(url=login_url + "/services/oauth2/token", params=params) + if response.status_code == 200: + access_code = response.json() + return self._exchange_token(access_code[AUTH_RESPONSE_INSTANCE_URL], access_code[AUTH_RESPONSE_ACCESS_TOKEN]) + raise _CdpError(_format_oauth_failure("JWT Bearer", response)) + + +def _token_by_client_creds_flow(self, login_url, client_id, client_secret): + params = { + AUTH_PARAM_GRANT_TYPE: AUTH_PARAM_CLIENT_CREDENTIALS_GRANT_TYPE, + AUTH_PARAM_CLIENT_ID: client_id, + AUTH_PARAM_CLIENT_SECRET: client_secret, + } + response = self.session.post(url=login_url + "/services/oauth2/token", params=params) + if response.status_code == 200: + access_code = response.json() + return self._exchange_token(access_code[AUTH_RESPONSE_INSTANCE_URL], access_code[AUTH_RESPONSE_ACCESS_TOKEN]) + raise _CdpError(_format_oauth_failure("Client Credentials", response)) + + +# Replace the connector's auth methods at import time. The stock methods build +# the same request but throw away the response body on failure. The patched +# methods preserve SF's ``error_description`` so the cause is visible in the +# Test Connection UI and in application logs. +_auth_helper.AuthenticationHelper._token_by_jwt_bearer_flow = _token_by_jwt_bearer_flow +_auth_helper.AuthenticationHelper._token_by_client_creds_flow = _token_by_client_creds_flow + + +class _DBAPIShim: + """Shim module that satisfies SQLAlchemy's ``dialect.dbapi()`` contract. + + SQLAlchemy expects ``dbapi.connect(**kwargs)`` to return a DB-API + connection. We delegate to ``SalesforceCDPConnection``. + """ + + # Re-export the connector's exception hierarchy so SQLAlchemy can + # catch errors through the standard ``dbapi.Error`` path. + from salesforcecdpconnector.exceptions import ( + DatabaseError, + Error, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + ) + + paramstyle = "format" # SQLAlchemy needs *some* value; we never actually bind params + + @staticmethod + def connect(**kwargs): + from salesforcecdpconnector.connection import SalesforceCDPConnection + + conn = SalesforceCDPConnection(**kwargs) + # Patch the cursor factory to add missing DB-API attributes + _original_cursor = conn.cursor + + def _patched_cursor(): + cursor = _original_cursor() + if not hasattr(cursor, "rowcount"): + cursor.rowcount = -1 + if not hasattr(cursor, "lastrowid"): + cursor.lastrowid = None + return cursor + + conn.cursor = _patched_cursor + return conn + + +class SalesforceData360Dialect(DefaultDialect): + name = "salesforce_data360" + supports_alter = False + supports_transactions = False + supports_native_boolean = True + supports_statement_cache = False + supports_default_values = False + supports_empty_insert = False + postfetch_lastrowid = False + implicit_returning = False + + @classmethod + def dbapi(cls): + return _DBAPIShim + + @classmethod + def import_dbapi(cls): + return _DBAPIShim + + def create_connect_args(self, _url): + # All auth params arrive via connect_args; the URL is a dummy + # ``salesforce_data360://`` placeholder. + return ([], {}) + + def do_ping(self, _dbapi_connection): + return True + + def initialize(self, connection): + # Skip server-version detection and other introspection that + # DefaultDialect.initialize() performs. + pass diff --git a/testgen/common/date_service.py b/testgen/common/date_service.py index 72503ad3..eefaf131 100644 --- a/testgen/common/date_service.py +++ b/testgen/common/date_service.py @@ -62,7 +62,7 @@ def parse_since(since: str, *, today: date | None = None) -> date: def parse_fuzzy_date(value: str | int) -> datetime | None: if type(value) == str: - return datetime.strptime(value, "%Y-%m-%d %H:%M:%S") + return datetime.fromisoformat(value) elif type(value) == int or type(value) == float: ts = int(value) if ts >= 1e11: diff --git a/testgen/common/enums.py b/testgen/common/enums.py index 94d08a37..3679f709 100644 --- a/testgen/common/enums.py +++ b/testgen/common/enums.py @@ -27,3 +27,59 @@ class ImpactDimension(StrEnum): CONFORMANCE = "Conformance" REGULARITY = "Regularity" USABILITY = "Usability" + + +class JobKey(StrEnum): + """``job_key`` column values for ``job_executions`` and ``job_schedules``.""" + run_profile = "run-profile" + run_tests = "run-tests" + run_monitors = "run-monitors" + run_test_generation = "run-test-generation" + run_score_update = "run-score-update" + recalculate_project_scores = "recalculate-project-scores" + run_data_cleanup = "run-data-cleanup" + + +class JobSource(StrEnum): + """``source`` column values for ``job_executions``.""" + api = "api" + ui = "ui" + scheduler = "scheduler" + mcp = "mcp" + cli = "cli" + backfill = "backfill" + system = "system" + + +class JobStatus(StrEnum): + """``status`` column values for ``job_executions``. Lifecycle states; see + ``job_execution.py`` for the transition rules.""" + PENDING = "pending" + CLAIMED = "claimed" + RUNNING = "running" + COMPLETED = "completed" + ERROR = "error" + CANCEL_REQUESTED = "cancel_requested" + CANCELED = "canceled" + + +class Disposition(StrEnum): + """Stored disposition values for ``profile_anomaly_results.disposition`` and + ``test_results.disposition``. The user-facing label for ``INACTIVE`` is "Muted".""" + CONFIRMED = "Confirmed" + DISMISSED = "Dismissed" + INACTIVE = "Inactive" + + +class IssueLikelihood(StrEnum): + """Stored ``profile_anomaly_types.issue_likelihood`` values.""" + DEFINITE = "Definite" + LIKELY = "Likely" + POSSIBLE = "Possible" + POTENTIAL_PII = "Potential PII" + + +class PiiRisk(StrEnum): + """Risk level extracted from PII issue ``detail`` strings via ``priority`` hybrid.""" + HIGH = "High" + MODERATE = "Moderate" diff --git a/testgen/common/freshness_service.py b/testgen/common/freshness_service.py index f7810787..b87e5a93 100644 --- a/testgen/common/freshness_service.py +++ b/testgen/common/freshness_service.py @@ -143,6 +143,25 @@ def get_schedule_params(prediction: dict | str | None) -> ScheduleParams: return ScheduleParams(excluded_days=excluded_days, window_start=window_start, window_end=window_end) +def get_freshness_gated_baseline(prediction: dict | str | None) -> float | None: + """Extract the freshness-gated baseline value from a Volume_Trend / Metric_Trend + prediction JSON. + + The baseline is the test value at the most recent detected freshness update. Returns + None when the prediction is missing, empty, does not have freshness-gating enabled, + or has no baseline value recorded. + """ + if not prediction: + return None + parsed = prediction if isinstance(prediction, dict) else json.loads(prediction) + if not parsed.get("freshness_gated"): + return None + baseline_value = parsed.get("baseline_value") + if baseline_value is None: + return None + return float(baseline_value) + + def is_excluded_day( dt: pd.Timestamp, exclude_weekends: bool, diff --git a/testgen/common/job_context.py b/testgen/common/job_context.py index e711899e..8d9b2036 100644 --- a/testgen/common/job_context.py +++ b/testgen/common/job_context.py @@ -4,11 +4,13 @@ from dataclasses import dataclass from uuid import UUID +from testgen.common.enums import JobSource + @dataclass(frozen=True) class JobContext: job_id: UUID | None = None - source: str = "CLI" + source: JobSource = JobSource.cli job_context: contextvars.ContextVar[JobContext] = contextvars.ContextVar("job_context", default=JobContext()) diff --git a/testgen/common/mixpanel_service.py b/testgen/common/mixpanel_service.py index dba5c74f..1af2ae4c 100644 --- a/testgen/common/mixpanel_service.py +++ b/testgen/common/mixpanel_service.py @@ -54,6 +54,17 @@ def _hash_value(self, value: bytes | str, digest_size: int = 8) -> str: @safe_method def send_event(self, event_name, include_usage=False, **properties): + self._track(event_name, include_usage=include_usage, **properties) + + def send_feedback(self, **properties): + # User-submitted feedback is content the user explicitly chose to share + # so it is not gated by the TG_ANALYTICS opt-out. + try: + self._track("feedback", **properties) + except Exception: + LOG.exception("Error sending feedback") + + def _track(self, event_name, include_usage=False, **properties): properties.setdefault("instance_id", self.instance_id) properties.setdefault("edition", settings.DOCKER_HUB_REPOSITORY) properties.setdefault("version", settings.VERSION) diff --git a/testgen/common/models/__init__.py b/testgen/common/models/__init__.py index 4fe7211f..023f6478 100644 --- a/testgen/common/models/__init__.py +++ b/testgen/common/models/__init__.py @@ -4,7 +4,7 @@ import threading import urllib.parse -from sqlalchemy import create_engine +from sqlalchemy import create_engine, delete from sqlalchemy.orm import DeclarativeBase, sessionmaker from sqlalchemy.orm import Session as SQLAlchemySession @@ -30,6 +30,14 @@ class Base(DeclarativeBase): # Can be removed once all models use Mapped[] annotations. __allow_unmapped__ = True + @classmethod + def delete_where(cls, *clauses) -> int: + """Single-statement DELETE on this model filtered by ``clauses``; + returns the row count. Callers may ignore the return when not needed. + """ + result = get_current_session().execute(delete(cls).where(*clauses)) + return result.rowcount or 0 + Session = sessionmaker( engine, expire_on_commit=False, diff --git a/testgen/common/models/connection.py b/testgen/common/models/connection.py index dfb36e71..a940edba 100644 --- a/testgen/common/models/connection.py +++ b/testgen/common/models/connection.py @@ -4,7 +4,6 @@ from urllib.parse import parse_qs, urlparse from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import ( BigInteger, Boolean, @@ -23,17 +22,18 @@ from testgen.common.database.flavor.flavor_service import SQLFlavor from testgen.common.models import get_current_session from testgen.common.models.custom_types import JSON_TYPE, EncryptedBytea, EncryptedJson -from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.entity import Entity, EntityMinimal from testgen.common.models.table_group import TableGroup from testgen.utils import is_uuid4 -SQLFlavorCode = Literal["redshift", "redshift_spectrum", "snowflake", "mssql", "azure_mssql", "synapse_mssql", "postgresql", "databricks", "bigquery", "oracle", "sap_hana"] +SQLFlavorCode = Literal["redshift", "redshift_spectrum", "snowflake", "mssql", "azure_mssql", "synapse_mssql", "postgresql", "databricks", "bigquery", "oracle", "sap_hana", "salesforce_data360"] @dataclass class ConnectionMinimal(EntityMinimal): project_code: str connection_id: int + sql_flavor: SQLFlavor sql_flavor_code: SQLFlavorCode connection_name: str @@ -69,7 +69,6 @@ class Connection(Entity): _minimal_columns = ConnectionMinimal.__annotations__.keys() @classmethod - @st.cache_data(show_spinner=False) def get_minimal(cls, identifier: int) -> ConnectionMinimal | None: result = cls._get_columns(identifier, cls._minimal_columns) return ConnectionMinimal(**result) if result else None @@ -83,7 +82,6 @@ def get_by_table_group(cls, table_group_id: str | UUID) -> Self | None: return get_current_session().scalars(query).first() @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[ConnectionMinimal]: diff --git a/testgen/common/models/data_column.py b/testgen/common/models/data_column.py index 0280a28b..cee7d088 100644 --- a/testgen/common/models/data_column.py +++ b/testgen/common/models/data_column.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from datetime import datetime +from enum import StrEnum from uuid import UUID, uuid4 from sqlalchemy import ( @@ -12,14 +13,117 @@ and_, asc, case, + desc, func, select, ) from sqlalchemy.dialects import postgresql +from testgen.common.models import get_current_session from testgen.common.models.entity import Entity, EntityMinimal from testgen.common.models.hygiene_issue import HygieneIssue from testgen.common.models.profile_result import ProfileResult +from testgen.common.models.profiling_run import ProfilingRun + + +class GeneralType(StrEnum): + """User-facing word values for the column ``general_type``.""" + + ALPHA = "Alpha" + NUMERIC = "Numeric" + DATETIME = "Datetime" + BOOLEAN = "Boolean" + TIME = "Time" + OTHER = "Other" + + +# Translates the user-facing words to the single-letter codes stored on +# ``data_column_chars.general_type`` for WHERE-clause matching. +GENERAL_TYPE_TO_CODE: dict[GeneralType, str] = { + GeneralType.ALPHA: "A", + GeneralType.NUMERIC: "N", + GeneralType.DATETIME: "D", + GeneralType.BOOLEAN: "B", + GeneralType.TIME: "T", + GeneralType.OTHER: "X", +} + + +class SuggestedDataType(StrEnum): + """Values accepted for the ``suggested_data_type`` argument.""" + + ANY = "Any" + SMALLINT = "Smallint" + INTEGER = "Integer" + BIGINT = "Bigint" + DECIMAL = "Decimal" + NUMERIC = "Numeric" + VARCHAR = "Varchar" + DATE = "Date" + TIMESTAMP = "Timestamp" + BOOLEAN = "Boolean" + + +# Maps the user-facing word to the SQL-type prefix matched against +# ``datatype_suggestion`` (``Any`` is a sentinel — no prefix, just non-null check). +SUGGESTED_DATA_TYPE_TO_PREFIX: dict[SuggestedDataType, str | None] = { + SuggestedDataType.ANY: None, + SuggestedDataType.SMALLINT: "SMALLINT", + SuggestedDataType.INTEGER: "INTEGER", + SuggestedDataType.BIGINT: "BIGINT", + SuggestedDataType.DECIMAL: "DECIMAL", + SuggestedDataType.NUMERIC: "NUMERIC", + SuggestedDataType.VARCHAR: "VARCHAR", + SuggestedDataType.DATE: "DATE", + SuggestedDataType.TIMESTAMP: "TIMESTAMP", + SuggestedDataType.BOOLEAN: "BOOLEAN", +} + + +class ColumnOrderBy(StrEnum): + """Values accepted for the ``order_by`` argument on column profile listings.""" + + NULL_RATIO = "Null Ratio" + DISTINCT_RATIO = "Distinct Ratio" + FILLED_RATIO = "Filled Ratio" + SCORE_PROFILING = "Profiling Score" + SCORE_TESTING = "Testing Score" + HYGIENE_COUNT = "Hygiene Count" + + +class ProfileMetric(StrEnum): + """Profile-metric vocabulary: linear/arithmetic stats from a profiling run. + + Covers general column ratios (null / distinct / filled), type-specific + statistics (length, numeric range, date range, true count), table-level + row count, and table-group rollups (profiling score, hygiene issues). + + Labels align with the field names in ``column_profile_fields_resource``. + """ + + # Apply to any column + NULL_RATIO = "Null Ratio" + DISTINCT_RATIO = "Distinct Ratio" + FILLED_RATIO = "Filled Ratio" + # Apply to the parent table + RECORD_COUNT = "Row Count" + # Apply to the whole table group + PROFILING_SCORE = "Profiling Score" + HYGIENE_COUNT = "Hygiene Issues" + # Alpha-only + MIN_LENGTH = "Minimum Length" + MAX_LENGTH = "Maximum Length" + AVG_LENGTH = "Average Length" + # Numeric-only + MIN = "Minimum Value" + MAX = "Maximum Value" + AVG = "Average Value" + STDEV = "Standard Deviation" + # Date-only + MIN_DATE = "Minimum Date" + MAX_DATE = "Maximum Date" + # Boolean-only + TRUE_COUNT = "True Count" @dataclass @@ -40,6 +144,98 @@ class ColumnProfileSummary(EntityMinimal): hygiene_issue_count: int +@dataclass +class ColumnProfileDetail(EntityMinimal): + """L2 column profiling detail — header fields plus type-specific stats and run identity.""" + + # Identity + column_name: str + table_name: str + schema_name: str | None + # Types & metadata + general_type: str | None + column_type: str | None + db_data_type: str | None + functional_data_type: str | None + datatype_suggestion: str | None + functional_table_type: str | None + pii_flag: str | None + critical_data_element: bool | None + # Counts + record_ct: int | None + value_ct: int | None + distinct_value_ct: int | None + null_value_ct: int | None + filled_value_ct: int | None + zero_value_ct: int | None + # Alpha + min_length: int | None + max_length: int | None + avg_length: float | None + min_text: str | None + max_text: str | None + top_freq_values: str | None + top_patterns: str | None + distinct_std_value_ct: int | None + distinct_pattern_ct: int | None + std_pattern_match: str | None + mixed_case_ct: int | None + lower_case_ct: int | None + upper_case_ct: int | None + non_alpha_ct: int | None + includes_digit_ct: int | None + numeric_ct: int | None + date_ct: int | None + quoted_value_ct: int | None + lead_space_ct: int | None + embedded_space_ct: int | None + avg_embedded_spaces: float | None + zero_length_ct: int | None + # Numeric + min_value: float | None + min_value_over_0: float | None + max_value: float | None + avg_value: float | None + stdev_value: float | None + percentile_25: float | None + percentile_50: float | None + percentile_75: float | None + # Date + min_date: datetime | None + max_date: datetime | None + before_1yr_date_ct: int | None + before_5yr_date_ct: int | None + before_20yr_date_ct: int | None + within_1yr_date_ct: int | None + within_1mo_date_ct: int | None + future_date_ct: int | None + # Boolean + boolean_true_ct: int | None + # Per-column profiling failure + query_error: str | None + # Scores & hygiene + dq_score_profiling: float | None + dq_score_testing: float | None + hygiene_issue_count: int + # Run identity + profile_run_id: UUID | None + profile_run_je_id: UUID | None + profile_run_status: str | None + profile_run_started_at: datetime | None + profile_run_ended_at: datetime | None + profile_run_log_message: str | None + + +@dataclass +class ColumnSearchHit(EntityMinimal): + project_code: str + table_groups_id: UUID + table_groups_name: str + schema_name: str | None + table_name: str + column_name: str + + class DataColumnChars(Entity): __tablename__ = "data_column_chars" @@ -78,6 +274,7 @@ def list_for_table_group( *clauses, table_groups_id: UUID, profiling_run_id: UUID | None = None, + order_by: ColumnOrderBy | None = None, page: int, limit: int, ) -> tuple[list[ColumnProfileSummary], int]: @@ -162,7 +359,246 @@ def list_for_table_group( cls.drop_date.is_(None), *clauses, ) - .order_by(asc(cls.table_name), asc(cls.ordinal_position), asc(cls.column_name)) ) + null_ratio_expr = ProfileResult.null_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) + distinct_ratio_expr = ProfileResult.distinct_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) + filled_ratio_expr = ProfileResult.filled_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) + # Deterministic tiebreaker so paginated callers don't see rows skip or duplicate + # across pages when the primary sort has ties. + tiebreaker = (asc(cls.table_name), asc(cls.ordinal_position), asc(cls.column_name)) + order_exprs: tuple + if order_by is ColumnOrderBy.NULL_RATIO: + order_exprs = (desc(null_ratio_expr).nulls_last(), *tiebreaker) + elif order_by is ColumnOrderBy.DISTINCT_RATIO: + order_exprs = (asc(distinct_ratio_expr).nulls_last(), *tiebreaker) + elif order_by is ColumnOrderBy.FILLED_RATIO: + order_exprs = (desc(filled_ratio_expr).nulls_last(), *tiebreaker) + elif order_by is ColumnOrderBy.SCORE_PROFILING: + order_exprs = (asc(cls.dq_score_profiling).nulls_last(), *tiebreaker) + elif order_by is ColumnOrderBy.SCORE_TESTING: + order_exprs = (asc(cls.dq_score_testing).nulls_last(), *tiebreaker) + elif order_by is ColumnOrderBy.HYGIENE_COUNT: + order_exprs = (desc(func.coalesce(hygiene_subq.c.hygiene_issue_count, 0)), *tiebreaker) + else: + order_exprs = tiebreaker + + query = query.order_by(*order_exprs) + return cls._paginate(query, page=page, limit=limit, data_class=ColumnProfileSummary) + + @classmethod + def get_column_detail( + cls, + table_groups_id: UUID, + table_name: str, + column_name: str, + profiling_run_id: UUID | None = None, + ) -> ColumnProfileDetail | None: + """Fetch the L2 profile detail for a single column. + + When ``profiling_run_id`` is None, joins on the column's + ``last_complete_profile_run_id`` so the caller gets the latest run. + Returns None when the column does not exist in the table group. + """ + from testgen.common.models.data_table import DataTable + + profile_run_filter = ( + ProfileResult.profile_run_id == profiling_run_id + if profiling_run_id is not None + else ProfileResult.profile_run_id == cls.last_complete_profile_run_id + ) + + hygiene_subq = ( + select( + HygieneIssue.profile_run_id.label("profile_run_id"), + HygieneIssue.schema_name.label("schema_name"), + HygieneIssue.table_name.label("table_name"), + HygieneIssue.column_name.label("column_name"), + func.count().label("hygiene_issue_count"), + ) + .where( + HygieneIssue.table_groups_id == table_groups_id, + func.coalesce(HygieneIssue.disposition, "Confirmed") == "Confirmed", + ) + .group_by( + HygieneIssue.profile_run_id, + HygieneIssue.schema_name, + HygieneIssue.table_name, + HygieneIssue.column_name, + ) + .subquery() + ) + + cde_coalesced = case( + (cls.critical_data_element.is_(True), True), + (DataTable.critical_data_element.is_(True), True), + else_=False, + ).label("critical_data_element") + + query = ( + select( + cls.column_name, + cls.table_name, + cls.schema_name, + cls.general_type, + ProfileResult.column_type, + cls.db_data_type, + cls.functional_data_type, + ProfileResult.datatype_suggestion, + ProfileResult.functional_table_type, + cls.pii_flag, + cde_coalesced, + ProfileResult.record_ct, + ProfileResult.value_ct, + ProfileResult.distinct_value_ct, + ProfileResult.null_value_ct, + ProfileResult.filled_value_ct, + ProfileResult.zero_value_ct, + ProfileResult.min_length, + ProfileResult.max_length, + ProfileResult.avg_length, + ProfileResult.min_text, + ProfileResult.max_text, + ProfileResult.top_freq_values, + ProfileResult.top_patterns, + ProfileResult.distinct_std_value_ct, + ProfileResult.distinct_pattern_ct, + ProfileResult.std_pattern_match, + ProfileResult.mixed_case_ct, + ProfileResult.lower_case_ct, + ProfileResult.upper_case_ct, + ProfileResult.non_alpha_ct, + ProfileResult.includes_digit_ct, + ProfileResult.numeric_ct, + ProfileResult.date_ct, + ProfileResult.quoted_value_ct, + ProfileResult.lead_space_ct, + ProfileResult.embedded_space_ct, + ProfileResult.avg_embedded_spaces, + ProfileResult.zero_length_ct, + ProfileResult.min_value, + ProfileResult.min_value_over_0, + ProfileResult.max_value, + ProfileResult.avg_value, + ProfileResult.stdev_value, + ProfileResult.percentile_25, + ProfileResult.percentile_50, + ProfileResult.percentile_75, + ProfileResult.min_date, + ProfileResult.max_date, + ProfileResult.before_1yr_date_ct, + ProfileResult.before_5yr_date_ct, + ProfileResult.before_20yr_date_ct, + ProfileResult.within_1yr_date_ct, + ProfileResult.within_1mo_date_ct, + ProfileResult.future_date_ct, + ProfileResult.boolean_true_ct, + ProfileResult.query_error, + cls.dq_score_profiling, + cls.dq_score_testing, + func.coalesce(hygiene_subq.c.hygiene_issue_count, 0).label("hygiene_issue_count"), + ProfilingRun.id.label("profile_run_id"), + ProfilingRun.job_execution_id.label("profile_run_je_id"), + ProfilingRun.status.label("profile_run_status"), + ProfilingRun.profiling_starttime.label("profile_run_started_at"), + ProfilingRun.profiling_endtime.label("profile_run_ended_at"), + ProfilingRun.log_message.label("profile_run_log_message"), + ) + .outerjoin(DataTable, DataTable.id == cls.table_id) + .outerjoin( + ProfileResult, + and_( + profile_run_filter, + ProfileResult.schema_name == cls.schema_name, + ProfileResult.table_name == cls.table_name, + ProfileResult.column_name == cls.column_name, + ), + ) + .outerjoin( + hygiene_subq, + and_( + hygiene_subq.c.profile_run_id == ProfileResult.profile_run_id, + hygiene_subq.c.schema_name == cls.schema_name, + hygiene_subq.c.table_name == cls.table_name, + hygiene_subq.c.column_name == cls.column_name, + ), + ) + .outerjoin(ProfilingRun, ProfilingRun.id == ProfileResult.profile_run_id) + .where( + cls.table_groups_id == table_groups_id, + cls.table_name == table_name, + cls.column_name == column_name, + cls.drop_date.is_(None), + ) + .limit(1) + ) + + row = get_current_session().execute(query).mappings().first() + return ColumnProfileDetail(**row) if row else None + + @classmethod + def search_by_name( + cls, + *clauses, + pattern: str, + page: int, + limit: int, + ) -> tuple[list[ColumnSearchHit], int]: + """Cross-table-group column-name search. Scoping clauses are passed in by the caller. + + ``pattern`` is matched with ``ILIKE``. Callers are expected to pre-wrap bare + tokens with ``%`` if substring search is desired; literal ``%`` / ``_`` from + the caller are honored as wildcards. + """ + # Local import: avoid circular dependency with TableGroup. + from testgen.common.models.table_group import TableGroup + + query = ( + select( + TableGroup.project_code, + TableGroup.id.label("table_groups_id"), + TableGroup.table_groups_name, + cls.schema_name, + cls.table_name, + cls.column_name, + ) + .join(TableGroup, TableGroup.id == cls.table_groups_id) + .where( + cls.column_name.ilike(pattern, escape="\\"), + cls.drop_date.is_(None), + *clauses, + ) + .order_by( + asc(TableGroup.project_code), + asc(TableGroup.table_groups_name), + asc(cls.table_name), + asc(cls.column_name), + ) + ) + + return cls._paginate(query, page=page, limit=limit, data_class=ColumnSearchHit) + + @classmethod + def summarize_matches_by_project( + cls, + *clauses, + pattern: str, + ) -> list[tuple[str, int]]: + """Per-project match counts for a column-name search — same WHERE shape as :meth:`search_by_name`.""" + # Local import: avoid circular dependency with TableGroup. + from testgen.common.models.table_group import TableGroup + + query = ( + select(TableGroup.project_code, func.count().label("match_count")) + .select_from(cls) + .join(TableGroup, TableGroup.id == cls.table_groups_id) + .where( + cls.column_name.ilike(pattern, escape="\\"), + cls.drop_date.is_(None), + *clauses, + ) + .group_by(TableGroup.project_code) + .order_by(TableGroup.project_code) + ) + return [(row.project_code, row.match_count) for row in get_current_session().execute(query).all()] diff --git a/testgen/common/models/entity.py b/testgen/common/models/entity.py index 8f055bda..248d0d2c 100644 --- a/testgen/common/models/entity.py +++ b/testgen/common/models/entity.py @@ -3,8 +3,7 @@ from typing import Any, Self from uuid import UUID -import streamlit as st -from sqlalchemy import delete, func, select +from sqlalchemy import func, select from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList @@ -46,7 +45,6 @@ class Entity(Base): _default_order_by: tuple[str | InstrumentedAttribute] = ("id",) @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def get(cls, identifier: str | int | UUID, *clauses) -> Self | None: """Fetch by primary key, optionally narrowed by extra WHERE clauses. @@ -89,7 +87,6 @@ def _get_columns( return get_current_session().execute(query).mappings().first() @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_where(cls, *clauses, order_by: tuple[str | InstrumentedAttribute] | None = None) -> Iterable[Self]: order_by = order_by or cls._default_order_by query = select(cls).where(*clauses).order_by(*order_by) @@ -150,12 +147,6 @@ def _paginate( def has_running_process(cls, ids: list[str]) -> bool: raise NotImplementedError - @classmethod - def delete_where(cls, *clauses) -> None: - query = delete(cls).where(*clauses) - db_session = get_current_session() - db_session.execute(query) - @classmethod def is_in_use(cls, ids: list[str]) -> bool: raise NotImplementedError diff --git a/testgen/common/models/hygiene_issue.py b/testgen/common/models/hygiene_issue.py index 497c8180..c7683479 100644 --- a/testgen/common/models/hygiene_issue.py +++ b/testgen/common/models/hygiene_issue.py @@ -2,7 +2,6 @@ from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime -from enum import StrEnum from typing import Self from uuid import UUID, uuid4 @@ -12,6 +11,7 @@ from sqlalchemy.orm import aliased, relationship from sqlalchemy.sql.functions import func +from testgen.common.enums import Disposition from testgen.common.models import Base, get_current_session from testgen.common.models.entity import Entity from testgen.common.models.job_execution import JobExecution @@ -22,28 +22,6 @@ PII_RISK_RE = re.compile(r"Risk: (MODERATE|HIGH),") -class Disposition(StrEnum): - """Stored disposition values for ``profile_anomaly_results.disposition`` and - ``test_results.disposition``. The user-facing label for ``INACTIVE`` is "Muted".""" - CONFIRMED = "Confirmed" - DISMISSED = "Dismissed" - INACTIVE = "Inactive" - - -class IssueLikelihood(StrEnum): - """Stored ``profile_anomaly_types.issue_likelihood`` values.""" - DEFINITE = "Definite" - LIKELY = "Likely" - POSSIBLE = "Possible" - POTENTIAL_PII = "Potential PII" - - -class PiiRisk(StrEnum): - """Risk level extracted from PII issue ``detail`` strings via ``priority`` hybrid.""" - HIGH = "High" - MODERATE = "Moderate" - - @dataclass class IssueLikelihoodCounts: """Counts of hygiene issues by likelihood category, with dismissed/inactive separated.""" diff --git a/testgen/common/models/job_execution.py b/testgen/common/models/job_execution.py index 49aa67b9..8d8160c2 100644 --- a/testgen/common/models/job_execution.py +++ b/testgen/common/models/job_execution.py @@ -1,25 +1,24 @@ import logging from datetime import UTC, datetime -from enum import StrEnum -from typing import Any, Self +from typing import Any, ClassVar, Self from uuid import UUID, uuid4 -from sqlalchemy import Column, String, Text, case, func, select, text, update +from sqlalchemy import Column, String, Text, case, delete, func, select, text, update from sqlalchemy.dialects import postgresql -from testgen.common.models import Base, get_current_session +from testgen.common.enums import JobKey, JobSource, JobStatus +from testgen.common.models import Base, database_session, get_current_session LOG = logging.getLogger("testgen") -class JobStatus(StrEnum): - PENDING = "pending" - CLAIMED = "claimed" - RUNNING = "running" - COMPLETED = "completed" - ERROR = "error" - CANCEL_REQUESTED = "cancel_requested" - CANCELED = "canceled" +# Job kinds that are externally triggerable. Internal kinds (run-score-update, +# recalculate-project-scores, ...) are absent and filtered out by construction. +PUBLIC_JOB_KEYS: frozenset[JobKey] = frozenset({ + JobKey.run_profile, + JobKey.run_tests, + JobKey.run_test_generation, +}) _VALID_TRANSITIONS: dict[JobStatus, frozenset[JobStatus]] = { @@ -36,9 +35,8 @@ class JobExecution(Base): id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4) job_key: str = Column(String(100), nullable=False) - # args and kwargs are internal dispatch details passed to the job handler. - # Do not query or filter on them — external code should not depend on their structure. - args: list[Any] = Column(postgresql.JSONB, nullable=False, default=list, server_default=text("'[]'::jsonb")) + # kwargs is the internal dispatch payload passed to the job handler. + # Do not query or filter on it — external code should not depend on its structure. kwargs: dict[str, Any] = Column(postgresql.JSONB, nullable=False, default=dict, server_default=text("'{}'::jsonb")) source: str = Column(String(20), nullable=False) status: str = Column(String(20), nullable=False, default=JobStatus.PENDING, server_default=text("'pending'")) @@ -55,7 +53,7 @@ def submit( cls, job_key: str, kwargs: dict[str, Any], - source: str, + source: JobSource, project_code: str, job_schedule_id: UUID | None = None, ) -> Self: @@ -101,6 +99,39 @@ def claim_actionable(cls, limit: int = 5) -> list[Self]: LOG.info("Claimed %d pending job execution(s)", claimed) return rows + _ACTIVE_STATUSES: ClassVar[list[JobStatus]] = [ + JobStatus.PENDING, JobStatus.CLAIMED, JobStatus.RUNNING, JobStatus.CANCEL_REQUESTED, + ] + + @classmethod + def select_active_by_kwargs( + cls, + project_code: str, + job_key: str, + kwargs_match: dict[str, str | list[str]], + statuses: list[JobStatus] | None = None, + ) -> list[Self]: + """Find JE rows whose ``kwargs`` JSONB matches the given (key, value) pairs. + + Values may be a single string or a list of strings (which becomes an ``IN`` filter). + Defaults to active (non-terminal) statuses. + """ + statuses = statuses or cls._ACTIVE_STATUSES + query = select(cls).where( + cls.project_code == project_code, + cls.job_key == job_key, + cls.status.in_(statuses), + ) + for k, v in kwargs_match.items(): + if isinstance(v, list): + if not v: + return [] + query = query.where(cls.kwargs[k].astext.in_([str(x) for x in v])) + else: + query = query.where(cls.kwargs[k].astext == str(v)) + query = query.order_by(cls.created_at.desc()) + return list(get_current_session().scalars(query).all()) + @classmethod def find_stale(cls) -> list[Self]: """Return job executions left in non-terminal states from a previous process.""" @@ -117,6 +148,42 @@ def get(cls, execution_id: UUID) -> Self | None: session = get_current_session() return session.get(cls, execution_id) + @classmethod + def delete_older_than( + cls, + cutoff: datetime, + project_code: str, + protected_ids: set[UUID], + batch_size: int = 1000, + ) -> int: + """Batched delete of terminal-state job executions older than cutoff for + the given project, excluding protected ids. Returns total rows deleted. + + Skips rows in non-terminal states (pending/claimed/running/cancel_requested) — + those represent live work and must not be removed regardless of age. + + Each batch runs in its own transaction (committed before the next batch + is selected), so locks on job_executions are released between batches + and WAL growth stays bounded for large sweeps. + """ + where_clauses = [ + cls.project_code == project_code, + cls.completed_at < cutoff, + cls.status.in_([JobStatus.COMPLETED, JobStatus.ERROR, JobStatus.CANCELED]), + ] + if protected_ids: + where_clauses.append(cls.id.notin_(protected_ids)) + + total = 0 + while True: + with database_session() as session: + ids = session.scalars(select(cls.id).where(*where_clauses).limit(batch_size)).all() + if not ids: + break + session.execute(delete(cls).where(cls.id.in_(ids))) + total += len(ids) + return total + @classmethod def list_for_project( cls, diff --git a/testgen/common/models/notification_settings.py b/testgen/common/models/notification_settings.py index e15349f4..90063422 100644 --- a/testgen/common/models/notification_settings.py +++ b/testgen/common/models/notification_settings.py @@ -1,6 +1,7 @@ import enum import re from collections.abc import Iterable +from dataclasses import dataclass from decimal import Decimal from typing import ClassVar, Generic, Self, TypeVar from uuid import UUID, uuid4 @@ -8,6 +9,7 @@ from sqlalchemy import Boolean, Column, Enum, ForeignKey, String, and_, or_, select from sqlalchemy.dialects import postgresql from sqlalchemy.sql import Select +from sqlalchemy.sql.elements import ColumnElement from testgen.common.models import get_current_session from testgen.common.models.custom_types import JSON_TYPE @@ -22,6 +24,17 @@ TriggerT = TypeVar("TriggerT", bound=Enum) +_EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") + + +def is_valid_email(value: str) -> bool: + """Return whether ``value`` is a well-formed email address. + + Single source of truth for recipient validation, shared by the model's + ``validate()`` and the MCP layer's batch recipient check. + """ + return bool(_EMAIL_REGEX.match(value)) + class TestRunNotificationTrigger(enum.Enum): always = "always" @@ -51,6 +64,27 @@ class NotificationSettingsValidationError(Exception): pass +@dataclass +class NotificationSummary: + """Row shape for paginated ``NotificationSettings.list_for_*`` queries. + + Field order matches the SELECT projection in the ``list_for_*`` methods. + ``settings`` keeps the raw JSONB blob so event-specific values (``trigger``, + ``total_threshold``, ``cde_threshold``, ``table_name``) can be read by the + consumer's format helpers without forking the dataclass per event type. + """ + + id: UUID + project_code: str + event: NotificationEvent + enabled: bool + recipients: list[str] + test_suite_id: UUID | None + table_group_id: UUID | None + score_definition_id: UUID | None + settings: dict + + class NotificationSettings(Entity): __tablename__ = "notification_settings" @@ -87,6 +121,17 @@ class NotificationSettings(Entity): "polymorphic_identity": "base", } + @classmethod + def _scope_subquery(cls, entity, rel_col, id_value) -> ColumnElement[bool]: + """Where-clause: rows scoped to ``entity.id == id_value`` plus project-wide rows + (``rel_col IS NULL``) for the same project. Used by both the streaming + ``select()`` and the paginated ``list_for_*`` methods. + """ + return and_( + cls.project_code.in_(select(entity.project_code).where(entity.id == id_value)), + or_(rel_col == id_value, rel_col.is_(None)), + ) + @classmethod def _base_select_query( cls, @@ -94,9 +139,9 @@ def _base_select_query( enabled: bool | SENTINEL_TYPE = SENTINEL, event: NotificationEvent | SENTINEL_TYPE = SENTINEL, project_code: str | SENTINEL_TYPE = SENTINEL, - test_suite_id: UUID | None | SENTINEL_TYPE = SENTINEL, - table_group_id: UUID | None | SENTINEL_TYPE = SENTINEL, - score_definition_id: UUID | None | SENTINEL_TYPE = SENTINEL, + test_suite_id: UUID | SENTINEL_TYPE | None = SENTINEL, + table_group_id: UUID | SENTINEL_TYPE | None = SENTINEL, + score_definition_id: UUID | SENTINEL_TYPE | None = SENTINEL, ) -> Select: fk_count = len([None for fk in (test_suite_id, table_group_id, score_definition_id) if fk is not SENTINEL]) if fk_count > 1: @@ -112,18 +157,12 @@ def _base_select_query( if project_code is not SENTINEL: query = query.where(cls.project_code == project_code) - def _subquery_clauses(entity, rel_col, id_value): - return and_( - cls.project_code.in_(select(entity.project_code).where(entity.id == id_value)), - or_(rel_col == id_value, rel_col.is_(None)), - ) - if test_suite_id is not SENTINEL: - query = query.where(_subquery_clauses(TestSuite, cls.test_suite_id, test_suite_id)) + query = query.where(cls._scope_subquery(TestSuite, cls.test_suite_id, test_suite_id)) elif table_group_id is not SENTINEL: - query = query.where(_subquery_clauses(TableGroup, cls.table_group_id, table_group_id)) + query = query.where(cls._scope_subquery(TableGroup, cls.table_group_id, table_group_id)) elif score_definition_id is not SENTINEL: - query = query.where(_subquery_clauses(ScoreDefinition, cls.score_definition_id, score_definition_id)) + query = query.where(cls._scope_subquery(ScoreDefinition, cls.score_definition_id, score_definition_id)) return query @@ -134,9 +173,9 @@ def select( enabled: bool | SENTINEL_TYPE = SENTINEL, event: NotificationEvent | SENTINEL_TYPE = SENTINEL, project_code: str | SENTINEL_TYPE = SENTINEL, - test_suite_id: UUID | None | SENTINEL_TYPE = SENTINEL, - table_group_id: UUID | None | SENTINEL_TYPE = SENTINEL, - score_definition_id: UUID | None | SENTINEL_TYPE = SENTINEL, + test_suite_id: UUID | SENTINEL_TYPE | None = SENTINEL, + table_group_id: UUID | SENTINEL_TYPE | None = SENTINEL, + score_definition_id: UUID | SENTINEL_TYPE | None = SENTINEL, ) -> Iterable[Self]: query = cls._base_select_query( enabled=enabled, @@ -150,6 +189,102 @@ def select( ) return get_current_session().scalars(query) + @classmethod + def _list_query(cls, scope_clause) -> Select: + """Projection + ORDER BY shared by every ``list_for_*`` classmethod. + + ``scope_clause`` is the WHERE expression that narrows to a project or a parent + entity (and its project-wide siblings). Caller-supplied filters arrive as + ``*clauses`` in each ``list_for_*`` wrapper and are appended here. + """ + return ( + select( + cls.id.label("id"), + cls.project_code.label("project_code"), + cls.event.label("event"), + cls.enabled.label("enabled"), + cls.recipients.label("recipients"), + cls.test_suite_id.label("test_suite_id"), + cls.table_group_id.label("table_group_id"), + cls.score_definition_id.label("score_definition_id"), + cls.settings.label("settings"), + ) + .where(scope_clause) + .order_by( + cls.project_code, cls.event, cls.test_suite_id, + cls.table_group_id, cls.score_definition_id, cls.id, + ) + ) + + @classmethod + def list_for_projects( + cls, + project_codes: Iterable[str], + *clauses, + page: int = 1, + limit: int = 50, + ) -> tuple[list[NotificationSummary], int]: + """Paginated notifications across one or more projects.""" + query = cls._list_query(cls.project_code.in_(list(project_codes))) + if clauses: + query = query.where(*clauses) + return cls._paginate(query, page=page, limit=limit, data_class=NotificationSummary) + + @classmethod + def list_for_test_suite( + cls, + test_suite_id: UUID, + *clauses, + page: int = 1, + limit: int = 50, + ) -> tuple[list[NotificationSummary], int]: + """Paginated notifications whose ``test_suite_id`` exactly matches ``test_suite_id``. + + Use ``list_for_projects`` to also surface project-wide notifications (rows + with ``test_suite_id IS NULL``) — they're a different display concern from + narrowing to a specific suite. + """ + query = cls._list_query(cls.test_suite_id == test_suite_id) + if clauses: + query = query.where(*clauses) + return cls._paginate(query, page=page, limit=limit, data_class=NotificationSummary) + + @classmethod + def list_for_table_group( + cls, + table_group_id: UUID, + *clauses, + page: int = 1, + limit: int = 50, + ) -> tuple[list[NotificationSummary], int]: + """Paginated notifications whose ``table_group_id`` exactly matches ``table_group_id``. + + Use ``list_for_projects`` to also surface project-wide notifications (rows + with ``table_group_id IS NULL``). + """ + query = cls._list_query(cls.table_group_id == table_group_id) + if clauses: + query = query.where(*clauses) + return cls._paginate(query, page=page, limit=limit, data_class=NotificationSummary) + + @classmethod + def list_for_score_definition( + cls, + score_definition_id: UUID, + *clauses, + page: int = 1, + limit: int = 50, + ) -> tuple[list[NotificationSummary], int]: + """Paginated notifications whose ``score_definition_id`` exactly matches ``score_definition_id``. + + Use ``list_for_projects`` to also surface project-wide notifications (rows + with ``score_definition_id IS NULL``). + """ + query = cls._list_query(cls.score_definition_id == score_definition_id) + if clauses: + query = query.where(*clauses) + return cls._paginate(query, page=page, limit=limit, data_class=NotificationSummary) + def _validate_settings(self): pass @@ -157,7 +292,7 @@ def validate(self): if len(self.recipients) < 1: raise NotificationSettingsValidationError("At least one recipient must be defined.") for addr in self.recipients: - if not re.match(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", addr): + if not is_valid_email(addr): raise NotificationSettingsValidationError(f"Invalid email address: {addr}.") self._validate_settings() diff --git a/testgen/common/models/profile_result.py b/testgen/common/models/profile_result.py index 5826e63c..046cb015 100644 --- a/testgen/common/models/profile_result.py +++ b/testgen/common/models/profile_result.py @@ -1,6 +1,8 @@ +from collections.abc import Iterable +from datetime import datetime from uuid import UUID, uuid4 -from sqlalchemy import BigInteger, Column, ForeignKey, Integer, String, asc +from sqlalchemy import BigInteger, Column, Float, ForeignKey, Integer, Numeric, String, asc, desc from sqlalchemy.dialects import postgresql from testgen.common.models.entity import Entity @@ -18,9 +20,11 @@ class ProfileResult(Entity): position: int = Column(Integer) general_type: str | None = Column(String) + column_type: str | None = Column(String) + db_data_type: str | None = Column(String) functional_data_type: str | None = Column(String) + functional_table_type: str | None = Column(String) datatype_suggestion: str | None = Column(String) - db_data_type: str | None = Column(String) pii_flag: str | None = Column(String(50)) record_ct: int | None = Column(BigInteger) @@ -28,8 +32,118 @@ class ProfileResult(Entity): null_value_ct: int | None = Column(BigInteger) distinct_value_ct: int | None = Column(BigInteger) filled_value_ct: int | None = Column(BigInteger) + zero_value_ct: int | None = Column(BigInteger) + + # Alpha-specific + min_length: int | None = Column(Integer) + max_length: int | None = Column(Integer) + avg_length: float | None = Column(Float) + min_text: str | None = Column(String) + max_text: str | None = Column(String) + top_freq_values: str | None = Column(String) + top_patterns: str | None = Column(String) + distinct_std_value_ct: int | None = Column(BigInteger) + distinct_pattern_ct: int | None = Column(BigInteger) + std_pattern_match: str | None = Column(String) + mixed_case_ct: int | None = Column(BigInteger) + lower_case_ct: int | None = Column(BigInteger) + upper_case_ct: int | None = Column(BigInteger) + non_alpha_ct: int | None = Column(BigInteger) + includes_digit_ct: int | None = Column(BigInteger) + numeric_ct: int | None = Column(BigInteger) + date_ct: int | None = Column(BigInteger) + quoted_value_ct: int | None = Column(BigInteger) + lead_space_ct: int | None = Column(BigInteger) + embedded_space_ct: int | None = Column(BigInteger) + avg_embedded_spaces: float | None = Column(Float) + zero_length_ct: int | None = Column(BigInteger) + + # Numeric-specific + min_value: float | None = Column(Float) + min_value_over_0: float | None = Column(Float) + max_value: float | None = Column(Float) + avg_value: float | None = Column(Float) + stdev_value: float | None = Column(Float) + percentile_25: float | None = Column(Float) + percentile_50: float | None = Column(Float) + percentile_75: float | None = Column(Float) + fractional_sum: float | None = Column(Numeric(38, 6)) + + # Date-specific + min_date: datetime | None = Column(postgresql.TIMESTAMP) + max_date: datetime | None = Column(postgresql.TIMESTAMP) + before_1yr_date_ct: int | None = Column(BigInteger) + before_5yr_date_ct: int | None = Column(BigInteger) + before_20yr_date_ct: int | None = Column(BigInteger) + within_1yr_date_ct: int | None = Column(BigInteger) + within_1mo_date_ct: int | None = Column(BigInteger) + future_date_ct: int | None = Column(BigInteger) + + # Boolean-specific + boolean_true_ct: int | None = Column(BigInteger) + + # Per-column profiling failure (independent of run-level status) + query_error: str | None = Column(String) _default_order_by = (asc(position), asc(column_name)) - # Additional columns exist on this table (type-specific profile stats). - # They'll be mapped here as new MCP tools need them (L2+). + @classmethod + def get_for_column( + cls, + table_groups_id: UUID, + table_name: str, + column_name: str, + profiling_run_id: UUID | None = None, + ) -> "ProfileResult | None": + """Fetch the profile-results row for one column. + + Resolves to the explicit ``profiling_run_id`` when given, otherwise to the + column's latest profile run (via ``data_column_chars.last_complete_profile_run_id``). + Returns ``None`` when no row exists. + """ + # Local import: data_column imports ProfileResult at module top. + from testgen.common.models.data_column import DataColumnChars + + clauses = [ + cls.table_groups_id == table_groups_id, + cls.table_name == table_name, + cls.column_name == column_name, + ] + if profiling_run_id is not None: + clauses.append(cls.profile_run_id == profiling_run_id) + else: + latest = list( + DataColumnChars.select_where( + DataColumnChars.table_groups_id == table_groups_id, + DataColumnChars.table_name == table_name, + DataColumnChars.column_name == column_name, + ) + ) + if not latest or latest[0].last_complete_profile_run_id is None: + return None + clauses.append(cls.profile_run_id == latest[0].last_complete_profile_run_id) + + rows = list(cls.select_where(*clauses, order_by=(desc(cls.profile_run_id),))) + return rows[0] if rows else None + + @classmethod + def select_for_runs( + cls, + run_ids: Iterable[UUID], + table_name: str | None = None, + column_name: str | None = None, + ) -> list["ProfileResult"]: + """Fetch profile-results rows for a set of profiling runs in one query. + + Optional ``table_name`` and ``column_name`` filters narrow the result to one + entity (case-sensitive exact match). + """ + run_ids = list(run_ids) + if not run_ids: + return [] + clauses = [cls.profile_run_id.in_(run_ids)] + if table_name is not None: + clauses.append(cls.table_name == table_name) + if column_name is not None: + clauses.append(cls.column_name == column_name) + return list(cls.select_where(*clauses)) diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py index 05a5a94f..225c9942 100644 --- a/testgen/common/models/profiling_run.py +++ b/testgen/common/models/profiling_run.py @@ -4,17 +4,18 @@ from typing import ClassVar, Literal, NamedTuple, Self, TypedDict from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import BigInteger, Column, Float, Integer, String, desc, func, select, text, update from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.sql.expression import case -from testgen.common.models import get_current_session +from testgen.common.enums import Disposition, JobStatus +from testgen.common.models import database_session, get_current_session from testgen.common.models.connection import Connection -from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.entity import Entity, EntityMinimal +from testgen.common.models.job_execution import JobExecution +from testgen.common.models.profile_result import ProfileResult from testgen.common.models.project import Project from testgen.common.models.table_group import TableGroup from testgen.utils import is_uuid4 @@ -46,6 +47,8 @@ class ProfilingRunMinimal(EntityMinimal): class ProfilingRunSummary(EntityMinimal): job_execution_id: UUID profiling_run_id: UUID | None + job_schedule_id: UUID | None + project_code: str status: JobStatus created_at: datetime started_at: datetime | None @@ -83,6 +86,15 @@ def status_label(self) -> str: return self.STATUS_LABEL.get(self.status, self.status) +@dataclass +class ProfilingRunTableBreakdown(EntityMinimal): + schema_name: str + table_name: str + record_ct: int | None + column_ct: int + anomaly_ct: int + + class LatestProfilingRun(NamedTuple): id: str run_time: datetime @@ -135,7 +147,6 @@ def get_by_id_or_job(cls, identifier: UUID) -> Self | None: return get_current_session().scalars(query).first() @classmethod - @st.cache_data(show_spinner=False) def get_minimal(cls, run_id: str | UUID) -> ProfilingRunMinimal | None: if not is_uuid4(run_id): return None @@ -180,7 +191,6 @@ def get_latest_complete_je_id_for_table_group(cls, table_groups_id: UUID) -> UUI return get_current_session().scalar(query) @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[ProfilingRunMinimal]: @@ -196,14 +206,21 @@ def select_minimal_where( @classmethod def select_summary( cls, - project_code: str, + project_code: str | None = None, table_group_id: str | UUID | None = None, + job_execution_id: str | UUID | None = None, + statuses: list[JobStatus] | None = None, page: int = 1, page_size: int = 20, ) -> tuple[list[ProfilingRunSummary], int]: - if table_group_id and not is_uuid4(table_group_id): + if ( + (table_group_id and not is_uuid4(table_group_id)) + or (job_execution_id and not is_uuid4(job_execution_id)) + ): return [], 0 + # Pending JEs (no pr row) surface in project-scope queries via the LEFT JOIN, but + # not in table-group-scoped queries, since the WHERE filter requires tg to match. query = f""" WITH profile_anomalies AS ( SELECT profile_anomaly_results.profile_run_id, @@ -224,6 +241,8 @@ def select_summary( SELECT je.id AS job_execution_id, pr.id AS profiling_run_id, + je.job_schedule_id, + je.project_code, je.status, je.created_at, je.started_at, @@ -250,14 +269,18 @@ def select_summary( LEFT JOIN table_groups tg ON tg.id = pr.table_groups_id LEFT JOIN profile_anomalies pa ON pa.profile_run_id = pr.id WHERE je.job_key = 'run-profile' - AND je.project_code = :project_code + {" AND je.project_code = :project_code" if project_code else ""} {" AND tg.id = :table_group_id" if table_group_id else ""} + {" AND je.id = :job_execution_id" if job_execution_id else ""} + {" AND je.status IN :statuses" if statuses else ""} ORDER BY je.created_at DESC LIMIT :limit OFFSET :offset; """ params = { "project_code": project_code, - "table_group_id": table_group_id, + "table_group_id": str(table_group_id) if table_group_id else None, + "job_execution_id": str(job_execution_id) if job_execution_id else None, + "statuses": tuple(statuses) if statuses else (), "limit": page_size, "offset": (page - 1) * page_size, } @@ -267,6 +290,55 @@ def select_summary( total = items[0].total_count if items else 0 return items, total + @classmethod + def select_table_breakdown(cls, profiling_run_id: UUID) -> list[ProfilingRunTableBreakdown]: + """Per-table breakdown for a completed profiling run: schema, table, record/column count, anomaly count.""" + # HygieneIssue imports ProfilingRun, so this import has to stay function-local. + from testgen.common.models.hygiene_issue import HygieneIssue + + results_subq = ( + select( + ProfileResult.schema_name.label("schema_name"), + ProfileResult.table_name.label("table_name"), + func.max(ProfileResult.record_ct).label("record_ct"), + func.count(func.distinct(ProfileResult.column_name)).label("column_ct"), + ) + .where(ProfileResult.profile_run_id == profiling_run_id) + .group_by(ProfileResult.schema_name, ProfileResult.table_name) + .subquery() + ) + anomalies_subq = ( + select( + HygieneIssue.schema_name.label("schema_name"), + HygieneIssue.table_name.label("table_name"), + func.count().label("anomaly_ct"), + ) + .where( + HygieneIssue.profile_run_id == profiling_run_id, + func.coalesce(HygieneIssue.disposition, "Confirmed") == "Confirmed", + ) + .group_by(HygieneIssue.schema_name, HygieneIssue.table_name) + .subquery() + ) + query = ( + select( + results_subq.c.schema_name, + results_subq.c.table_name, + results_subq.c.record_ct, + results_subq.c.column_ct, + func.coalesce(anomalies_subq.c.anomaly_ct, 0).label("anomaly_ct"), + ) + .select_from(results_subq) + .outerjoin( + anomalies_subq, + (anomalies_subq.c.schema_name == results_subq.c.schema_name) + & (anomalies_subq.c.table_name == results_subq.c.table_name), + ) + .order_by(results_subq.c.schema_name, results_subq.c.table_name) + ) + rows = get_current_session().execute(query).mappings().all() + return [ProfilingRunTableBreakdown(**row) for row in rows] + _ACTIVE_JOB_STATUSES = (JobStatus.PENDING, JobStatus.CLAIMED, JobStatus.RUNNING, JobStatus.CANCEL_REQUESTED) @classmethod @@ -317,6 +389,89 @@ def cascade_delete(cls, ids: list[str]) -> None: db_session.execute(text(query), {"profiling_run_ids": tuple(ids)}) cls.delete_where(cls.id.in_(ids)) + @classmethod + def find_latest_per_table_group(cls, project_code: str) -> set[UUID]: + """Return the latest completed profiling run id per table group for the + project. + + Used by data retention to protect at least one run per scope. Profiling + is expensive and runs infrequently; downstream features (test + generation, freshness monitor generation, data catalog, MCP analysis + tools) read the most recent profiling result for a table group, so the + latest usable snapshot must survive even when its run_date is past the + retention cutoff. Failed and in-flight runs are skipped because they + don't expose result data for downstream consumers to read. + """ + rows = get_current_session().scalars( + select(cls.id) + .join(JobExecution, cls.job_execution_id == JobExecution.id) + .where( + cls.project_code == project_code, + JobExecution.status == JobStatus.COMPLETED, + ) + .order_by(cls.table_groups_id, cls.profiling_starttime.desc()) + .distinct(cls.table_groups_id) + ).all() + return set(rows) + + @classmethod + def delete_older_than( + cls, + cutoff: datetime, + project_code: str, + protected_ids: set[UUID], + batch_size: int = 1000, + dry_run: bool = False, + ) -> int: + """Batched delete of profiling runs (with cascading children) older than + cutoff for the given project, excluding protected ids. Returns total + parent rows deleted across all batches — or, with ``dry_run=True``, + the number that would be deleted (for retention preview, no writes). + + In-flight runs (JE in PENDING/CLAIMED/RUNNING/CANCEL_REQUESTED) are + never deleted — they may still be writing data. + + Each batch runs in its own transaction (committed before the next batch + is selected), so locks on profiling_runs / profile_results / etc. are + released between batches and WAL growth stays bounded for large sweeps. + """ + where_clauses = [ + cls.project_code == project_code, + cls.profiling_starttime < cutoff, + JobExecution.status.in_([JobStatus.COMPLETED, JobStatus.ERROR, JobStatus.CANCELED]), + ] + if protected_ids: + where_clauses.append(cls.id.notin_(protected_ids)) + + base_select = select(cls.id).join(JobExecution, cls.job_execution_id == JobExecution.id) + + if dry_run: + return get_current_session().scalar( + select(func.count()).select_from(base_select.where(*where_clauses).subquery()) + ) or 0 + + total = 0 + while True: + with database_session() as session: + ids = session.scalars(base_select.where(*where_clauses).limit(batch_size)).all() + if not ids: + break + cls.cascade_delete([str(i) for i in ids]) + total += len(ids) + return total + + @classmethod + def get_job_execution_ids(cls, profiling_run_ids: list[UUID]) -> dict[UUID, UUID | None]: + """Map profiling_run PKs to their job_execution_ids (batch lookup). + + Mirrors TestRun.get_job_execution_ids. + """ + if not profiling_run_ids: + return {} + query = select(cls.id, cls.job_execution_id).where(cls.id.in_(profiling_run_ids)) + rows = get_current_session().execute(query).all() + return {row.id: row.job_execution_id for row in rows} + def init_progress(self) -> None: self._progress = { "data_chars": {"label": "Refreshing data catalog"}, @@ -344,9 +499,41 @@ def get_previous(self) -> Self | None: .where( ProfilingRun.table_groups_id == self.table_groups_id, JobExecution.status == JobStatus.COMPLETED, - JobExecution.started_at < self.profiling_starttime, + ProfilingRun.profiling_starttime < self.profiling_starttime, ) - .order_by(desc(JobExecution.started_at)) + .order_by(desc(ProfilingRun.profiling_starttime)) .limit(1) ) return get_current_session().scalar(query) + + @classmethod + def list_recent_complete(cls, table_groups_id: UUID, limit: int) -> list[Self]: + """Return the most recent completed profiling runs for a table group, newest first.""" + query = ( + select(cls) + .join(JobExecution, cls.job_execution_id == JobExecution.id) + .where( + cls.table_groups_id == table_groups_id, + JobExecution.status == JobStatus.COMPLETED, + ) + .order_by(desc(JobExecution.started_at)) + .limit(limit) + ) + return list(get_current_session().scalars(query)) + + @classmethod + def count_confirmed_hygiene_issues(cls, run_ids: list[UUID]) -> dict[UUID, int]: + """Count confirmed hygiene issues per profiling run. Missing runs default to zero.""" + if not run_ids: + return {} + from testgen.common.models.hygiene_issue import HygieneIssue + + query = ( + select(HygieneIssue.profile_run_id, func.count()) + .where( + HygieneIssue.profile_run_id.in_(run_ids), + func.coalesce(HygieneIssue.disposition, Disposition.CONFIRMED) == Disposition.CONFIRMED, + ) + .group_by(HygieneIssue.profile_run_id) + ) + return {row[0]: row[1] for row in get_current_session().execute(query)} diff --git a/testgen/common/models/project.py b/testgen/common/models/project.py index 5c54872a..40d04974 100644 --- a/testgen/common/models/project.py +++ b/testgen/common/models/project.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from uuid import UUID, uuid4 -from sqlalchemy import Boolean, Column, String, asc, func, select, text +from sqlalchemy import Boolean, Column, Integer, String, asc, func, select, text from sqlalchemy.dialects import postgresql from testgen.common.models import get_current_session @@ -40,6 +40,8 @@ class Project(Entity): observability_api_url: str = Column(NullIfEmptyString) observability_api_key: str = Column(NullIfEmptyString) use_dq_score_weights: bool = Column(Boolean, default=True) + data_retention_enabled: bool = Column(Boolean, nullable=False, default=True) + data_retention_days: int | None = Column(Integer, default=180) _get_by = "project_code" _default_order_by = (asc(func.lower(project_name)),) diff --git a/testgen/common/models/scheduler.py b/testgen/common/models/scheduler.py index f094c4ab..d1c03657 100644 --- a/testgen/common/models/scheduler.py +++ b/testgen/common/models/scheduler.py @@ -3,14 +3,13 @@ from typing import Any, Self from uuid import UUID, uuid4 -import streamlit as st from cron_converter import Cron from sqlalchemy import Boolean, Column, String, cast, delete, func, select, update from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute +from testgen.common.enums import JobKey from testgen.common.models import Base, get_current_session -from testgen.common.models.entity import ENTITY_HASH_FUNCS from testgen.common.models.test_definition import TestDefinition from testgen.common.models.test_suite import TestSuite @@ -18,6 +17,13 @@ RUN_MONITORS_JOB_KEY = "run-monitors" RUN_PROFILE_JOB_KEY = "run-profile" +DEFAULT_DATA_CLEANUP_CRON = "0 1 * * *" +# Non-UI fallback for retention schedule timezone. UI surfaces should instead +# default to the user's browser timezone (resolved client-side). +DEFAULT_RETENTION_CRON_TZ = "UTC" + +SCHEDULABLE_JOB_KEYS: frozenset[JobKey] = frozenset({JobKey.run_profile, JobKey.run_tests}) + class JobSchedule(Base): __tablename__ = "job_schedules" @@ -26,20 +32,28 @@ class JobSchedule(Base): project_code: str = Column(String) key: str = Column(String, nullable=False) - args: list[Any] = Column(postgresql.JSONB, nullable=False, default=[]) kwargs: dict[str, Any] = Column(postgresql.JSONB, nullable=False, default={}) cron_expr: str = Column(String, nullable=False) cron_tz: str = Column(String, nullable=False) active: bool = Column(Boolean, default=True) @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def get(cls, *clauses) -> Self | None: query = select(cls).where(*clauses) return get_current_session().scalars(query).first() @classmethod def select_where(cls, *clauses, order_by: str | InstrumentedAttribute | None = None) -> Iterable[Self]: + query = select(cls).where(*clauses) + if order_by is not None: + query = query.order_by(order_by) + return get_current_session().scalars(query).all() + + @classmethod + def select_runnable(cls, *clauses, order_by: str | InstrumentedAttribute | None = None) -> Iterable[Self]: + """Schedules the scheduler should dispatch: active rows, and (for test/monitor runs) + only when the linked test suite has at least one test definition. + """ test_job_keys = [RUN_TESTS_JOB_KEY, RUN_MONITORS_JOB_KEY] test_definitions_count = ( select(cls.id) @@ -73,7 +87,98 @@ def update_active(cls, job_id: str | UUID, active: bool) -> None: @classmethod def count(cls): return get_current_session().query(cls).count() - + + @classmethod + def list_for_project( + cls, + project_code: str, + *extra_filters, + key_filter: Iterable[JobKey] | None = None, + page: int = 1, + limit: int = 20, + ) -> tuple[list[Self], int]: + """List schedules for a project with optional key filter and pagination. + + Returns both active and paused rows. Defaults ``key_filter`` to + ``SCHEDULABLE_JOB_KEYS`` (``run_profile``, ``run_tests``); pass an explicit + ``key_filter`` to include other kinds. + """ + session = get_current_session() + keys = list(key_filter) if key_filter is not None else list(SCHEDULABLE_JOB_KEYS) + query = select(cls).where(cls.project_code == project_code, cls.key.in_(keys), *extra_filters) + total = session.scalar(select(func.count()).select_from(query.subquery())) + items = session.scalars(query.order_by(cls.key, cls.id).offset((page - 1) * limit).limit(limit)).all() + return list(items), total or 0 + + @classmethod + def select_active_by_kwargs( + cls, + project_code: str, + key: str, + kwargs_match: dict[str, str | list[str]], + ) -> list[Self]: + """Find active schedules whose ``kwargs`` JSONB matches the given (key, value) pairs. + + Values may be a single string or a list of strings (which becomes an ``IN`` filter). + """ + query = select(cls).where( + cls.project_code == project_code, + cls.key == key, + cls.active.is_(True), + ) + for k, v in kwargs_match.items(): + if isinstance(v, list): + if not v: + return [] + query = query.where(cls.kwargs[k].astext.in_([str(x) for x in v])) + else: + query = query.where(cls.kwargs[k].astext == str(v)) + return list(get_current_session().scalars(query).all()) + + @classmethod + def upsert_for_retention( + cls, + project_code: str, + retention_days: int, + cron_expr: str, + cron_tz: str, + ) -> Self: + """Create or update the data-retention schedule for a project. + + Idempotent — safe to call on project creation and on every retention + settings save. Uniquely keyed by (project_code, JobKey.run_data_cleanup). + """ + session = get_current_session() + schedule = session.scalars( + select(cls).where(cls.project_code == project_code, cls.key == JobKey.run_data_cleanup) + ).first() + kwargs = {"project_code": project_code, "retention_days": retention_days} + if schedule: + schedule.kwargs = kwargs + schedule.cron_expr = cron_expr + schedule.cron_tz = cron_tz + schedule.active = True + else: + schedule = cls( + project_code=project_code, + key=JobKey.run_data_cleanup, + kwargs=kwargs, + cron_expr=cron_expr, + cron_tz=cron_tz, + active=True, + ) + session.add(schedule) + return schedule + + @classmethod + def delete_for_retention(cls, project_code: str) -> None: + """Remove the data-retention schedule for a project (when retention is + disabled or the project is deleted). + """ + get_current_session().execute( + delete(cls).where(cls.project_code == project_code, cls.key == JobKey.run_data_cleanup) + ) + def get_sample_triggering_timestamps(self, n=3) -> list[datetime]: schedule = Cron(cron_string=self.cron_expr).schedule(timezone_str=self.cron_tz) return [schedule.next() for _ in range(n)] @@ -81,7 +186,7 @@ def get_sample_triggering_timestamps(self, n=3) -> list[datetime]: @property def cron_tz_str(self) -> str: return self.cron_tz.replace("_", " ") - + def save(self) -> None: db_session = get_current_session() db_session.add(self) diff --git a/testgen/common/models/scores.py b/testgen/common/models/scores.py index b5fc9545..4dfb4c10 100644 --- a/testgen/common/models/scores.py +++ b/testgen/common/models/scores.py @@ -12,12 +12,29 @@ from typing import Literal, Self, TypedDict from uuid import UUID, uuid4 -from sqlalchemy import Boolean, Column, DateTime, Enum, Float, ForeignKey, Integer, String, delete, func, select, text +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Enum, + Float, + ForeignKey, + Integer, + String, + column, + delete, + func, + or_, + select, + table, + text, + tuple_, +) from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import aliased, attributes, relationship +from sqlalchemy.orm import aliased, attributes, joinedload, relationship from testgen.common import read_template_sql_file -from testgen.common.models import Base, get_current_session +from testgen.common.models import Base, database_session, get_current_session from testgen.utils import is_uuid4 SCORE_CATEGORIES = [ @@ -140,6 +157,85 @@ def get(cls, id_: str) -> Self | None: definition = db_session.scalars(query).first() return definition + @classmethod + def names_by_id(cls, ids: Iterable[UUID]) -> dict[UUID, str]: + """Return ``{id: name}`` for the given scorecard IDs in a single query. + + IDs with no matching scorecard are omitted. Empty input yields ``{}`` + without touching the database. + """ + ids = list(ids) + if not ids: + return {} + query = select(cls.id, cls.name).where(cls.id.in_(ids)) + return {row.id: row.name for row in get_current_session().execute(query).all()} + + @classmethod + def list_with_table_group_targets( + cls, + project_code: str, + ) -> list[tuple[UUID, str, list[str]]]: + """Return all scorecards in the project, each paired with the list of + `table_groups_name` values their criteria reference. + + Walks both root filters (`criteria.filters`) and the `next_filter` chain + via a recursive CTE. A scorecard with zero name filters has an empty + list; multiple are returned in chain order. + + Single query. Does NOT eagerly load the criteria/filter ORM objects — + the caller gets only (id, name, target names). Used by the MCP + inventory tool to surface scorecard IDs under each table group. + """ + # Seed: root filters joined through criteria for the project's definitions. + seed = ( + select( + ScoreDefinitionCriteria.definition_id.label("definition_id"), + ScoreDefinitionFilter.field.label("field"), + ScoreDefinitionFilter.value.label("value"), + ScoreDefinitionFilter.next_filter_id.label("next_filter_id"), + ) + .select_from(ScoreDefinitionCriteria) + .join(ScoreDefinitionFilter, ScoreDefinitionFilter.criteria_id == ScoreDefinitionCriteria.id) + .join(ScoreDefinition, ScoreDefinition.id == ScoreDefinitionCriteria.definition_id) + .where(ScoreDefinition.project_code == project_code) + .cte("filter_walk", recursive=True) + ) + # Recursive step: follow next_filter_id to walk the chain. + chain = aliased(ScoreDefinitionFilter) + filter_walk = seed.union_all( + select( + seed.c.definition_id, + chain.field, + chain.value, + chain.next_filter_id, + ) + .select_from(seed) + .join(chain, chain.id == seed.c.next_filter_id) + ) + + tg_names = ( + func.array_agg(filter_walk.c.value) + .filter(filter_walk.c.field == "table_groups_name") + .label("tg_names") + ) + query = ( + select(ScoreDefinition.id, ScoreDefinition.name, tg_names) + .select_from(ScoreDefinition) + .outerjoin(filter_walk, filter_walk.c.definition_id == ScoreDefinition.id) + .where(ScoreDefinition.project_code == project_code) + .group_by(ScoreDefinition.id, ScoreDefinition.name) + .order_by(ScoreDefinition.name) + ) + rows = get_current_session().execute(query).all() + # Dedupe tg_names: a mode-2 scorecard with N chains under the same + # table_groups_name would otherwise list the name N times, causing the + # inventory tool to render the scorecard once per chain. dict.fromkeys + # preserves first-seen order. + return [ + (row.id, row.name, list(dict.fromkeys(row.tg_names)) if row.tg_names else []) + for row in rows + ] + @classmethod def all( cls, @@ -192,6 +288,40 @@ def all( return definitions + @classmethod + def list_for_project( + cls, + project_code: str, + page: int = 1, + limit: int = 20, + ) -> tuple[list[Self], int]: + """Paginated list of scorecards in a project. + + Returns ORM objects with ``criteria`` eager-loaded so callers can walk + the filter chain without firing extra queries. ``results`` is already + ``lazy="joined"`` and rides along automatically — feeds + ``as_cached_score_card()``. + """ + session = get_current_session() + base_filter = ScoreDefinition.project_code == project_code + + total = session.scalar( + select(func.count()).select_from( + select(ScoreDefinition.id).where(base_filter).subquery() + ) + ) or 0 + + query = ( + select(ScoreDefinition) + .options(joinedload(ScoreDefinition.criteria)) + .where(base_filter) + .order_by(ScoreDefinition.name) + .offset((page - 1) * limit) + .limit(limit) + ) + rows = session.scalars(query).unique().all() + return list(rows), total + def save(self) -> None: db_session = get_current_session() db_session.add(self) @@ -259,6 +389,9 @@ def as_score_card(self) -> ScoreCard: ).replace("{filters}", filters)) ).mappings().first() or {} + cde_only_categories = self.cde_score and not self.total_score + category_filters = " AND ".join(self._get_raw_query_filters(cde_only=cde_only_categories)) + categories_scores = [] if (category := self.category): categories_scores = [ @@ -267,7 +400,7 @@ def as_score_card(self) -> ScoreCard: text(read_template_sql_file( categories_query_template_file, sub_directory="score_cards", - ).replace("{category}", category.value).replace("{filters}", filters)) + ).replace("{category}", category.value).replace("{filters}", category_filters)) ).mappings().all() ] @@ -473,6 +606,29 @@ def recalculate_scores_history(self) -> None: self.history = list(current_history.values()) + def get_overall_issue_ct(self) -> int: + """Sum of hygiene + test issue counts under this definition's filters. + + Reuses the same filter machinery as `as_score_card` so the rolled-up + count matches the score that call returns. + """ + if not self.criteria.has_filters(): + return 0 + + where_clause = text(" AND ".join(self._get_raw_query_filters())) + session = get_current_session() + + def _sum_issue_ct(view_name: str) -> int: + view = table(view_name, column("issue_ct")) + return int(session.execute( + select(func.coalesce(func.sum(view.c.issue_ct), 0)).where(where_clause) + ).scalar() or 0) + + return ( + _sum_issue_ct("v_dq_profile_scoring_latest_by_column") + + _sum_issue_ct("v_dq_test_scoring_latest_by_column") + ) + def _get_raw_query_filters(self, cde_only: bool = False, prefix: str | None = None) -> list[str]: extra_filters = [ f"{prefix or ''}project_code = '{self.project_code}'" @@ -728,7 +884,6 @@ def add_as_cutoff(self): Query templates: add_latest_runs.sql """ - # ruff: noqa: RUF027 query = read_template_sql_file("add_latest_runs.sql", sub_directory="score_cards") params = { "project_code": self.definition.project_code, @@ -738,6 +893,142 @@ def add_as_cutoff(self): session = get_current_session() session.execute(text(query), params) + @classmethod + def delete_older_than( + cls, + cutoff: datetime, + project_code: str, + protected_keys: set[tuple[UUID, datetime]], + batch_size: int = 1000, + ) -> int: + """Batched delete of score-history entries older than cutoff for the + given project, excluding entries whose (definition_id, last_run_time) + is in protected_keys. Preserves snapshots tied to protected latest + runs so the score-trend chart stays consistent with the run. + + Each batch runs in its own transaction (committed before the next batch + is selected) so locks and WAL growth stay bounded for large sweeps. + """ + project_def_ids = select(ScoreDefinition.id).where( + ScoreDefinition.project_code == project_code + ).scalar_subquery() + + where_clauses = [ + cls.last_run_time < cutoff, + cls.definition_id.in_(project_def_ids), + ] + if protected_keys: + where_clauses.append( + tuple_(cls.definition_id, cls.last_run_time).notin_(list(protected_keys)) + ) + + total = 0 + while True: + with database_session() as session: + keys = session.execute( + select(cls.definition_id, cls.last_run_time) + .where(*where_clauses) + .distinct() + .limit(batch_size) + ).all() + if not keys: + break + result = session.execute( + delete(cls).where( + tuple_(cls.definition_id, cls.last_run_time).in_(list(keys)) + ) + ) + total += result.rowcount or 0 + return total + + +class ScoreHistoryLatestRun(Base): + """Snapshot mapping rows: for a score definition + cutoff time, holds the + latest profiling/test run ids active at that point. Score-trend snapshots + in score_definition_results_history correlate to runs through this table. + + The underlying table has no real primary key — the composite declared here + captures the semantic uniqueness (one row per definition x cutoff x scope). + """ + + __tablename__ = "score_history_latest_runs" + + definition_id: UUID = Column(postgresql.UUID(as_uuid=True), nullable=False, primary_key=True) + score_history_cutoff_time: datetime = Column(DateTime(timezone=False), nullable=False, primary_key=True) + table_groups_id: UUID | None = Column(postgresql.UUID(as_uuid=True), nullable=True, primary_key=True) + last_profiling_run_id: UUID | None = Column(postgresql.UUID(as_uuid=True), nullable=True) + test_suite_id: UUID | None = Column(postgresql.UUID(as_uuid=True), nullable=True, primary_key=True) + last_test_run_id: UUID | None = Column(postgresql.UUID(as_uuid=True), nullable=True) + + @classmethod + def find_protected_keys( + cls, + protected_profiling_ids: set[UUID], + protected_test_run_ids: set[UUID], + ) -> set[tuple[UUID, datetime]]: + """Return (definition_id, score_history_cutoff_time) pairs that map to + any protected profiling or test run. Used to preserve score-trend + snapshots tied to runs that retention is keeping alive. + """ + if not protected_profiling_ids and not protected_test_run_ids: + return set() + clauses = [] + if protected_profiling_ids: + clauses.append(cls.last_profiling_run_id.in_(protected_profiling_ids)) + if protected_test_run_ids: + clauses.append(cls.last_test_run_id.in_(protected_test_run_ids)) + rows = get_current_session().execute( + select(cls.definition_id, cls.score_history_cutoff_time).where(or_(*clauses)).distinct() + ).all() + return {tuple(row) for row in rows} + + @classmethod + def delete_older_than( + cls, + cutoff: datetime, + project_code: str, + protected_keys: set[tuple[UUID, datetime]], + batch_size: int = 1000, + ) -> int: + """Batched delete of mapping rows older than cutoff for the given + project, excluding rows whose (definition_id, cutoff_time) is in + protected_keys. + + Each batch runs in its own transaction (committed before the next batch + is selected) so locks and WAL growth stay bounded for large sweeps. + """ + project_def_ids = select(ScoreDefinition.id).where( + ScoreDefinition.project_code == project_code + ).scalar_subquery() + + where_clauses = [ + cls.score_history_cutoff_time < cutoff, + cls.definition_id.in_(project_def_ids), + ] + if protected_keys: + where_clauses.append( + tuple_(cls.definition_id, cls.score_history_cutoff_time).notin_(list(protected_keys)) + ) + + total = 0 + while True: + with database_session() as session: + keys = session.execute( + select(cls.definition_id, cls.score_history_cutoff_time) + .where(*where_clauses) + .distinct() + .limit(batch_size) + ).all() + if not keys: + break + result = session.execute( + delete(cls).where( + tuple_(cls.definition_id, cls.score_history_cutoff_time).in_(list(keys)) + ) + ) + total += result.rowcount or 0 + return total + class ScoreCard(TypedDict): id: str diff --git a/testgen/common/models/stg_data_chars_update.py b/testgen/common/models/stg_data_chars_update.py new file mode 100644 index 00000000..4da2c6d1 --- /dev/null +++ b/testgen/common/models/stg_data_chars_update.py @@ -0,0 +1,32 @@ +"""ORM model for the stg_data_chars_updates staging table. + +Cleaned per-run by `data_chars_staging_delete.sql`; this model exists for +data retention to age out orphans left by failed/interrupted profiling runs. +Has no project_code column — project scope is enforced via a subquery on +table_groups. PK declared is cosmetic; only WHERE columns are needed for +bulk DELETE. +""" + +from datetime import datetime +from uuid import UUID + +from sqlalchemy import Column, String, select +from sqlalchemy.dialects import postgresql + +from testgen.common.models import Base +from testgen.common.models.table_group import TableGroup + + +class StgDataCharsUpdate(Base): + __tablename__ = "stg_data_chars_updates" + + table_groups_id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, nullable=False) + run_date: datetime = Column(postgresql.TIMESTAMP, primary_key=True, nullable=False) + schema_name: str = Column(String(120), primary_key=True) + table_name: str = Column(String(120), primary_key=True) + column_name: str = Column(String(120), primary_key=True) + + @classmethod + def delete_older_than(cls, cutoff: datetime, project_code: str) -> int: + project_table_groups = select(TableGroup.id).where(TableGroup.project_code == project_code) + return cls.delete_where(cls.run_date < cutoff, cls.table_groups_id.in_(project_table_groups)) diff --git a/testgen/common/models/stg_functional_table_update.py b/testgen/common/models/stg_functional_table_update.py new file mode 100644 index 00000000..783f949f --- /dev/null +++ b/testgen/common/models/stg_functional_table_update.py @@ -0,0 +1,26 @@ +"""ORM model for the stg_functional_table_updates staging table. + +Unlike the other staging tables, this one has no per-run delete anywhere in +the codebase — rows accumulate indefinitely. Data retention is the primary +cleanup. PK declared is cosmetic; only WHERE columns are needed for bulk DELETE. +""" + +from datetime import datetime + +from sqlalchemy import Column, String +from sqlalchemy.dialects import postgresql + +from testgen.common.models import Base + + +class StgFunctionalTableUpdate(Base): + __tablename__ = "stg_functional_table_updates" + + project_code: str = Column(String(30), primary_key=True, nullable=False) + run_date: datetime = Column(postgresql.TIMESTAMP, primary_key=True, nullable=False) + schema_name: str = Column(String(50), primary_key=True) + table_name: str = Column(String(120), primary_key=True) + + @classmethod + def delete_older_than(cls, cutoff: datetime, project_code: str) -> int: + return cls.delete_where(cls.run_date < cutoff, cls.project_code == project_code) diff --git a/testgen/common/models/stg_secondary_profile_update.py b/testgen/common/models/stg_secondary_profile_update.py new file mode 100644 index 00000000..68705362 --- /dev/null +++ b/testgen/common/models/stg_secondary_profile_update.py @@ -0,0 +1,28 @@ +"""ORM model for the stg_secondary_profile_updates staging table. + +Cleaned per-run by `secondary_profiling_delete.sql`; this model exists for +data retention to age out orphans left by failed/interrupted profiling runs. +The PK declared here is cosmetic — only the WHERE columns are needed for the +bulk DELETE. See `staging` package docs in `run_data_cleanup.py` for context. +""" + +from datetime import datetime + +from sqlalchemy import Column, String +from sqlalchemy.dialects import postgresql + +from testgen.common.models import Base + + +class StgSecondaryProfileUpdate(Base): + __tablename__ = "stg_secondary_profile_updates" + + project_code: str = Column(String(30), primary_key=True, nullable=False) + run_date: datetime = Column(postgresql.TIMESTAMP, primary_key=True, nullable=False) + schema_name: str = Column(String(50), primary_key=True) + table_name: str = Column(String(120), primary_key=True) + column_name: str = Column(String(120), primary_key=True) + + @classmethod + def delete_older_than(cls, cutoff: datetime, project_code: str) -> int: + return cls.delete_where(cls.run_date < cutoff, cls.project_code == project_code) diff --git a/testgen/common/models/stg_test_definition_update.py b/testgen/common/models/stg_test_definition_update.py new file mode 100644 index 00000000..4da09f52 --- /dev/null +++ b/testgen/common/models/stg_test_definition_update.py @@ -0,0 +1,30 @@ +"""ORM model for the stg_test_definition_updates staging table. + +Cleaned per-run by `delete_staging_test_definitions.sql`; this model exists +for data retention to age out orphans left by failed/interrupted prediction +runs. Has no project_code column — project scope is enforced via a subquery +on test_suites. PK declared is cosmetic; only WHERE columns are needed for +bulk DELETE. +""" + +from datetime import datetime +from uuid import UUID + +from sqlalchemy import Column, select +from sqlalchemy.dialects import postgresql + +from testgen.common.models import Base +from testgen.common.models.test_suite import TestSuite + + +class StgTestDefinitionUpdate(Base): + __tablename__ = "stg_test_definition_updates" + + test_suite_id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, nullable=False) + test_definition_id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, nullable=False) + run_date: datetime = Column(postgresql.TIMESTAMP, primary_key=True, nullable=False) + + @classmethod + def delete_older_than(cls, cutoff: datetime, project_code: str) -> int: + project_test_suites = select(TestSuite.id).where(TestSuite.project_code == project_code) + return cls.delete_where(cls.run_date < cutoff, cls.test_suite_id.in_(project_test_suites)) diff --git a/testgen/common/models/table_group.py b/testgen/common/models/table_group.py index 117e8983..251dcc46 100644 --- a/testgen/common/models/table_group.py +++ b/testgen/common/models/table_group.py @@ -3,14 +3,13 @@ from datetime import datetime from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import BigInteger, Boolean, Column, Float, ForeignKey, Integer, String, asc, func, text, update from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute from testgen.common.models import get_current_session from testgen.common.models.custom_types import NullIfEmptyString, YNString -from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.entity import Entity, EntityMinimal from testgen.common.models.scores import ScoreDefinition from testgen.common.models.test_suite import TestSuite from testgen.utils import is_uuid4 @@ -151,13 +150,11 @@ class TableGroup(Entity): ) @classmethod - @st.cache_data(show_spinner=False) def get_minimal(cls, id_: str | UUID) -> TableGroupMinimal | None: result = cls._get_columns(id_, cls._minimal_columns) return TableGroupMinimal(**result) if result else None @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[TableGroupMinimal]: diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py index 8740203b..ec9ddde0 100644 --- a/testgen/common/models/test_definition.py +++ b/testgen/common/models/test_definition.py @@ -1,11 +1,11 @@ from collections.abc import Iterable from dataclasses import dataclass -from datetime import datetime +from datetime import UTC, datetime +from enum import StrEnum from itertools import zip_longest from typing import ClassVar, Literal from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import ( Boolean, Column, @@ -27,7 +27,7 @@ from testgen.common.models import Base, get_current_session from testgen.common.models.custom_types import NullIfEmptyString, YNString, ZeroIfEmptyInteger -from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.entity import Entity, EntityMinimal from testgen.utils import is_uuid4 TestRunType = Literal["QUERY", "CAT", "METADATA"] @@ -35,6 +35,24 @@ TestRunStatus = Literal["Running", "Complete", "Error", "Cancelled"] +class Severity(StrEnum): + FAIL = "Fail" + WARNING = "Warning" + + +class InvalidTestDefinitionFields(ValueError): + """Aggregated field-level validation errors. ``errors``: ``dict[field_name, reason]``.""" + + def __init__(self, errors: dict[str, str]) -> None: + self.errors = errors + super().__init__("; ".join(f"{k}: {v}" for k, v in errors.items())) + + +def _is_blank(value: object) -> bool: + # NullIfEmptyString columns turn ``""`` into NULL on write — treat both as cleared. + return value is None or value == "" + + class ParamFieldsMixin: """Parsed access to default_parm_columns/prompts/help metadata. @@ -204,6 +222,28 @@ def select_summary_where(cls, *clauses) -> Iterable[TestTypeSummary]: return [TestTypeSummary(**row) for row in results] +def _required_fields_for(test_type: TestType) -> set[str]: + """Fields that must be present and non-empty for the given test type. + + - Column-scoped tests implicitly require ``column_name``. + - Test types with ``custom_query`` in ``param_columns`` require ``custom_query``. + - ``default_parm_required`` is a CSV of ``Y``/``N`` aligned with ``default_parm_columns``; + positions marked ``Y`` are required. + """ + required: set[str] = set() + if test_type.test_scope == "column": + required.add("column_name") + if "custom_query" in test_type.param_columns: + required.add("custom_query") + if test_type.default_parm_required and test_type.default_parm_columns: + flags = [v.strip().upper() for v in test_type.default_parm_required.split(",")] + columns = [c.strip() for c in test_type.default_parm_columns.split(",")] + for col, flag in zip(columns, flags, strict=False): + if flag == "Y": + required.add(col) + return required + + class TestDefinition(Entity): __tablename__ = "test_definitions" @@ -290,7 +330,6 @@ class TestDefinition(Entity): ) @classmethod - @st.cache_data(show_spinner=False) def get(cls, identifier: str | UUID) -> TestDefinitionSummary | None: if not is_uuid4(identifier): return None @@ -329,7 +368,6 @@ def get_for_project( return TestDefinitionSummary(**result) if result else None @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[TestDefinitionSummary]: @@ -343,7 +381,6 @@ def select_where( return [TestDefinitionSummary(**row) for row in results] @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[TestDefinitionMinimal]: @@ -395,8 +432,84 @@ def list_for_suite( query = query.order_by(*cls._default_order_by) return cls._paginate(query, page=page, limit=limit, data_class=TestDefinitionSummary) + @classmethod + def select_page( + cls, + *clauses, + order_by: tuple[str | InstrumentedAttribute] | None = None, + page: int = 1, + limit: int = 500, + ) -> tuple[list["TestDefinitionSummary"], int]: + select_columns = [ + getattr(cls, col, None) or getattr(TestType, col) if isinstance(col, str) else col + for col in cls._summary_columns + ] + query = ( + select(*select_columns) + .join(TestType, cls.test_type == TestType.test_type) + .where(*clauses) + .order_by(*(order_by or cls._default_order_by)) + ) + return cls._paginate(query, page=page, limit=limit, data_class=TestDefinitionSummary) + _yn_columns: ClassVar = {"test_active", "lock_refresh"} + # Fields editable on every test type regardless of param_columns. + EDITABLE_BASE_FIELDS: ClassVar[frozenset[str]] = frozenset({ + "test_active", "severity", "lock_refresh", "flagged", "test_description", + }) + + def editable_fields(self, test_type: TestType) -> set[str]: + """Fields a caller may set or change on this test definition under the given test type.""" + fields = self.EDITABLE_BASE_FIELDS | test_type.param_columns + # column_name is meaningful for column-scoped tests (the column under test) and + # custom-scoped tests (a "Test Focus" label). Other scopes don't use it. + if test_type.test_scope in ("column", "custom"): + fields = fields | {"column_name"} + # impact_dimension is overridable only for user-defined-semantic scopes + # (custom-scope = user-authored SQL; referential-scope = comparison-based tests). + # Other scopes have baked-in dimensions so the override doesn't apply. + if test_type.test_scope in ("custom", "referential"): + fields = fields | {"impact_dimension"} + return fields + + def validate(self, test_type: TestType) -> None: + """Validate the current state against the given test type. + + Raises :class:`InvalidTestDefinitionFields` with every offending field + and reason — callers see all problems at once. + """ + errors: dict[str, str] = {} + + if self.severity: + try: + Severity(self.severity) + except ValueError: + errors["severity"] = ( + f"must be `{Severity.FAIL.value}` or `{Severity.WARNING.value}` " + f"(got `{self.severity}`)" + ) + + # column_name applies to column-scoped tests (the column under test) and + # custom-scoped tests (a "Test Focus" label). Other scopes don't use it. + if test_type.test_scope not in ("column", "custom") and not _is_blank(self.column_name): + errors["column_name"] = ( + f"test type `{test_type.test_type}` has scope `{test_type.test_scope}`; " + f"column_name does not apply to this scope" + ) + + if not _is_blank(self.custom_query) and "custom_query" not in test_type.param_columns: + errors["custom_query"] = ( + f"test type `{test_type.test_type}` does not accept a custom query" + ) + + for required in _required_fields_for(test_type): + if _is_blank(getattr(self, required, None)): + errors[required] = f"required for test type `{test_type.test_type}`" + + if errors: + raise InvalidTestDefinitionFields(errors) + @classmethod def set_status_attribute( cls, @@ -560,16 +673,25 @@ class TestDefinitionNote(Base): updated_at: datetime = Column(postgresql.TIMESTAMP) @classmethod - def add_note(cls, test_definition_id: str | UUID, detail: str, username: str) -> None: + def add_note(cls, test_definition_id: str | UUID, detail: str, username: str) -> "TestDefinitionNote": + """Insert a note and return the persisted instance with ``id`` and ``created_at`` populated.""" db_session = get_current_session() - db_session.execute( - insert(cls).values(test_definition_id=test_definition_id, detail=detail, created_by=username) + note = cls( + test_definition_id=test_definition_id, + detail=detail, + created_by=username, + created_at=datetime.now(UTC).replace(tzinfo=None), ) + db_session.add(note) + db_session.flush() + return note @classmethod def update_note(cls, note_id: str | UUID, detail: str) -> None: db_session = get_current_session() - db_session.execute(update(cls).where(cls.id == note_id).values(detail=detail, updated_at=func.now())) + db_session.execute( + update(cls).where(cls.id == note_id).values(detail=detail, updated_at=datetime.now(UTC).replace(tzinfo=None)) + ) @classmethod def delete_note(cls, note_id: str | UUID) -> None: diff --git a/testgen/common/models/test_result.py b/testgen/common/models/test_result.py index 90538adf..c1d25e73 100644 --- a/testgen/common/models/test_result.py +++ b/testgen/common/models/test_result.py @@ -80,27 +80,27 @@ def failure_rate(self) -> float: @dataclass class DiffRow: - """One test definition's status across two runs for ``get_test_run_diff``.""" + """One test definition's status across two runs for ``compare_test_runs``.""" test_definition_id: UUID test_type: str test_name_short: str | None table_name: str | None column_names: str | None - status_a: TestResultStatus | None - status_b: TestResultStatus | None - measure_a: str | None - measure_b: str | None - threshold_a: str | None - threshold_b: str | None + status_baseline: TestResultStatus | None + status_target: TestResultStatus | None + measure_baseline: str | None + measure_target: str | None + threshold_baseline: str | None + threshold_target: str | None @dataclass class RunDiff: """Categorized diff between two test runs.""" - total_a: int - total_b: int + total_baseline: int + total_target: int regressions: list[DiffRow] = field(default_factory=list) improvements: list[DiffRow] = field(default_factory=list) persistent_failures: list[DiffRow] = field(default_factory=list) @@ -414,7 +414,7 @@ def failure_trend( ] @classmethod - def diff_with_details(cls, test_run_id_a: UUID, test_run_id_b: UUID) -> RunDiff: + def diff_with_details(cls, baseline_run_id: UUID, target_run_id: UUID) -> RunDiff: """Compare two runs by ``test_definition_id`` and return categorized diff rows.""" def _fetch(run_id: UUID) -> dict[UUID, dict]: @@ -448,41 +448,41 @@ def _fetch(run_id: UUID) -> dict[UUID, dict]: for row in get_current_session().execute(query) } - def _row(tid: UUID, info_a: dict | None, info_b: dict | None) -> DiffRow: - base = info_b or info_a # prefer B for display fields (test_type, table, column names) + def _row(tid: UUID, baseline_info: dict | None, target_info: dict | None) -> DiffRow: + base = target_info or baseline_info # prefer target for display fields (test_type, table, column names) return DiffRow( test_definition_id=tid, test_type=base["test_type"], test_name_short=base["test_name_short"], table_name=base["table_name"], column_names=base["column_names"], - status_a=info_a["status"] if info_a else None, - status_b=info_b["status"] if info_b else None, - measure_a=info_a["measure"] if info_a else None, - measure_b=info_b["measure"] if info_b else None, - threshold_a=info_a["threshold"] if info_a else None, - threshold_b=info_b["threshold"] if info_b else None, + status_baseline=baseline_info["status"] if baseline_info else None, + status_target=target_info["status"] if target_info else None, + measure_baseline=baseline_info["measure"] if baseline_info else None, + measure_target=target_info["measure"] if target_info else None, + threshold_baseline=baseline_info["threshold"] if baseline_info else None, + threshold_target=target_info["threshold"] if target_info else None, ) - results_a = _fetch(test_run_id_a) - results_b = _fetch(test_run_id_b) + baseline_results = _fetch(baseline_run_id) + target_results = _fetch(target_run_id) failing = {TestResultStatus.Failed, TestResultStatus.Warning} - diff = RunDiff(total_a=len(results_a), total_b=len(results_b)) + diff = RunDiff(total_baseline=len(baseline_results), total_target=len(target_results)) - for tid in results_a.keys() & results_b.keys(): - info_a, info_b = results_a[tid], results_b[tid] - row = _row(tid, info_a, info_b) - if info_a["status"] == TestResultStatus.Passed and info_b["status"] in failing: + for tid in baseline_results.keys() & target_results.keys(): + baseline_info, target_info = baseline_results[tid], target_results[tid] + row = _row(tid, baseline_info, target_info) + if baseline_info["status"] == TestResultStatus.Passed and target_info["status"] in failing: diff.regressions.append(row) - elif info_a["status"] in failing and info_b["status"] == TestResultStatus.Passed: + elif baseline_info["status"] in failing and target_info["status"] == TestResultStatus.Passed: diff.improvements.append(row) - elif info_a["status"] in failing and info_b["status"] in failing: + elif baseline_info["status"] in failing and target_info["status"] in failing: diff.persistent_failures.append(row) - for tid in results_b.keys() - results_a.keys(): - diff.new_tests.append(_row(tid, None, results_b[tid])) + for tid in target_results.keys() - baseline_results.keys(): + diff.new_tests.append(_row(tid, None, target_results[tid])) - for tid in results_a.keys() - results_b.keys(): - diff.removed_tests.append(_row(tid, results_a[tid], None)) + for tid in baseline_results.keys() - target_results.keys(): + diff.removed_tests.append(_row(tid, baseline_results[tid], None)) return diff diff --git a/testgen/common/models/test_run.py b/testgen/common/models/test_run.py index 7653f355..2bb9bc51 100644 --- a/testgen/common/models/test_run.py +++ b/testgen/common/models/test_run.py @@ -3,16 +3,16 @@ from typing import ClassVar, Literal, NamedTuple, Self, TypedDict from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import BigInteger, Column, Float, ForeignKey, Integer, String, Text, desc, func, select, text, update from sqlalchemy.dialects import postgresql from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.sql.expression import case -from testgen.common.models import get_current_session +from testgen.common.enums import JobStatus +from testgen.common.models import database_session, get_current_session from testgen.common.models.connection import Connection from testgen.common.models.entity import Entity, EntityMinimal -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.project import Project from testgen.common.models.table_group import TableGroup from testgen.common.models.test_result import TestResult, TestResultStatus @@ -47,6 +47,7 @@ class TestRunMinimal(EntityMinimal): class TestRunSummary(EntityMinimal): job_execution_id: UUID test_run_id: UUID | None + job_schedule_id: UUID | None status: JobStatus created_at: datetime started_at: datetime | None @@ -159,7 +160,6 @@ def get_job_execution_ids(cls, test_run_ids: list[UUID]) -> dict[UUID, UUID | No return {row.id: row.job_execution_id for row in rows} @classmethod - @st.cache_data(show_spinner=False) def get_minimal(cls, run_id: str | UUID) -> TestRunMinimal | None: if not is_uuid4(run_id): return None @@ -194,9 +194,9 @@ def get_previous(self) -> Self | None: .where( TestRun.test_suite_id == self.test_suite_id, JobExecution.status == JobStatus.COMPLETED, - JobExecution.started_at < self.test_starttime, + TestRun.test_starttime < self.test_starttime, ) - .order_by(desc(JobExecution.started_at)) + .order_by(desc(TestRun.test_starttime)) .limit(1) ) return get_current_session().scalar(query) @@ -218,6 +218,8 @@ def select_summary( table_group_id: str | None = None, test_suite_id: str | None = None, test_run_ids: list[str | UUID] | None = None, + job_execution_id: str | UUID | None = None, + statuses: list[JobStatus] | None = None, page: int = 1, page_size: int = 20, ) -> tuple[list[TestRunSummary], int]: @@ -225,9 +227,13 @@ def select_summary( (table_group_id and not is_uuid4(table_group_id)) or (test_suite_id and not is_uuid4(test_suite_id)) or (test_run_ids and not all(is_uuid4(run_id) for run_id in test_run_ids)) + or (job_execution_id and not is_uuid4(job_execution_id)) ): return [], 0 + # Pending JEs (no tr row) surface in project-scope queries — the LEFT JOIN to + # test_suites yields a NULL ts row that the ``ts.id IS NULL`` clause lets through — + # but not in suite/TG-scoped queries, since the WHERE filter requires ts to match. query = f""" WITH run_results AS ( SELECT test_run_id, @@ -249,6 +255,7 @@ def select_summary( SELECT je.id AS job_execution_id, tr.id AS test_run_id, + je.job_schedule_id, je.status, je.created_at, je.started_at, @@ -282,6 +289,8 @@ def select_summary( {" AND ts.table_groups_id = :table_group_id" if table_group_id else ""} {" AND ts.id = :test_suite_id" if test_suite_id else ""} {" AND tr.id IN :test_run_ids" if test_run_ids else ""} + {" AND je.id = :job_execution_id" if job_execution_id else ""} + {" AND je.status IN :statuses" if statuses else ""} ORDER BY je.created_at DESC LIMIT :limit OFFSET :offset; """ @@ -290,6 +299,8 @@ def select_summary( "table_group_id": table_group_id, "test_suite_id": test_suite_id, "test_run_ids": tuple(test_run_ids or []), + "job_execution_id": str(job_execution_id) if job_execution_id else None, + "statuses": tuple(statuses) if statuses else (), "limit": page_size, "offset": (page - 1) * page_size, } @@ -393,6 +404,80 @@ def cascade_delete(cls, ids: list[str]) -> None: db_session.execute(text(query), {"test_run_ids": tuple(ids)}) cls.delete_where(cls.id.in_(ids)) + @classmethod + def find_latest_per_test_suite(cls, project_code: str) -> set[UUID]: + """Return the latest completed test run id per test suite for the + project. + + Includes monitor suites (`is_monitor=True`). Used by data retention to + protect at least one run per scope so each suite keeps a usable + baseline when retention sweeps clear older history. Failed and + in-flight runs are skipped. + """ + rows = get_current_session().scalars( + select(cls.id) + .join(TestSuite, cls.test_suite_id == TestSuite.id) + .join(JobExecution, cls.job_execution_id == JobExecution.id) + .where( + TestSuite.project_code == project_code, + JobExecution.status == JobStatus.COMPLETED, + ) + .order_by(cls.test_suite_id, cls.test_starttime.desc()) + .distinct(cls.test_suite_id) + ).all() + return set(rows) + + @classmethod + def delete_older_than( + cls, + cutoff: datetime, + project_code: str, + protected_ids: set[UUID], + batch_size: int = 1000, + dry_run: bool = False, + ) -> int: + """Batched delete of test runs (with cascading children) older than + cutoff for the given project, excluding protected ids. Returns total + parent rows deleted across all batches — or, with ``dry_run=True``, + the number that would be deleted (for retention preview, no writes). + + In-flight runs (JE in PENDING/CLAIMED/RUNNING/CANCEL_REQUESTED) are + never deleted — they may still be writing data. + + Each batch runs in its own transaction (committed before the next batch + is selected), so locks on test_runs / test_results / etc. are released + between batches and WAL growth stays bounded for large sweeps. + """ + where_clauses = [ + TestSuite.project_code == project_code, + cls.test_starttime < cutoff, + JobExecution.status.in_([JobStatus.COMPLETED, JobStatus.ERROR, JobStatus.CANCELED]), + ] + if protected_ids: + where_clauses.append(cls.id.notin_(protected_ids)) + + base_select = ( + select(cls.id) + .join(TestSuite, cls.test_suite_id == TestSuite.id) + .join(JobExecution, cls.job_execution_id == JobExecution.id) + ) + + if dry_run: + return get_current_session().scalar( + select(func.count()).select_from(base_select.where(*where_clauses).subquery()) + ) or 0 + + total = 0 + while True: + with database_session() as session: + ids = session.scalars(base_select.where(*where_clauses).limit(batch_size)).all() + if not ids: + break + cls.cascade_delete([str(i) for i in ids]) + total += len(ids) + return total + + def init_progress(self) -> None: self._progress = { "data_chars": {"label": "Refreshing data catalog"}, diff --git a/testgen/common/models/test_suite.py b/testgen/common/models/test_suite.py index bd396eb1..ac29ddce 100644 --- a/testgen/common/models/test_suite.py +++ b/testgen/common/models/test_suite.py @@ -4,14 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import BigInteger, Boolean, Column, Enum, ForeignKey, Integer, String, asc, func, select, text from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute from testgen.common.models import get_current_session from testgen.common.models.custom_types import NullIfEmptyString, YNString -from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.entity import Entity, EntityMinimal from testgen.utils import is_uuid4 @@ -94,13 +93,11 @@ def get_regular(cls, identifier: str | UUID) -> "TestSuite | None": @classmethod - @st.cache_data(show_spinner=False) def get_minimal(cls, identifier: int) -> TestSuiteMinimal | None: result = cls._get_columns(identifier, cls._minimal_columns) return TestSuiteMinimal(**result) if result else None @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[TestSuiteMinimal]: diff --git a/testgen/common/models/user.py b/testgen/common/models/user.py index b4e1d575..09022ff1 100644 --- a/testgen/common/models/user.py +++ b/testgen/common/models/user.py @@ -1,9 +1,9 @@ -from datetime import UTC, datetime -from typing import Self +from datetime import UTC, datetime, timedelta +from enum import StrEnum +from typing import Any, Self from uuid import UUID, uuid4 -import streamlit as st -from sqlalchemy import Boolean, Column, String, asc, func, select, update +from sqlalchemy import Boolean, Column, String, asc, func, select, text, update from sqlalchemy.dialects import postgresql from testgen.common.models import get_current_session @@ -12,6 +12,27 @@ from testgen.common.models.project_membership import RoleType +class PreferenceKey(StrEnum): + """Keys allowed in the User.preferences JSONB column.""" + + LAST_FEEDBACK_POPUP = "last_feedback_popup" + + +# Feedback popup cadence. The popup recurs every FEEDBACK_POPUP_INTERVAL. A new user's first +# popup is delayed by FEEDBACK_POPUP_INITIAL_DELAY (rather than showing on first login, when they +# have nothing to give feedback on yet) by seeding last_feedback_popup in the past at creation. +FEEDBACK_POPUP_INTERVAL = timedelta(days=30) +FEEDBACK_POPUP_INITIAL_DELAY = timedelta(days=1) + + +def initial_feedback_popup_seed() -> str: + return (datetime.now(UTC) - (FEEDBACK_POPUP_INTERVAL - FEEDBACK_POPUP_INITIAL_DELAY)).isoformat() + + +def default_user_preferences() -> dict: + return {PreferenceKey.LAST_FEEDBACK_POPUP: initial_feedback_popup_seed()} + + class User(Entity): __tablename__ = "auth_users" @@ -22,6 +43,9 @@ class User(Entity): password: str = Column(String) is_global_admin: bool = Column(Boolean, nullable=False, default=False) latest_login: datetime = Column(postgresql.TIMESTAMP) + preferences: dict = Column( + postgresql.JSONB, nullable=False, default=default_user_preferences, server_default=text("'{}'") + ) _get_by = "username" _default_order_by = (asc(func.lower(username)),) @@ -41,8 +65,18 @@ def save(self, update_latest_login: bool = False) -> None: self.latest_login = datetime.now(UTC) super().save() + def get_preference(self, key: PreferenceKey, default: Any = None) -> Any: + return self.preferences.get(key, default) + + def set_preference(self, key: PreferenceKey, value: Any) -> None: + self.preferences[key] = value + self.update_preferences() + + def update_preferences(self) -> None: + query = update(User).where(User.id == self.id).values(preferences=self.preferences) + get_current_session().execute(query) + @classmethod - @st.cache_data(show_spinner=False) def get(cls, identifier: str) -> Self | None: query = select(cls).where(func.lower(User.username) == func.lower(identifier)) return get_current_session().scalars(query).first() diff --git a/testgen/common/profile_top_values.py b/testgen/common/profile_top_values.py new file mode 100644 index 00000000..ac7d9298 --- /dev/null +++ b/testgen/common/profile_top_values.py @@ -0,0 +1,51 @@ +"""Parsers for the ``top_freq_values`` and ``top_patterns`` fields written by profiling. + +Both fields are stored as delimited strings on ``profile_results``. This module +splits them back into structured rows; format quirks (separators, leading markers, +values containing the separator) are handled here so they only need fixing in one +place. +""" + + +def parse_top_freq_values(raw: str | None) -> list[tuple[str, int]]: + """Parse ``top_freq_values`` text into ``[(value, count), ...]``. + + Stored format: ``| value | count\\n| value | count ...`` — each row begins with + ``| ``, value and count are separated by `` | ``, rows are joined by ``\\n``. + Uses :py:meth:`str.rpartition` so values containing `` | `` parse correctly + (the count is always the rightmost segment). + """ + if not raw: + return [] + body = raw[2:] if raw.startswith("| ") else raw + rows: list[tuple[str, int]] = [] + for part in body.split("\n| "): + if " | " not in part: + continue + value, _, count = part.rpartition(" | ") + try: + rows.append((value.strip(), int(count.strip()))) + except ValueError: + continue + return rows + + +def parse_top_patterns(raw: str | None) -> list[tuple[str, int]]: + """Parse ``top_patterns`` text into ``[(pattern, count), ...]``. + + Stored format: alternating ``count | pattern | count | pattern ...`` (SQL + templates emit segments separated by `` | ``; the odd-indexed segment is the + pattern, the even-indexed is the count). + """ + if not raw: + return [] + parts = [p.strip() for p in raw.split(" | ")] + rows: list[tuple[str, int]] = [] + for index in range(0, len(parts) - 1, 2): + try: + count = int(parts[index]) + except ValueError: + continue + pattern = parts[index + 1] + rows.append((pattern, count)) + return rows diff --git a/testgen/common/source_data_service.py b/testgen/common/source_data_service.py index c59d0b47..52cf2723 100644 --- a/testgen/common/source_data_service.py +++ b/testgen/common/source_data_service.py @@ -127,7 +127,11 @@ def build_hygiene_query(issue_data: dict, limit: int = DEFAULT_LIMIT) -> str | N "TABLE_NAME": issue_data["table_name"], "COLUMN_NAME": issue_data["column_name"], "DETAIL_EXPRESSION": issue_data["detail"], - "PROFILE_RUN_DATE": issue_data["profiling_starttime"], + # Date-only string: Oracle/HANA templates use TO_DATE(..., 'YYYY-MM-DD'), which rejects a time + # component, and the anomaly criteria boundary is date-based (CURRENT_DATE + INTERVAL '30 year'). + "PROFILE_RUN_DATE": parsed_run_date.strftime("%Y-%m-%d") + if (parsed_run_date := parse_fuzzy_date(issue_data["profiling_starttime"])) + else None, "LIMIT": limit, "LIMIT_2": int(limit / 2), "LIMIT_4": int(limit / 4), @@ -265,13 +269,15 @@ def _generate_recency_lookup_query( column_names_str = detail_exp[start_index:] columns = [col.strip() for col in column_names_str.split(",")] - quote = get_flavor_service(sql_flavor).quote_character + flavor_service = get_flavor_service(sql_flavor) + quote = flavor_service.quote_character + table_ref = flavor_service.get_table_ref("{TARGET_SCHEMA}", "{TABLE_NAME}") queries = [ f""" SELECT '{column}' AS column_name, MAX({quote}{column}{quote}) AS max_date_available - FROM {quote}{{TARGET_SCHEMA}}{quote}.{quote}{{TABLE_NAME}}{quote} + FROM {table_ref} """ for column in columns ] diff --git a/testgen/common/time_series_service.py b/testgen/common/time_series_service.py index 7aca697a..aeabb180 100644 --- a/testgen/common/time_series_service.py +++ b/testgen/common/time_series_service.py @@ -10,7 +10,7 @@ # This is a heuristic minimum to get a reasonable prediction # Not a hard limit of the model -MIN_TRAIN_VALUES = 20 +MIN_TRAIN_VALUES = 8 class NotEnoughData(ValueError): diff --git a/testgen/mcp/permissions.py b/testgen/mcp/permissions.py index dce78000..0850c753 100644 --- a/testgen/mcp/permissions.py +++ b/testgen/mcp/permissions.py @@ -39,6 +39,14 @@ def has_access(self, project_code: str) -> bool: """For filtering lists — no exception, just a bool.""" return project_code in self.allowed_codes + def has_permission(self, permission: str, project_code: str) -> bool: + """Whether the user has ``permission`` on ``project_code`` (single-check predicate). + + For per-row checks in tight loops, prefer caching the result of + :meth:`codes_allowed_to` once and using a set lookup. + """ + return project_code in self.codes_allowed_to(permission) + def verify_access(self, project_code: str, not_found: "str | MCPPermissionDenied") -> None: """Raise MCPPermissionDenied if user can't access this project. diff --git a/testgen/mcp/prompts/workflows.py b/testgen/mcp/prompts/workflows.py index 4201bf75..55d833f6 100644 --- a/testgen/mcp/prompts/workflows.py +++ b/testgen/mcp/prompts/workflows.py @@ -7,7 +7,7 @@ def health_check() -> str: Please perform a data quality health check: 1. Call `get_data_inventory()` to get a complete overview of all projects, connections, table groups, and test suites. -2. For each project, call `get_recent_test_runs(...)` to get the latest test runs across all suites. +2. For each project, call `list_test_runs(...)` to get the latest test runs across all suites. 3. Summarize the overall health: - Which projects/suites are healthy (all tests passing)? - Which have failures or warnings? @@ -29,7 +29,7 @@ def investigate_failures(test_suite: str | None = None) -> str: Please investigate test failures and identify root causes:{suite_filter} 1. Call `get_data_inventory()` to understand the project structure. -2. Call `get_recent_test_runs(...)` to find the latest run per suite{f" for suite `{test_suite}`" if test_suite else ""}. +2. Call `list_test_runs(...)` to find the latest run per suite{f" for suite `{test_suite}`" if test_suite else ""}. 3. Call `get_failure_summary(job_execution_id='...')` to see failures grouped by test type. 4. For each failure category, call `get_test_type(test_type='...')` to understand what the test checks. 5. Call `list_test_results(test_suite_id='...', status='Failed')` to drill into the specific failing tests in the latest run. @@ -112,7 +112,7 @@ def hygiene_triage(table_group_id: str | None = None) -> str: def compare_runs(test_suite: str | None = None) -> str: - """Compare the two most recent test runs to identify regressions and improvements. + """Compare the most recent test run against the previous run to identify regressions and improvements. Args: test_suite: Optional test suite name to focus the comparison on. @@ -120,16 +120,10 @@ def compare_runs(test_suite: str | None = None) -> str: suite_filter = f" for suite `{test_suite}`" if test_suite else "" return f"""\ -Please compare the two most recent test runs{suite_filter} to identify regressions and improvements: +Please compare the most recent test run{suite_filter} against the previous run to identify regressions and improvements: 1. Call `get_data_inventory()` to understand the project structure. -2. Call `list_test_suites(project_code='...')` to find suites{suite_filter} and their latest runs. -3. For the most recent completed run, call `list_test_results(test_suite_id='...')` to get all results. -4. For the previous run, call `list_test_results(job_execution_id='...')` to get all results. -5. Compare the two runs: - - **Regressions:** Tests that passed before but now fail. - - **Improvements:** Tests that failed before but now pass. - - **Persistent failures:** Tests that failed in both runs. - - **Stable passes:** Tests that passed in both runs. -6. Summarize the trend and highlight any concerning regressions. +2. Call `list_test_suites(project_code='...')` to find suites{suite_filter} and their latest run IDs. +3. Call `compare_test_runs(target_job_execution_id='')` — with only the target supplied, the tool automatically diffs against the previous completed run of the same suite. +4. Summarize the trend and highlight any concerning regressions, improvements, persistent failures, or newly added/removed tests. """ diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 1358db55..5e91b2f1 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -1,11 +1,14 @@ import logging +from urllib.parse import urlparse from mcp.server.auth.provider import AccessToken from mcp.server.auth.settings import AuthSettings from mcp.server.fastmcp import FastMCP from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings from starlette.applications import Starlette +from testgen import settings from testgen.common.auth import decode_jwt_token from testgen.mcp.permissions import set_mcp_token, set_mcp_username @@ -31,7 +34,10 @@ ALWAYS look them up using either the `testgen://test-types` resource or the `get_test_type()` tool. Hygiene issue types similarly have specific meanings. ALWAYS look them up using the -`testgen://hygiene-issue-types` resource. +`testgen://hygiene-issue-types` resource.q + +Column profile fields are type-specific (different stats per Alpha / Numeric / Date / Boolean / Other). +ALWAYS look them up using the `testgen://column-profile-fields` resource. INVESTIGATING FAILURES @@ -75,6 +81,43 @@ def _configure_mcp_logging() -> None: logging.getLogger(name).parent = testgen_logger +def _build_transport_security() -> TransportSecuritySettings: + """Build DNS-rebinding allowlist from BASE_URL plus operator extras and loopback. + + Without an explicit transport_security, FastMCP installs a loopback-only + allowlist that rejects external Host headers with 421. We pass this settings + object so production deployments accept their own externally-reachable host. + """ + parsed = urlparse(settings.BASE_URL) + base_host = parsed.hostname or "localhost" + netloc = parsed.netloc + scheme = parsed.scheme or "http" + + allowed_hosts: set[str] = { + netloc, + f"{base_host}:*", + "127.0.0.1:*", + "localhost:*", + "[::1]:*", + } + allowed_origins: set[str] = { + f"{scheme}://{netloc}", + "http://127.0.0.1:*", "https://127.0.0.1:*", + "http://localhost:*", "https://localhost:*", + "http://[::1]:*", "https://[::1]:*", + } + for host in settings.MCP_EXTRA_ALLOWED_HOSTS: + host_pattern = host if ":" in host else f"{host}:*" + allowed_hosts.add(host_pattern) + allowed_origins.update({f"http://{host_pattern}", f"https://{host_pattern}"}) + + return TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=sorted(allowed_hosts), + allowed_origins=sorted(allowed_origins), + ) + + def build_mcp_server( api_base_url: str, server_url: str | None = None, @@ -108,24 +151,75 @@ def build_mcp_server( search_hygiene_issues, update_hygiene_issue, ) - from testgen.mcp.tools.profiling import get_table, list_column_profiles, list_profiling_summaries + from testgen.mcp.tools.notifications import ( + create_notification, + delete_notification, + get_notification, + list_notifications, + update_notification, + ) + from testgen.mcp.tools.profile_history import ( + compare_profiling_runs, + get_profiling_trends, + get_schema_history, + ) + from testgen.mcp.tools.profiling import ( + get_column_frequent_values, + get_column_patterns, + get_column_profile_detail, + get_profiling_run, + get_table, + list_column_profiles, + list_profiling_runs, + list_profiling_summaries, + search_columns, + ) + from testgen.mcp.tools.quality_scores import ( + create_scorecard, + delete_scorecard, + get_quality_scores, + get_scorecard, + list_scorecards, + update_scorecard, + ) from testgen.mcp.tools.reference import ( + column_profile_fields_resource, get_test_type, glossary_resource, hygiene_issue_types_resource, test_types_resource, ) + from testgen.mcp.tools.schedules import ( + create_profiling_schedule, + create_test_run_schedule, + delete_schedule, + get_schedule, + list_schedules, + update_schedule, + ) from testgen.mcp.tools.source_data import get_source_data, get_source_data_query - from testgen.mcp.tools.test_definitions import get_test, list_test_notes, list_test_types, list_tests + from testgen.mcp.tools.test_definitions import ( + bulk_update_tests, + create_test, + create_test_note, + delete_test_note, + get_test, + list_test_notes, + list_test_types, + list_tests, + update_test, + update_test_note, + validate_custom_test, + ) from testgen.mcp.tools.test_results import ( + compare_test_runs, get_failure_summary, get_failure_trend, - get_test_result_history, - get_test_run_diff, + list_test_result_history, list_test_results, search_test_results, ) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import get_test_run, list_test_runs if server_url is None: server_url = f"{api_base_url}/mcp" @@ -138,6 +232,7 @@ def build_mcp_server( resource_server_url=server_url, ), token_verifier=JWTTokenVerifier(), + transport_security=_build_transport_security(), ) _configure_mcp_logging() @@ -155,13 +250,14 @@ def safe_prompt(fn): safe_tool(list_projects) safe_tool(list_tables) safe_tool(list_test_suites) - safe_tool(get_recent_test_runs) + safe_tool(list_test_runs) + safe_tool(get_test_run) safe_tool(list_test_results) - safe_tool(get_test_result_history) + safe_tool(list_test_result_history) safe_tool(get_failure_summary) safe_tool(search_test_results) safe_tool(get_failure_trend) - safe_tool(get_test_run_diff) + safe_tool(compare_test_runs) safe_tool(get_test_type) safe_tool(get_source_data) safe_tool(get_source_data_query) @@ -172,19 +268,53 @@ def safe_prompt(fn): safe_tool(get_table) safe_tool(list_column_profiles) safe_tool(list_profiling_summaries) + safe_tool(list_profiling_runs) + safe_tool(get_profiling_run) + safe_tool(get_column_profile_detail) + safe_tool(get_column_frequent_values) + safe_tool(get_column_patterns) + safe_tool(search_columns) + safe_tool(compare_profiling_runs) + safe_tool(get_profiling_trends) + safe_tool(get_schema_history) safe_tool(run_tests) safe_tool(run_profiling) safe_tool(cancel_test_run) safe_tool(cancel_profiling_run) safe_tool(generate_tests) + safe_tool(create_test) + safe_tool(update_test) + safe_tool(validate_custom_test) + safe_tool(bulk_update_tests) + safe_tool(create_test_note) + safe_tool(update_test_note) + safe_tool(delete_test_note) safe_tool(list_hygiene_issues) safe_tool(get_hygiene_issue) safe_tool(search_hygiene_issues) safe_tool(update_hygiene_issue) + safe_tool(create_profiling_schedule) + safe_tool(create_test_run_schedule) + safe_tool(list_schedules) + safe_tool(get_schedule) + safe_tool(update_schedule) + safe_tool(delete_schedule) + safe_tool(get_quality_scores) + safe_tool(list_scorecards) + safe_tool(get_scorecard) + safe_tool(create_scorecard) + safe_tool(update_scorecard) + safe_tool(delete_scorecard) + safe_tool(list_notifications) + safe_tool(get_notification) + safe_tool(create_notification) + safe_tool(update_notification) + safe_tool(delete_notification) # Resources safe_resource("testgen://test-types", test_types_resource) safe_resource("testgen://hygiene-issue-types", hygiene_issue_types_resource) + safe_resource("testgen://column-profile-fields", column_profile_fields_resource) safe_resource("testgen://glossary", glossary_resource) # Prompts diff --git a/testgen/mcp/services/inventory_service.py b/testgen/mcp/services/inventory_service.py index a20aef31..cee1c7a6 100644 --- a/testgen/mcp/services/inventory_service.py +++ b/testgen/mcp/services/inventory_service.py @@ -5,6 +5,7 @@ from testgen.common.models import get_current_session from testgen.common.models.connection import Connection from testgen.common.models.project import Project +from testgen.common.models.scores import ScoreDefinition from testgen.common.models.table_group import TableGroup, TableGroupSummary from testgen.common.models.test_suite import TestSuite from testgen.utils import friendly_score, score @@ -95,10 +96,12 @@ def get_inventory( view_codes_set = set(view_project_codes) profiling_by_tg: dict[UUID, TableGroupSummary] = {} + scorecards_by_project: dict[str, tuple[dict[str, list[tuple[str, str]]], list[tuple[str, str]]]] = {} for code in view_codes_set: summaries, _ = TableGroup.select_summary(code) for summary in summaries: profiling_by_tg[summary.id] = summary + scorecards_by_project[code] = _scorecards_by_table_group(code) # Format as Markdown lines = ["# Data Inventory\n"] @@ -125,6 +128,10 @@ def get_inventory( for group_id, group in conn["groups"].items(): summary = profiling_by_tg.get(group_id) if can_view else None + tg_scorecards: list[tuple[str, str]] = [] + if can_view: + by_tg, _ = scorecards_by_project[project_code] + tg_scorecards = by_tg.get(group["name"], []) if compact_groups or not can_view: line = ( @@ -133,6 +140,8 @@ def get_inventory( ) if summary: line += f", {_profiling_summary_fragment(summary)}" + if tg_scorecards: + line += f", scorecards: {len(tg_scorecards)}" lines.append(line) continue @@ -143,26 +152,65 @@ def get_inventory( if summary: lines.append(f"_{_profiling_summary_fragment(summary)}_\n") + if tg_scorecards: + lines.append("**Scorecards:**") + for sid, name in tg_scorecards: + lines.append(f"- **{name}** (id: `{sid}`)") + lines.append("") + if not group["suites"]: lines.append("_No test suites._\n") continue + lines.append("**Test Suites:**") for suite in group["suites"]: lines.append(f"- **{suite['name']}** (id: `{suite['id']}`)") lines.append("") lines.append("") + if can_view: + _, multi = scorecards_by_project.get(project_code, ({}, [])) + if multi: + lines.append("### Scorecards spanning multiple table groups\n") + for sid, name in multi: + lines.append(f"- **{name}** (id: `{sid}`)") + lines.append("") + lines.append( "---\n" "Use `list_tables(table_group_id='...')` to see tables in a group.\n" "Use `list_test_suites(project_code='...')` for suite details and latest run stats.\n" - "Use `list_profiling_summaries(table_group_id='...')` for the quality score rollup and hygiene issue counts." + "Use `list_profiling_summaries(table_group_id='...')` for the quality score rollup and hygiene issue counts.\n" + "Use `get_scorecard(scorecard_id='...')` for the score breakdown and category detail." ) return "\n".join(lines) +def _scorecards_by_table_group( + project_code: str, +) -> tuple[dict[str, list[tuple[str, str]]], list[tuple[str, str]]]: + """Index scorecards in a project by the table groups they target by name. + + Returns (by_tg_name, multi_or_none): + - by_tg_name[tg_name] = list of (scorecard_id_str, scorecard_name) for + scorecards that declare a `table_groups_name = tg_name` filter. + - multi_or_none lists scorecards whose name-filter count is not exactly 1 + (zero filters → project-wide; multiple → spans TGs by name). Such + scorecards appear under every named TG AND in this list. + """ + by_tg: dict[str, list[tuple[str, str]]] = {} + multi_or_none: list[tuple[str, str]] = [] + for sc_id, sc_name, tg_names in ScoreDefinition.list_with_table_group_targets(project_code): + entry = (str(sc_id), sc_name) + for tg_name in tg_names: + by_tg.setdefault(tg_name, []).append(entry) + if len(tg_names) != 1: + multi_or_none.append(entry) + return by_tg, multi_or_none + + def _profiling_summary_fragment(summary: TableGroupSummary) -> str: """Compact one-liner of profiling metadata for a table group.""" if not summary.latest_profile_id: @@ -173,13 +221,13 @@ def _profiling_summary_fragment(summary: TableGroupSummary) -> str: + (summary.latest_hygiene_issues_likely_ct or 0) + (summary.latest_hygiene_issues_possible_ct or 0) ) - combined = friendly_score(score(summary.dq_score_profiling, summary.dq_score_testing)) + total = friendly_score(score(summary.dq_score_profiling, summary.dq_score_testing)) profiled_at = ( summary.latest_profile_start.strftime("%Y-%m-%d") if summary.latest_profile_start else "—" ) return ( - f"Score {combined}, hygiene issues {hygiene_issue_total}, " + f"Score {total}, hygiene issues {hygiene_issue_total}, " f"last profiled {profiled_at}, " f"profiling run `{summary.latest_profile_job_execution_id}`" ) diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index 55b7ff02..e08ca861 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -1,12 +1,32 @@ -from datetime import date +from datetime import date, datetime from enum import StrEnum from uuid import UUID +from sqlalchemy import select + from testgen.common.date_service import parse_since -from testgen.common.enums import ImpactDimension, QualityDimension -from testgen.common.models.hygiene_issue import Disposition, HygieneIssueType, IssueLikelihood, PiiRisk +from testgen.common.enums import Disposition, ImpactDimension, IssueLikelihood, JobStatus, PiiRisk, QualityDimension +from testgen.common.models import get_current_session +from testgen.common.models.data_column import ( + GENERAL_TYPE_TO_CODE, + ColumnOrderBy, + GeneralType, + ProfileMetric, + SuggestedDataType, +) +from testgen.common.models.hygiene_issue import HygieneIssueType +from testgen.common.models.notification_settings import ( + MonitorNotificationTrigger, + NotificationEvent, + NotificationSettings, + ProfilingRunNotificationTrigger, + TestRunNotificationTrigger, +) +from testgen.common.models.profiling_run import ProfilingRun +from testgen.common.models.scheduler import SCHEDULABLE_JOB_KEYS, JobSchedule +from testgen.common.models.scores import ScoreCategory, ScoreDefinition from testgen.common.models.table_group import TableGroup -from testgen.common.models.test_definition import TestType +from testgen.common.models.test_definition import TestDefinition, TestDefinitionNote, TestType from testgen.common.models.test_result import TestResultStatus from testgen.common.models.test_suite import TestSuite from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError @@ -35,6 +55,8 @@ class DocGroup(StrEnum): INVESTIGATE = "Investigate quality issues" BROWSE_PROFILING = "Browse profiling results" TRIGGER = "Trigger profiling, tests, and test generation" + SCORING = "Track data quality scores" + MANAGE = "Manage TestGen configuration" def parse_uuid(value: str, label: str = "ID") -> UUID: @@ -85,6 +107,220 @@ def parse_quality_dimension(value: str) -> QualityDimension: raise MCPUserError(f"Invalid quality_dimension `{value}`. Valid values: {valid}") from err +class ScoreGroupBy(StrEnum): + """User-facing values accepted for the ``group_by`` argument on quality-score rollups.""" + + QUALITY_DIMENSION = "Quality Dimension" + IMPACT_DIMENSION = "Impact Dimension" + SEMANTIC_DATA_TYPE = "Semantic Data Type" + TABLE_GROUP = "Table Group" + DATA_LOCATION = "Data Location" + DATA_SOURCE = "Data Source" + SOURCE_SYSTEM = "Source System" + SOURCE_PROCESS = "Source Process" + BUSINESS_DOMAIN = "Business Domain" + STAKEHOLDER_GROUP = "Stakeholder Group" + TRANSFORM_LEVEL = "Transform Level" + DATA_PRODUCT = "Data Product" + + +# Translates the user-facing label to the internal DB column name used by +# ``ScoreCategory`` and the criteria filter list. +SCORE_GROUP_BY_TO_COLUMN: dict[ScoreGroupBy, str] = { + ScoreGroupBy.QUALITY_DIMENSION: "dq_dimension", + ScoreGroupBy.IMPACT_DIMENSION: "impact_dimension", + ScoreGroupBy.SEMANTIC_DATA_TYPE: "semantic_data_type", + ScoreGroupBy.TABLE_GROUP: "table_groups_name", + ScoreGroupBy.DATA_LOCATION: "data_location", + ScoreGroupBy.DATA_SOURCE: "data_source", + ScoreGroupBy.SOURCE_SYSTEM: "source_system", + ScoreGroupBy.SOURCE_PROCESS: "source_process", + ScoreGroupBy.BUSINESS_DOMAIN: "business_domain", + ScoreGroupBy.STAKEHOLDER_GROUP: "stakeholder_group", + ScoreGroupBy.TRANSFORM_LEVEL: "transform_level", + ScoreGroupBy.DATA_PRODUCT: "data_product", +} + + +class ScoreFilterField(StrEnum): + """User-facing values accepted for ``filters[].field`` on quality-score rollups. + + Same shape as ``ScoreGroupBy`` minus the two dimension values — Quality + Dimension and Impact Dimension are valid as ``group_by``, not as filter + fields. The duplication is deliberate: each argument has its own enum so + the valid-value set for each is read off one StrEnum. + """ + + SEMANTIC_DATA_TYPE = "Semantic Data Type" + TABLE_GROUP = "Table Group" + DATA_LOCATION = "Data Location" + DATA_SOURCE = "Data Source" + SOURCE_SYSTEM = "Source System" + SOURCE_PROCESS = "Source Process" + BUSINESS_DOMAIN = "Business Domain" + STAKEHOLDER_GROUP = "Stakeholder Group" + TRANSFORM_LEVEL = "Transform Level" + DATA_PRODUCT = "Data Product" + + +SCORE_FILTER_FIELD_TO_COLUMN: dict[ScoreFilterField, str] = { + ScoreFilterField.SEMANTIC_DATA_TYPE: "semantic_data_type", + ScoreFilterField.TABLE_GROUP: "table_groups_name", + ScoreFilterField.DATA_LOCATION: "data_location", + ScoreFilterField.DATA_SOURCE: "data_source", + ScoreFilterField.SOURCE_SYSTEM: "source_system", + ScoreFilterField.SOURCE_PROCESS: "source_process", + ScoreFilterField.BUSINESS_DOMAIN: "business_domain", + ScoreFilterField.STAKEHOLDER_GROUP: "stakeholder_group", + ScoreFilterField.TRANSFORM_LEVEL: "transform_level", + ScoreFilterField.DATA_PRODUCT: "data_product", +} + + +class ScoreCategoryArg(StrEnum): + """User-facing values accepted for the ``category`` argument on scorecard CRUD. + + Same shape as ``ScoreGroupBy`` — every group-by value is also a valid + breakdown category. Kept as a separate enum (rather than reusing + ``ScoreGroupBy``) so each argument has its own valid-value set per the + per-arg enum convention. + """ + + TABLE_GROUP = "Table Group" + DATA_LOCATION = "Data Location" + DATA_SOURCE = "Data Source" + SOURCE_SYSTEM = "Source System" + SOURCE_PROCESS = "Source Process" + BUSINESS_DOMAIN = "Business Domain" + STAKEHOLDER_GROUP = "Stakeholder Group" + TRANSFORM_LEVEL = "Transform Level" + QUALITY_DIMENSION = "Quality Dimension" + IMPACT_DIMENSION = "Impact Dimension" + DATA_PRODUCT = "Data Product" + + +SCORE_CATEGORY_ARG_TO_COLUMN: dict[ScoreCategoryArg, str] = { + ScoreCategoryArg.TABLE_GROUP: "table_groups_name", + ScoreCategoryArg.DATA_LOCATION: "data_location", + ScoreCategoryArg.DATA_SOURCE: "data_source", + ScoreCategoryArg.SOURCE_SYSTEM: "source_system", + ScoreCategoryArg.SOURCE_PROCESS: "source_process", + ScoreCategoryArg.BUSINESS_DOMAIN: "business_domain", + ScoreCategoryArg.STAKEHOLDER_GROUP: "stakeholder_group", + ScoreCategoryArg.TRANSFORM_LEVEL: "transform_level", + ScoreCategoryArg.QUALITY_DIMENSION: "dq_dimension", + ScoreCategoryArg.IMPACT_DIMENSION: "impact_dimension", + ScoreCategoryArg.DATA_PRODUCT: "data_product", +} + + +class ScoreChainLeafField(StrEnum): + """User-facing values accepted as the leaf ``field`` in a scorecard filter chain.""" + + TABLE = "Table" + COLUMN = "Column" + + +SCORE_CHAIN_LEAF_TO_COLUMN: dict[ScoreChainLeafField, str] = { + ScoreChainLeafField.TABLE: "table_name", + ScoreChainLeafField.COLUMN: "column_name", +} + + +class ScoreType(StrEnum): + """User-facing values accepted for the ``score_type`` argument.""" + + TOTAL = "Total" + CDE = "CDE" + + +def parse_score_group_by(value: str) -> ScoreGroupBy: + try: + return ScoreGroupBy(value) + except ValueError as err: + valid = ", ".join(g.value for g in ScoreGroupBy) + raise MCPUserError(f"Invalid group_by `{value}`. Valid values: {valid}") from err + + +def parse_score_filter_field(value: str) -> ScoreFilterField: + try: + return ScoreFilterField(value) + except ValueError as err: + if value in {ScoreGroupBy.QUALITY_DIMENSION.value, ScoreGroupBy.IMPACT_DIMENSION.value}: + raise MCPUserError( + f"`{value}` is not a valid filter field — use group_by='{value}' instead" + ) from err + valid = ", ".join(f.value for f in ScoreFilterField) + raise MCPUserError(f"Invalid filter field `{value}`. Valid values: {valid}") from err + + +def parse_score_type(value: str) -> ScoreType: + try: + return ScoreType(value) + except ValueError as err: + valid = ", ".join(s.value for s in ScoreType) + raise MCPUserError(f"Invalid score_type `{value}`. Valid values: {valid}") from err + + +def parse_category(value: str) -> ScoreCategory: + """Validate a ``category`` argument and return the stored ``ScoreCategory``. + + Accepts the display-form values exposed by ``get_quality_scores``'s + ``group_by`` argument (e.g. ``Quality Dimension``, ``Data Source``). + """ + try: + arg = ScoreCategoryArg(value) + except ValueError as err: + valid = ", ".join(c.value for c in ScoreCategoryArg) + raise MCPUserError(f"Invalid category `{value}`. Valid values: {valid}") from err + return ScoreCategory(SCORE_CATEGORY_ARG_TO_COLUMN[arg]) + + +# Maps user-facing run-status labels to underlying ``JobStatus`` values. Transient states +# (Starting/Canceling) are excluded because they're sub-second and noisy as filters. +# ``Pending`` collapses PENDING+CLAIMED; ``Canceled`` collapses CANCEL_REQUESTED+CANCELED. +_RUN_STATUS_FILTER: dict[str, list[JobStatus]] = { + "Pending": [JobStatus.PENDING, JobStatus.CLAIMED], + "Running": [JobStatus.RUNNING], + "Completed": [JobStatus.COMPLETED], + "Canceled": [JobStatus.CANCEL_REQUESTED, JobStatus.CANCELED], + "Error": [JobStatus.ERROR], +} + + +def parse_run_status_filter(value: str) -> list[JobStatus]: + """Map a user-facing run status label (e.g. ``Pending``) to the underlying ``JobStatus`` values.""" + statuses = _RUN_STATUS_FILTER.get(value) + if statuses is None: + valid = ", ".join(_RUN_STATUS_FILTER.keys()) + raise MCPUserError(f"Invalid status `{value}`. Valid values: {valid}") + return statuses + + +def format_run_duration(started_at: datetime | None, completed_at: datetime | None) -> str | None: + """Render an elapsed duration as ``Xs`` / ``Xm Ys`` / ``Xh Ym``. Returns ``None`` if either bound is missing.""" + if not started_at or not completed_at: + return None + seconds = int((completed_at - started_at).total_seconds()) + if seconds < 60: + return f"{seconds}s" + if seconds < 3600: + return f"{seconds // 60}m {seconds % 60}s" + return f"{seconds // 3600}h {(seconds % 3600) // 60}m" + + +def next_scheduled_run( + job_key: str, kwargs_filter: dict[str, str | list[str]], project_code: str, +) -> datetime | None: + """Return the next firing of an active ``JobSchedule`` matching ``job_key`` and a kwargs + filter. When multiple schedules match, the soonest next-firing wins. + """ + schedules = JobSchedule.select_active_by_kwargs(project_code, job_key, kwargs_filter) + if not schedules: + return None + return min(s.get_sample_triggering_timestamps(1)[0] for s in schedules) + + def parse_disposition(value: str) -> Disposition: """Validate a user-facing disposition label and return the stored ``Disposition``. @@ -125,6 +361,100 @@ def parse_issue_likelihood_list(values: list[str]) -> list[IssueLikelihood]: return parsed +# Maps the user-facing display label to the stored ``pii_flag`` middle segment +# (``A//``). Mirrors ``_PII_TYPE_MAP`` in ``profiling.py``. +_PII_CATEGORY_TO_CODE: dict[str, str] = { + "ID": "ID", + "Name": "NAME", + "Demographic": "DEMO", + "Contact": "CONTACT", +} + + +def build_ilike_pattern(raw: str) -> str: + """Prepare a free-text input for an ``ILIKE`` clause. + + Escapes literal underscores (which column names commonly contain) so they + match as themselves rather than as the SQL single-character wildcard. When + the input contains an explicit ``%``, honor it as the caller's wildcard; + otherwise wrap the input with ``%...%`` for substring match. + + Pair with ``column.ilike(pattern, escape="\\\\")`` at the call site. + """ + escaped = raw.replace("_", r"\_") + return escaped if "%" in escaped else f"%{escaped}%" + + +def parse_pii_category(value: str) -> str: + """Validate a pii_category value and return the stored ``pii_flag`` middle segment.""" + code = _PII_CATEGORY_TO_CODE.get(value) + if code is None: + valid = ", ".join(_PII_CATEGORY_TO_CODE) + raise MCPUserError(f"Invalid pii_category `{value}`. Valid values: {valid}") + return code + + +def parse_general_type(value: str) -> str: + """Validate a user-facing ``general_type`` word and return the stored single-letter code. + + Accepts ``Alpha`` / ``Numeric`` / ``Datetime`` / ``Boolean`` / ``Time`` / ``Other``; + returns ``A`` / ``N`` / ``D`` / ``B`` / ``T`` / ``X`` respectively (the values stored + on ``data_column_chars.general_type``). + """ + try: + member = GeneralType(value) + except ValueError as err: + valid = ", ".join(t.value for t in GeneralType) + raise MCPUserError(f"Invalid general_type `{value}`. Valid values: {valid}") from err + return GENERAL_TYPE_TO_CODE[member] + + +def parse_suggested_data_type(value: str) -> SuggestedDataType: + try: + return SuggestedDataType(value) + except ValueError as err: + valid = ", ".join(t.value for t in SuggestedDataType) + raise MCPUserError(f"Invalid suggested_data_type `{value}`. Valid values: {valid}") from err + + +def parse_column_order_by(value: str) -> ColumnOrderBy: + try: + return ColumnOrderBy(value) + except ValueError as err: + valid = ", ".join(o.value for o in ColumnOrderBy) + raise MCPUserError(f"Invalid order_by `{value}`. Valid values: {valid}") from err + + +def parse_profile_metrics(values: list[str]) -> list[ProfileMetric]: + """Validate a list of profile metric names. Empties out with one error listing all invalids.""" + if not values: + raise MCPUserError("`metrics` cannot be empty — name at least one metric to trend.") + parsed: list[ProfileMetric] = [] + invalid: list[str] = [] + for value in values: + try: + parsed.append(ProfileMetric(value)) + except ValueError: + invalid.append(value) + if invalid: + valid = ", ".join(m.value for m in ProfileMetric) + raise MCPUserError(f"Invalid metrics {invalid}. Valid values: {valid}") + return parsed + + +# ``pii_flag`` encodes risk as a single-character prefix: ``A`` (High), ``B`` (Moderate), ``C`` (Low). +_PII_RISK_LEVEL_TO_CODE: dict[str, str] = {"High": "A", "Moderate": "B", "Low": "C"} + + +def parse_pii_risk_level(value: str) -> str: + """Validate a column-profile pii_risk_level filter and return the stored prefix code.""" + code = _PII_RISK_LEVEL_TO_CODE.get(value) + if code is None: + valid = ", ".join(_PII_RISK_LEVEL_TO_CODE) + raise MCPUserError(f"Invalid pii_risk_level `{value}`. Valid values: {valid}") + return code + + def parse_pii_risk_list(values: list[str]) -> list[PiiRisk]: parsed: list[PiiRisk] = [] invalid: list[str] = [] @@ -203,3 +533,194 @@ def resolve_test_suite(test_suite_id: str) -> TestSuite: if suite is None: raise MCPResourceNotAccessible("Test suite", test_suite_id) return suite + + +def resolve_profiling_run(job_execution_id: str) -> ProfilingRun: + """Resolve a profiling run by id-or-JE-id, scoped to allowed projects. + + Collapses missing-or-inaccessible into a single ``MCPResourceNotAccessible`` + so callers don't leak existence of runs they shouldn't see. + """ + run_uuid = parse_uuid(job_execution_id, "job_execution_id") + run = ProfilingRun.get_by_id_or_job(run_uuid) + perms = get_project_permissions() + if run is None or not perms.has_access(run.project_code): + raise MCPResourceNotAccessible("Profiling run", job_execution_id) + return run + + +def resolve_scorecard(scorecard_id: str) -> ScoreDefinition: + """Resolve a scorecard ID, collapsing missing-or-inaccessible into one error path.""" + parse_uuid(scorecard_id, "scorecard_id") + perms = get_project_permissions() + sd = ScoreDefinition.get(scorecard_id) + if sd is None or not perms.has_access(sd.project_code): + raise MCPResourceNotAccessible("Scorecard", scorecard_id) + return sd + + +def resolve_test_definition(test_definition_id: str) -> TestDefinition: + """Resolve a test definition ID to the live ORM model, collapsing missing-or-inaccessible. + + Filters monitor suites and project access. Returns the ORM ``TestDefinition`` + (not ``TestDefinitionSummary``) so the row can be mutated and saved. + """ + td_uuid = parse_uuid(test_definition_id, "test_definition_id") + perms = get_project_permissions() + query = ( + select(TestDefinition) + .join(TestSuite, TestDefinition.test_suite_id == TestSuite.id) + .where( + TestDefinition.id == td_uuid, + TestSuite.is_monitor.isnot(True), + TestSuite.project_code.in_(perms.allowed_codes), + ) + ) + td = get_current_session().scalars(query).first() + if td is None: + raise MCPResourceNotAccessible("Test definition", test_definition_id) + return td + + +def resolve_test_note(test_note_id: str) -> TestDefinitionNote: + """Resolve a test note ID to the live ORM model, collapsing missing-or-inaccessible. + + Filters monitor suites and project access via the note's parent test definition. + """ + note_uuid = parse_uuid(test_note_id, "test_note_id") + perms = get_project_permissions() + query = ( + select(TestDefinitionNote) + .join(TestDefinition, TestDefinitionNote.test_definition_id == TestDefinition.id) + .join(TestSuite, TestDefinition.test_suite_id == TestSuite.id) + .where( + TestDefinitionNote.id == note_uuid, + TestSuite.is_monitor.isnot(True), + TestSuite.project_code.in_(perms.allowed_codes), + ) + ) + note = get_current_session().scalars(query).first() + if note is None: + raise MCPResourceNotAccessible("Test note", test_note_id) + return note + + +def resolve_schedule(schedule_id: str) -> JobSchedule: + """Resolve a user-managed schedule ID, collapsing missing-or-inaccessible into one error path.""" + sched_uuid = parse_uuid(schedule_id, "schedule_id") + perms = get_project_permissions() + sched = JobSchedule.get( + JobSchedule.id == sched_uuid, + JobSchedule.key.in_(SCHEDULABLE_JOB_KEYS), + JobSchedule.project_code.in_(perms.allowed_codes), + ) + if sched is None: + raise MCPResourceNotAccessible("Schedule", schedule_id) + return sched + + +def resolve_notification(notification_id: str) -> NotificationSettings: + """Resolve a notification ID, collapsing missing-or-inaccessible into one error path. + + Returns the polymorphic ``NotificationSettings`` subclass (TestRun / ProfilingRun / + ScoreDrop / Monitor) so callers can read event-specific typed properties. + """ + notif_uuid = parse_uuid(notification_id, "notification_id") + perms = get_project_permissions() + notif = NotificationSettings.get( + notif_uuid, + NotificationSettings.project_code.in_(perms.allowed_codes), + ) + if notif is None: + raise MCPResourceNotAccessible("Notification", notification_id) + return notif + + +# Notification event-type labels. + +class NotificationEventLabel(StrEnum): + """User-facing values for notification event types.""" + + TEST_RUN = "Test Run" + PROFILING_RUN = "Profiling Run" + SCORE_DROP = "Score Drop" + MONITOR_RUN = "Monitor Alert" + + +NOTIFICATION_EVENT_LABEL_TO_INTERNAL: dict[NotificationEventLabel, NotificationEvent] = { + NotificationEventLabel.TEST_RUN: NotificationEvent.test_run, + NotificationEventLabel.PROFILING_RUN: NotificationEvent.profiling_run, + NotificationEventLabel.SCORE_DROP: NotificationEvent.score_drop, + NotificationEventLabel.MONITOR_RUN: NotificationEvent.monitor_run, +} + +_NOTIFICATION_EVENT_INTERNAL_TO_LABEL: dict[NotificationEvent, NotificationEventLabel] = { + v: k for k, v in NOTIFICATION_EVENT_LABEL_TO_INTERNAL.items() +} + + +def format_notification_event(event: NotificationEvent | str) -> str: + """Map a stored notification event to its user-facing label.""" + return _NOTIFICATION_EVENT_INTERNAL_TO_LABEL[NotificationEvent(event)].value + + +# Notification trigger labels — one StrEnum per event type. Same wording the end user sees in the UI: +# ``ui/views/test_runs.py:249-254``, ``ui/views/profiling_runs.py:265-268``, +# ``ui/views/monitors_dashboard.py:323-326``. + +class TestRunTriggerLabel(StrEnum): + ALWAYS = "Always" + ON_FAILURES = "On test failures" + ON_WARNINGS = "On test failures and warnings" + ON_CHANGES = "On new test failures and warnings" + + +TEST_RUN_TRIGGER_LABEL_TO_INTERNAL: dict[TestRunTriggerLabel, TestRunNotificationTrigger] = { + TestRunTriggerLabel.ALWAYS: TestRunNotificationTrigger.always, + TestRunTriggerLabel.ON_FAILURES: TestRunNotificationTrigger.on_failures, + TestRunTriggerLabel.ON_WARNINGS: TestRunNotificationTrigger.on_warnings, + TestRunTriggerLabel.ON_CHANGES: TestRunNotificationTrigger.on_changes, +} + + +class ProfilingRunTriggerLabel(StrEnum): + ALWAYS = "Always" + ON_CHANGES = "On new hygiene issues" + + +PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL: dict[ProfilingRunTriggerLabel, ProfilingRunNotificationTrigger] = { + ProfilingRunTriggerLabel.ALWAYS: ProfilingRunNotificationTrigger.always, + ProfilingRunTriggerLabel.ON_CHANGES: ProfilingRunNotificationTrigger.on_changes, +} + + +class MonitorTriggerLabel(StrEnum): + ON_ANOMALIES = "On anomalies" + + +MONITOR_TRIGGER_LABEL_TO_INTERNAL: dict[MonitorTriggerLabel, MonitorNotificationTrigger] = { + MonitorTriggerLabel.ON_ANOMALIES: MonitorNotificationTrigger.on_anomalies, +} + +_TEST_RUN_TRIGGER_INTERNAL_TO_LABEL = {v: k for k, v in TEST_RUN_TRIGGER_LABEL_TO_INTERNAL.items()} +_PROFILING_RUN_TRIGGER_INTERNAL_TO_LABEL = {v: k for k, v in PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL.items()} +_MONITOR_TRIGGER_INTERNAL_TO_LABEL = {v: k for k, v in MONITOR_TRIGGER_LABEL_TO_INTERNAL.items()} + + +def format_notification_trigger(event: NotificationEvent | str, settings: dict | None) -> str | None: + """Map a notification's stored trigger value to its user-facing label. + + Returns ``None`` for ``score_drop`` (no trigger — thresholds drive it) or when + ``settings`` carries no ``trigger`` key. + """ + raw = settings.get("trigger") if settings else None + if raw is None: + return None + event_enum = NotificationEvent(event) + if event_enum is NotificationEvent.test_run: + return _TEST_RUN_TRIGGER_INTERNAL_TO_LABEL[TestRunNotificationTrigger(raw)].value + if event_enum is NotificationEvent.profiling_run: + return _PROFILING_RUN_TRIGGER_INTERNAL_TO_LABEL[ProfilingRunNotificationTrigger(raw)].value + if event_enum is NotificationEvent.monitor_run: + return _MONITOR_TRIGGER_INTERNAL_TO_LABEL[MonitorNotificationTrigger(raw)].value + return None diff --git a/testgen/mcp/tools/execution.py b/testgen/mcp/tools/execution.py index b7313535..3ed02e0e 100644 --- a/testgen/mcp/tools/execution.py +++ b/testgen/mcp/tools/execution.py @@ -2,7 +2,7 @@ from sqlalchemy import select -from testgen.api.schemas import JobKey, JobSource +from testgen.common.enums import JobKey, JobSource from testgen.common.models import get_current_session, with_database_session from testgen.common.models.job_execution import JobExecution from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError @@ -17,7 +17,7 @@ @mcp_permission("edit") def run_tests(test_suite_id: str) -> str: """Submit a test run for a test suite. Returns immediately with a job_execution_id; - use ``get_recent_test_runs`` to track status. + use ``list_test_runs`` to track status. Args: test_suite_id: UUID of the test suite to run, e.g. from ``list_test_suites``. @@ -29,7 +29,7 @@ def run_tests(test_suite_id: str) -> str: source=JobSource.mcp, project_code=suite.project_code, ) - return _render_submission("Test run", suite.test_suite, "Test suite", job, "get_recent_test_runs") + return _render_submission("Test run", suite.test_suite, "Test suite", job, "list_test_runs") @with_database_session @@ -86,10 +86,10 @@ def cancel_test_run(job_execution_id: str) -> str: """Request cancellation of a queued or running test run. Args: - job_execution_id: UUID of a test run, e.g. from ``get_recent_test_runs``. + job_execution_id: UUID of a test run, e.g. from ``list_test_runs``. """ job = _resolve_job_execution(job_execution_id, JobKey.run_tests, "Test run") - return _render_cancel(job, "Test run", "get_recent_test_runs") + return _render_cancel(job, "Test run", "list_test_runs") @with_database_session @@ -106,7 +106,8 @@ def cancel_profiling_run(job_execution_id: str) -> str: def _resolve_job_execution(job_execution_id: str, expected_job_key: JobKey, kind: str) -> JobExecution: """Resolve a user-submitted job by ID + expected job_key, collapsing missing-or-inaccessible - into one error path. Filters out source='system' jobs (internal rollups, never user-cancelable). + into one error path. Each MCP tool pins ``expected_job_key`` to a public kind, so the + job_key match alone restricts the lookup to externally-visible jobs. """ job_uuid = parse_uuid(job_execution_id, "job_execution_id") perms = get_project_permissions() @@ -114,7 +115,6 @@ def _resolve_job_execution(job_execution_id: str, expected_job_key: JobKey, kind select(JobExecution).where( JobExecution.id == job_uuid, JobExecution.job_key == expected_job_key, - JobExecution.source != "system", JobExecution.project_code.in_(perms.allowed_codes), ) ).first() diff --git a/testgen/mcp/tools/hygiene_issues.py b/testgen/mcp/tools/hygiene_issues.py index 1a834a4b..19042588 100644 --- a/testgen/mcp/tools/hygiene_issues.py +++ b/testgen/mcp/tools/hygiene_issues.py @@ -4,8 +4,9 @@ from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.functions import func +from testgen.common.enums import Disposition, IssueLikelihood, PiiRisk from testgen.common.models import with_database_session -from testgen.common.models.hygiene_issue import Disposition, HygieneIssue, HygieneIssueType, IssueLikelihood, PiiRisk +from testgen.common.models.hygiene_issue import HygieneIssue, HygieneIssueType from testgen.common.models.job_execution import JobExecution from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.table_group import TableGroup diff --git a/testgen/mcp/tools/markdown.py b/testgen/mcp/tools/markdown.py index ceac0ded..23e3c6d7 100644 --- a/testgen/mcp/tools/markdown.py +++ b/testgen/mcp/tools/markdown.py @@ -44,7 +44,6 @@ def _format_dt(value: object) -> str | None: return value[:16].replace("T", " ") + " UTC" return None - def _format_part(value: object) -> str: """Format a single value for text() parts — datetime-aware, no escaping.""" if value is None: @@ -52,6 +51,11 @@ def _format_part(value: object) -> str: return dt_str if (dt_str := _format_dt(value)) else str(value) +def _format_boolean(value: object) -> str | None: + if isinstance(value, bool): + return "Yes" if value else "No" + return None + # --------------------------------------------------------------------------- # MdDoc # --------------------------------------------------------------------------- @@ -204,6 +208,8 @@ def _format_field_value(value: object, *, code: bool = False) -> str: return "\u2014" if dt_str := _format_dt(value): return MdDoc.code(dt_str) if code else dt_str + if bool_str := _format_boolean(value): + return MdDoc.code(bool_str) if code else bool_str s = str(value) return MdDoc.code(s) if code else s diff --git a/testgen/mcp/tools/notifications.py b/testgen/mcp/tools/notifications.py new file mode 100644 index 00000000..5b798ca6 --- /dev/null +++ b/testgen/mcp/tools/notifications.py @@ -0,0 +1,995 @@ +from dataclasses import dataclass +from decimal import Decimal +from enum import StrEnum +from uuid import UUID + +from testgen.common.models import with_database_session +from testgen.common.models.notification_settings import ( + MonitorNotificationTrigger, + NotificationEvent, + NotificationSettings, + NotificationSummary, + ProfilingRunNotificationSettings, + ProfilingRunNotificationTrigger, + ScoreDropNotificationSettings, + TestRunNotificationSettings, + TestRunNotificationTrigger, + is_valid_email, +) +from testgen.common.models.scores import ScoreDefinition +from testgen.common.models.table_group import TableGroup +from testgen.common.models.test_suite import TestSuite +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError +from testgen.mcp.permissions import get_project_permissions, mcp_permission +from testgen.mcp.tools.common import ( + MONITOR_TRIGGER_LABEL_TO_INTERNAL, + NOTIFICATION_EVENT_LABEL_TO_INTERNAL, + PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL, + TEST_RUN_TRIGGER_LABEL_TO_INTERNAL, + DocGroup, + MonitorTriggerLabel, + NotificationEventLabel, + ProfilingRunTriggerLabel, + TestRunTriggerLabel, + format_notification_event, + format_notification_trigger, + format_page_footer, + format_page_info, + resolve_notification, + resolve_scorecard, + resolve_table_group, + resolve_test_suite, + validate_limit, + validate_page, +) +from testgen.mcp.tools.markdown import MdDoc + +_DOC_GROUP = DocGroup.MANAGE + +# ``Monitor Alert`` is intentionally excluded from creation: a monitor notification is +# bound to its (internal, user-invisible) monitor test suite at monitor-setup time, so it +# can't be created standalone here. Existing monitor notifications are still managed via +# get/update/delete/list_notifications. +_CREATE_SUPPORTED_EVENTS: tuple[NotificationEvent, ...] = ( + NotificationEvent.test_run, + NotificationEvent.profiling_run, + NotificationEvent.score_drop, +) + + +@with_database_session +@mcp_permission("view") +def list_notifications( + project_code: str | None = None, + test_suite_id: str | None = None, + table_group_id: str | None = None, + scorecard_id: str | None = None, + limit: int = 50, + page: int = 1, +) -> str: + """List notifications configured across projects, or scoped to a parent entity. + + With no scope argument, returns notifications across every project the caller can view. + Provide one of ``project_code`` / ``test_suite_id`` / ``table_group_id`` / ``scorecard_id`` + to narrow the listing. Parent-entity scopes filter strictly on that entity — to also + see project-wide notifications (those not bound to a specific suite, table group, or + scorecard), use ``project_code``. + + Args: + project_code: Scope to a specific project. + test_suite_id: UUID of a test suite, e.g. from ``list_test_suites``. Returns only + notifications bound to this suite. + table_group_id: UUID of a table group, e.g. from ``get_data_inventory``. Returns + only notifications bound to this table group. + scorecard_id: UUID of a scorecard, e.g. from ``list_scorecards``. Returns only + notifications bound to this scorecard. + limit: Maximum number of notifications per page (default 50, max 200). + page: Page number, starting from 1 (default 1). + """ + validate_page(page) + validate_limit(limit, 200) + + scope_args = { + "project_code": project_code, + "test_suite_id": test_suite_id, + "table_group_id": table_group_id, + "scorecard_id": scorecard_id, + } + provided = [name for name, value in scope_args.items() if value] + if len(provided) > 1: + raise MCPUserError( + "Pass at most one of `project_code`, `test_suite_id`, `table_group_id`, `scorecard_id`." + ) + + perms = get_project_permissions() + scope_label: str | None = None + + if test_suite_id: + suite = resolve_test_suite(test_suite_id) + rows, total = NotificationSettings.list_for_test_suite(suite.id, page=page, limit=limit) + scope_label = f"Test Suite `{suite.test_suite}`" + elif table_group_id: + tg = resolve_table_group(table_group_id) + rows, total = NotificationSettings.list_for_table_group(tg.id, page=page, limit=limit) + scope_label = f"Table Group `{tg.table_groups_name}`" + elif scorecard_id: + scorecard = resolve_scorecard(scorecard_id) + rows, total = NotificationSettings.list_for_score_definition(scorecard.id, page=page, limit=limit) + scope_label = f"Scorecard `{scorecard.name}`" + elif project_code: + perms.verify_access(project_code, not_found=MCPResourceNotAccessible("Project", project_code)) + rows, total = NotificationSettings.list_for_projects([project_code], page=page, limit=limit) + scope_label = f"Project `{project_code}`" + else: + rows, total = NotificationSettings.list_for_projects(perms.allowed_codes, page=page, limit=limit) + + return _render(rows, total, page=page, limit=limit, scope_label=scope_label) + + +@with_database_session +@mcp_permission("view") +def get_notification(notification_id: str) -> str: + """Get full details of an email notification: event type, trigger or thresholds, + scope (project, test suite, table group, or scorecard), and recipients. + + Works on any notification, including ``Monitor Alert`` notifications — those are + created through monitor setup rather than this tool, but can be viewed here. + + Args: + notification_id: UUID of the notification, e.g. from ``list_notifications``. + """ + notif = resolve_notification(notification_id) + return _render_one(notif) + + +@with_database_session +@mcp_permission("edit") +def create_notification( + event_type: str, + recipients: list[str], + test_suite_id: str | None = None, + table_group_id: str | None = None, + scorecard_id: str | None = None, + trigger_on: str | None = None, + total_threshold: float | None = None, + cde_threshold: float | None = None, +) -> str: + """Create an email notification for a test-run, profiling-run, or score-drop event. + + Every invalid input is surfaced in a single error so the call can be corrected + in one round-trip — no partial save occurs. + + Args: + event_type: The event that triggers the notification. One of + ``Test Run``, ``Profiling Run``, ``Score Drop``. ``Monitor Alert`` + notifications are configured in the TestGen UI and cannot be created + here; ``update_notification`` can still modify them once they exist. + recipients: One or more well-formed email addresses to notify. + test_suite_id: UUID of the test suite, e.g. from ``list_test_suites``. + Required when ``event_type`` is ``Test Run``; rejected otherwise. + table_group_id: UUID of the table group, e.g. from ``get_data_inventory``. + Required when ``event_type`` is ``Profiling Run``; rejected otherwise. + scorecard_id: UUID of the scorecard, e.g. from ``list_scorecards``. + Required when ``event_type`` is ``Score Drop``; rejected otherwise. + trigger_on: When to fire the notification. Only used for ``Test Run`` + and ``Profiling Run``; rejected for ``Score Drop``. + For ``Test Run`` (default ``On test failures``): one of ``Always``, + ``On test failures``, ``On test failures and warnings``, + ``On new test failures and warnings``. + For ``Profiling Run`` (default ``On new hygiene issues``): one of + ``Always``, ``On new hygiene issues``. + total_threshold: Score-drop trigger for the total score (over 0, up to 100). + Only used for ``Score Drop``; at least one of ``total_threshold`` or + ``cde_threshold`` must be supplied. + cde_threshold: Score-drop trigger for the critical-data-element score + (over 0, up to 100). Only used for ``Score Drop``. + """ + event = _parse_event_type(event_type) + + if event is NotificationEvent.test_run: + _enforce_scope_shape( + event_type, + required=("test_suite_id", test_suite_id), + forbidden=(("table_group_id", table_group_id), ("scorecard_id", scorecard_id)), + ) + _reject_threshold_args(event_type, total_threshold, cde_threshold) + suite = resolve_test_suite(test_suite_id) + clean_recipients = _validate_recipients(recipients) + trigger = _parse_test_run_trigger(trigger_on) + notif = TestRunNotificationSettings.create( + project_code=suite.project_code, + test_suite_id=suite.id, + recipients=clean_recipients, + trigger=trigger, + ) + elif event is NotificationEvent.profiling_run: + _enforce_scope_shape( + event_type, + required=("table_group_id", table_group_id), + forbidden=(("test_suite_id", test_suite_id), ("scorecard_id", scorecard_id)), + ) + _reject_threshold_args(event_type, total_threshold, cde_threshold) + tg = resolve_table_group(table_group_id) + clean_recipients = _validate_recipients(recipients) + trigger = _parse_profiling_run_trigger(trigger_on) + notif = ProfilingRunNotificationSettings.create( + project_code=tg.project_code, + table_group_id=tg.id, + recipients=clean_recipients, + trigger=trigger, + ) + else: + # NotificationEvent.score_drop — _parse_event_type rejected anything else. + _enforce_scope_shape( + event_type, + required=("scorecard_id", scorecard_id), + forbidden=(("test_suite_id", test_suite_id), ("table_group_id", table_group_id)), + ) + if trigger_on is not None: + raise MCPUserError( + f"`trigger_on` is not supported for event type `{event_type}` — thresholds drive the event." + ) + scorecard = resolve_scorecard(scorecard_id) + _validate_score_thresholds(total_threshold, cde_threshold) + clean_recipients = _validate_recipients(recipients) + notif = ScoreDropNotificationSettings.create( + project_code=scorecard.project_code, + score_definition_id=scorecard.id, + recipients=clean_recipients, + total_score_threshold=total_threshold, + cde_score_threshold=cde_threshold, + ) + + return _render_created(notif) + + +# --- create_notification helpers --- + + +def _parse_event_type(value: str) -> NotificationEvent: + """Map the supplied display label to its ``NotificationEvent``. + + Rejects anything outside the create-supported subset (test_run / profiling_run / + score_drop) — including the otherwise-valid ``Monitor Alert`` event. Raises + ``MCPUserError`` listing every supported display label. + """ + label: NotificationEventLabel | None + try: + label = NotificationEventLabel(value) + except ValueError: + label = None + event = NOTIFICATION_EVENT_LABEL_TO_INTERNAL.get(label) if label is not None else None + if event not in _CREATE_SUPPORTED_EVENTS: + valid = ", ".join(f"`{format_notification_event(e)}`" for e in _CREATE_SUPPORTED_EVENTS) + raise MCPUserError(f"Invalid `event_type` `{value}`. Valid values: {valid}.") + return event + + +def _enforce_scope_shape( + event_type: str, + *, + required: tuple[str, str | None], + forbidden: tuple[tuple[str, str | None], ...], +) -> None: + """Reject missing-required or any forbidden scope args for the chosen event.""" + required_name, required_value = required + if not required_value: + raise MCPUserError(f"`{required_name}` is required for event type `{event_type}`.") + supplied_forbidden = [name for name, value in forbidden if value] + if supplied_forbidden: + joined = ", ".join(f"`{name}`" for name in supplied_forbidden) + raise MCPUserError(f"{joined} not supported for event type `{event_type}`. Use only `{required_name}`.") + + +def _reject_threshold_args( + event_type: str, + total_threshold: float | None, + cde_threshold: float | None, +) -> None: + """Reject ``total_threshold`` / ``cde_threshold`` on non-score events.""" + stray = [ + name + for name, value in ( + ("total_threshold", total_threshold), + ("cde_threshold", cde_threshold), + ) + if value is not None + ] + if stray: + joined = ", ".join(f"`{name}`" for name in stray) + raise MCPUserError( + f"{joined} not supported for event type `{event_type}`. " + "Only `Score Drop` notifications use score thresholds." + ) + + +def _validate_recipients(recipients: list[str]) -> list[str]: + """Return the recipients list after batch-validating every entry. + + Raises ``MCPUserError`` if the list is empty or contains any malformed address — + every bad address is named in the single error message so the caller can fix + them all in one round-trip. + """ + if not recipients: + raise MCPUserError("`recipients` must contain at least one email address.") + invalid = [addr for addr in recipients if not is_valid_email(addr)] + if invalid: + joined = ", ".join(f"`{addr}`" for addr in invalid) + raise MCPUserError(f"Invalid email addresses: {joined}.") + return list(recipients) + + +def _parse_test_run_trigger(value: str | None) -> TestRunNotificationTrigger: + if value is None: + return TestRunNotificationTrigger.on_failures + try: + label = TestRunTriggerLabel(value) + except ValueError as err: + valid = ", ".join(f"`{label.value}`" for label in TestRunTriggerLabel) + raise MCPUserError( + f"Invalid `trigger_on` `{value}` for event type `Test Run`. Valid values: {valid}." + ) from err + return TEST_RUN_TRIGGER_LABEL_TO_INTERNAL[label] + + +def _parse_profiling_run_trigger(value: str | None) -> ProfilingRunNotificationTrigger: + if value is None: + return ProfilingRunNotificationTrigger.on_changes + try: + label = ProfilingRunTriggerLabel(value) + except ValueError as err: + valid = ", ".join(f"`{label.value}`" for label in ProfilingRunTriggerLabel) + raise MCPUserError( + f"Invalid `trigger_on` `{value}` for event type `Profiling Run`. Valid values: {valid}." + ) from err + return PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL[label] + + +def _validate_score_thresholds( + total_threshold: float | None, + cde_threshold: float | None, +) -> None: + """Reject missing-or-out-of-range thresholds for a score-drop notification. + + Surfaces every range violation in a single error. + """ + if total_threshold is None and cde_threshold is None: + raise MCPUserError( + "At least one of `total_threshold` or `cde_threshold` must be set for event type `Score Drop`." + ) + _validate_threshold_range(total_threshold, cde_threshold) + + +def _validate_threshold_range( + total_threshold: float | None, + cde_threshold: float | None, +) -> None: + """Reject any out-of-range threshold value; surface every offender in one error. + + 0 is rejected: a score can never drop below 0, so a 0 threshold would never fire. + """ + range_errors = [] + for name, value in (("total_threshold", total_threshold), ("cde_threshold", cde_threshold)): + if value is not None and not 0 < value <= 100: + range_errors.append(f"`{name}` = {value} (must be greater than 0 and at most 100)") + if range_errors: + raise MCPUserError("Score threshold out of range: " + "; ".join(range_errors) + ".") + + +def _parse_monitor_trigger(value: str | None) -> MonitorNotificationTrigger: + if value is None: + return MonitorNotificationTrigger.on_anomalies + try: + label = MonitorTriggerLabel(value) + except ValueError as err: + valid = ", ".join(f"`{label.value}`" for label in MonitorTriggerLabel) + raise MCPUserError( + f"Invalid `trigger_on` `{value}` for event type `Monitor Alert`. Valid values: {valid}." + ) from err + return MONITOR_TRIGGER_LABEL_TO_INTERNAL[label] + + +@with_database_session +@mcp_permission("edit") +def update_notification( + notification_id: str, + *, + enabled: bool | None = None, + recipients: list[str] | None = None, + trigger_on: str | None = None, + total_threshold: float | None = None, + cde_threshold: float | None = None, + clear_total_threshold: bool = False, + clear_cde_threshold: bool = False, + table_name: str | None = None, + clear_table_name: bool = False, +) -> str: + """Update fields on an existing email notification. Pass only the fields to change. + + Works on any notification, including ``Monitor Alert`` notifications — those are + created through monitor setup rather than this tool, but can be updated here. + + Every invalid input surfaces in a single error before any save — no partial save. + The notification's event type and scope entity are immutable through this tool; + delete and recreate to change them. (A Monitor Alert's optional table — a finer + scope within its table group — can still be set or cleared here.) + + Args: + notification_id: UUID of the notification, e.g. from ``list_notifications``. + enabled: ``True`` to resume, ``False`` to pause. Omit to leave unchanged. + recipients: Replace the recipient list with the supplied addresses (one or more + well-formed emails). Omit to leave unchanged. + trigger_on: New trigger condition. Only valid for ``Test Run``, ``Profiling Run``, + and ``Monitor Alert`` notifications; rejected for ``Score Drop``. + For ``Test Run``: one of ``Always``, ``On test failures``, + ``On test failures and warnings``, ``On new test failures and warnings``. + For ``Profiling Run``: one of ``Always``, ``On new hygiene issues``. + For ``Monitor Alert``: ``On anomalies`` is the only supported value, so + this field cannot meaningfully be changed on Monitor Alert notifications. + total_threshold: New total score threshold (over 0, up to 100). Only valid for + ``Score Drop`` notifications. + cde_threshold: New critical-data-element score threshold (over 0, up to 100). Only valid + for ``Score Drop`` notifications. + clear_total_threshold: ``True`` to clear the overall-score threshold (set to + NULL). At least one threshold must remain set after the call. + clear_cde_threshold: ``True`` to clear the CDE-score threshold. At least one + threshold must remain set after the call. + table_name: Narrow a Monitor Alert notification's scope to a single table within + its table group. Only valid for ``Monitor Alert`` notifications. + clear_table_name: ``True`` to drop an existing table from a Monitor Alert + notification (notifications then fire for any table in the table group). + """ + if ( + enabled is None + and recipients is None + and trigger_on is None + and total_threshold is None + and cde_threshold is None + and not clear_total_threshold + and not clear_cde_threshold + and table_name is None + and not clear_table_name + ): + raise MCPUserError("No fields supplied to update.") + + notif = resolve_notification(notification_id) + event = notif.event + event_label = format_notification_event(event) + + _reject_event_stray_args( + event, + event_label, + trigger_on=trigger_on, + total_threshold=total_threshold, + cde_threshold=cde_threshold, + clear_total_threshold=clear_total_threshold, + clear_cde_threshold=clear_cde_threshold, + table_name=table_name, + clear_table_name=clear_table_name, + ) + + _reject_set_and_clear_conflicts( + total_threshold=total_threshold, + clear_total_threshold=clear_total_threshold, + cde_threshold=cde_threshold, + clear_cde_threshold=clear_cde_threshold, + table_name=table_name, + clear_table_name=clear_table_name, + ) + + clean_recipients: list[str] | None = None + if recipients is not None: + clean_recipients = _validate_recipients(recipients) + + parsed_trigger = None + if trigger_on is not None: + if event is NotificationEvent.test_run: + parsed_trigger = _parse_test_run_trigger(trigger_on) + elif event is NotificationEvent.profiling_run: + parsed_trigger = _parse_profiling_run_trigger(trigger_on) + elif event is NotificationEvent.monitor_run: + parsed_trigger = _parse_monitor_trigger(trigger_on) + + if event is NotificationEvent.score_drop: + _validate_threshold_range(total_threshold, cde_threshold) + _validate_score_drop_post_state( + notif, + total_threshold=total_threshold, + cde_threshold=cde_threshold, + clear_total_threshold=clear_total_threshold, + clear_cde_threshold=clear_cde_threshold, + ) + + pending = _build_pending( + notif, + enabled=enabled, + recipients=clean_recipients, + trigger=parsed_trigger, + total_threshold=total_threshold, + cde_threshold=cde_threshold, + clear_total_threshold=clear_total_threshold, + clear_cde_threshold=clear_cde_threshold, + table_name=table_name, + clear_table_name=clear_table_name, + ) + + doc = MdDoc() + doc.heading(1, f"{event_label} Notification updated") + doc.field("Notification ID", notif.id, code=True) + + if not pending: + doc.text("No fields changed — supplied values matched the current state.") + return doc.render() + + before = {attr: _snapshot_attr(notif, attr) for attr in pending} + for attr, value in pending.items(): + setattr(notif, attr, value) + after = {attr: _snapshot_attr(notif, attr) for attr in pending} + + notif.save() + + rows = [[_DIFF_LABELS[attr], before[attr], after[attr]] for attr in pending] + doc.table(["Field", "Before", "After"], rows) + return doc.render() + + +# --- update_notification helpers --- + + +_DIFF_LABELS: dict[str, str] = { + "enabled": "Status", + "recipients": "Recipients", + "trigger": "Trigger", + "total_score_threshold": "Total Score Threshold", + "cde_score_threshold": "CDE Score Threshold", + "table_name": "Table", +} + + +def _reject_event_stray_args( + event: NotificationEvent, + event_label: str, + *, + trigger_on: str | None, + total_threshold: float | None, + cde_threshold: float | None, + clear_total_threshold: bool, + clear_cde_threshold: bool, + table_name: str | None, + clear_table_name: bool, +) -> None: + """Reject args that are meaningless for the resolved event. + + Collects every stray arg into a single ``MCPUserError`` so the caller can fix + them all in one round-trip. The message names the relevant supported event + for each stray so the LLM knows where each arg actually applies. + """ + threshold_strays = [ + name + for name, supplied in ( + ("total_threshold", total_threshold is not None), + ("cde_threshold", cde_threshold is not None), + ("clear_total_threshold", clear_total_threshold), + ("clear_cde_threshold", clear_cde_threshold), + ) + if supplied + ] + table_strays = [ + name + for name, supplied in ( + ("table_name", table_name is not None), + ("clear_table_name", clear_table_name), + ) + if supplied + ] + + messages: list[str] = [] + if event is NotificationEvent.score_drop: + if trigger_on is not None: + messages.append( + f"`trigger_on` is not supported for event type `{event_label}` — thresholds drive the event." + ) + if table_strays: + joined = ", ".join(f"`{name}`" for name in table_strays) + messages.append( + f"{joined} not supported for event type `{event_label}`. " + "Only `Monitor Alert` notifications can be scoped to a table." + ) + else: + if threshold_strays: + joined = ", ".join(f"`{name}`" for name in threshold_strays) + messages.append( + f"{joined} not supported for event type `{event_label}`. " + "Only `Score Drop` notifications use score thresholds." + ) + if event is not NotificationEvent.monitor_run and table_strays: + joined = ", ".join(f"`{name}`" for name in table_strays) + messages.append( + f"{joined} not supported for event type `{event_label}`. " + "Only `Monitor Alert` notifications can be scoped to a table." + ) + + if messages: + raise MCPUserError(" ".join(messages)) + + +def _reject_set_and_clear_conflicts( + *, + total_threshold: float | None, + clear_total_threshold: bool, + cde_threshold: float | None, + clear_cde_threshold: bool, + table_name: str | None, + clear_table_name: bool, +) -> None: + """Reject any (set, clear) pair where the caller supplied both for the same field.""" + conflicts = [ + name + for name, set_supplied, clear_supplied in ( + ("total_threshold", total_threshold is not None, clear_total_threshold), + ("cde_threshold", cde_threshold is not None, clear_cde_threshold), + ("table_name", table_name is not None, clear_table_name), + ) + if set_supplied and clear_supplied + ] + if conflicts: + joined = ", ".join(f"`{name}`" for name in conflicts) + raise MCPUserError(f"{joined} cannot be both set and cleared in the same call.") + + +def _validate_score_drop_post_state( + notif: NotificationSettings, + *, + total_threshold: float | None, + cde_threshold: float | None, + clear_total_threshold: bool, + clear_cde_threshold: bool, +) -> None: + """Pre-empt model.save()'s "at least one threshold" invariant. + + Compute the effective threshold values that would result from applying the + pending change and reject up-front if both would be NULL. + """ + if clear_total_threshold: + effective_total = None + elif total_threshold is not None: + effective_total = total_threshold + else: + effective_total = notif.total_score_threshold + + if clear_cde_threshold: + effective_cde = None + elif cde_threshold is not None: + effective_cde = cde_threshold + else: + effective_cde = notif.cde_score_threshold + + if effective_total is None and effective_cde is None: + raise MCPUserError( + "At least one of `total_threshold` or `cde_threshold` must remain set " + "for a `Score Drop` notification." + ) + + +def _build_pending( + notif: NotificationSettings, + *, + enabled: bool | None, + recipients: list[str] | None, + trigger: object, + total_threshold: float | None, + cde_threshold: float | None, + clear_total_threshold: bool, + clear_cde_threshold: bool, + table_name: str | None, + clear_table_name: bool, +) -> dict[str, object]: + """Return only the changes that actually differ from the current state.""" + pending: dict[str, object] = {} + + if enabled is not None and notif.enabled != enabled: + pending["enabled"] = enabled + + if recipients is not None and list(notif.recipients or []) != recipients: + pending["recipients"] = recipients + + if trigger is not None and notif.trigger != trigger: + pending["trigger"] = trigger + + if clear_total_threshold and notif.total_score_threshold is not None: + pending["total_score_threshold"] = None + elif total_threshold is not None and notif.total_score_threshold != total_threshold: + pending["total_score_threshold"] = total_threshold + + if clear_cde_threshold and notif.cde_score_threshold is not None: + pending["cde_score_threshold"] = None + elif cde_threshold is not None and notif.cde_score_threshold != cde_threshold: + pending["cde_score_threshold"] = cde_threshold + + if clear_table_name and notif.table_name is not None: + pending["table_name"] = None + elif table_name is not None and notif.table_name != table_name: + pending["table_name"] = table_name + + return pending + + +def _snapshot_attr(notif: NotificationSettings, attr: str) -> object: + """Render a single attribute's current value in display form for the diff table.""" + if attr == "enabled": + return "Active" if notif.enabled else "Paused" + if attr == "recipients": + return ", ".join(notif.recipients or []) or None + if attr == "trigger": + return _label_for_trigger(notif.event, notif.trigger) + if attr == "total_score_threshold": + return _format_threshold(notif.total_score_threshold) + if attr == "cde_score_threshold": + return _format_threshold(notif.cde_score_threshold) + if attr == "table_name": + return notif.table_name or None + return None + + +def _label_for_trigger(event: NotificationEvent, trigger: object) -> str | None: + """Render the user-facing label for an in-memory trigger enum value.""" + if trigger is None: + return None + if event is NotificationEvent.test_run and isinstance(trigger, TestRunNotificationTrigger): + return format_notification_trigger(event, {"trigger": trigger.value}) + if event is NotificationEvent.profiling_run and isinstance(trigger, ProfilingRunNotificationTrigger): + return format_notification_trigger(event, {"trigger": trigger.value}) + if event is NotificationEvent.monitor_run and isinstance(trigger, MonitorNotificationTrigger): + return format_notification_trigger(event, {"trigger": trigger.value}) + return None + + +def _format_threshold(value: object) -> str | None: + """Render a stored Decimal threshold (or an in-memory float/int) as a display string.""" + if value is None: + return None + if isinstance(value, Decimal): + return str(value) + return str(value) + + +@with_database_session +@mcp_permission("edit") +def delete_notification(notification_id: str) -> str: + """Delete an email notification. + + Works on any notification, including ``Monitor Alert`` notifications — those are + created through monitor setup rather than this tool, but can be deleted here. + + Args: + notification_id: UUID of the notification, e.g. from ``list_notifications``. + """ + notif = resolve_notification(notification_id) + event_label = format_notification_event(notif.event) + + doc = MdDoc() + doc.heading(1, f"{event_label} Notification deleted") + doc.field("Notification ID", notif.id, code=True) + doc.field("Event Type", event_label) + doc.field("Project", notif.project_code, code=True) + _render_scope_fields(doc, notif) + + notif.delete() + + return doc.render() + + +def _render_one(notif: NotificationSettings) -> str: + doc = MdDoc() + event_label = format_notification_event(notif.event) + doc.heading(1, f"{event_label} Notification") + _render_notification_body(doc, notif) + return doc.render() + + +def _render_created(notif: NotificationSettings) -> str: + doc = MdDoc() + event_label = format_notification_event(notif.event) + doc.heading(1, f"{event_label} Notification created") + _render_notification_body(doc, notif) + return doc.render() + + +def _render_notification_body(doc: MdDoc, notif: NotificationSettings) -> None: + event_label = format_notification_event(notif.event) + status_word = "Active" if notif.enabled else "Paused" + + doc.heading(2, "Configuration") + doc.field("Notification ID", notif.id, code=True) + doc.field("Event Type", event_label) + doc.field("Status", status_word) + if trigger_label := format_notification_trigger(notif.event, notif.settings): + doc.field("Trigger", trigger_label) + if notif.event == NotificationEvent.score_drop: + total_threshold = (notif.settings or {}).get("total_threshold") + cde_threshold = (notif.settings or {}).get("cde_threshold") + if total_threshold is not None: + doc.field("Total Score Threshold", total_threshold) + if cde_threshold is not None: + doc.field("CDE Score Threshold", cde_threshold) + + doc.heading(2, "Scope") + doc.field("Project", notif.project_code, code=True) + _render_scope_fields(doc, notif) + + doc.heading(2, "Recipients") + if notif.recipients: + doc.bullets(list(notif.recipients)) + else: + doc.text("_No recipients configured._") + + +class _ScopeEntityKind(StrEnum): + SUITE = "suite" + TABLE_GROUP = "table_group" + SCORECARD = "scorecard" + + +@dataclass(frozen=True) +class _ScopeField: + label: str + id_attr: str + all_label: str + kind: _ScopeEntityKind + + +_SUITE_FIELD = _ScopeField("Test Suite", "test_suite_id", "All Test Suites", _ScopeEntityKind.SUITE) +_TABLE_GROUP_FIELD = _ScopeField("Table Group", "table_group_id", "All Table Groups", _ScopeEntityKind.TABLE_GROUP) +_SCORECARD_FIELD = _ScopeField("Scorecard", "score_definition_id", "All Scorecards", _ScopeEntityKind.SCORECARD) + +# Single source of truth: which scope entities (and labels) each event renders. +# Both the detail view (_render_scope_fields) and the list view (_scope_text) iterate this. +# Monitors are scoped to their table group only — the underlying monitor test suite is an +# internal detail that is never surfaced. An optional table narrows the scope further (see +# the monitor ``table_name`` handling in both renderers). +_SCOPE_FIELDS: dict[NotificationEvent, tuple[_ScopeField, ...]] = { + NotificationEvent.test_run: (_SUITE_FIELD,), + NotificationEvent.profiling_run: (_TABLE_GROUP_FIELD,), + NotificationEvent.score_drop: (_SCORECARD_FIELD,), + NotificationEvent.monitor_run: (_TABLE_GROUP_FIELD,), +} + + +def _render_scope_fields(doc: MdDoc, notif: NotificationSettings) -> None: + for field in _SCOPE_FIELDS.get(notif.event, ()): + entity_id = getattr(notif, field.id_attr) + name = _resolve_scope_name(field.kind, entity_id) + doc.field(field.label, _scope_value(name, entity_id, field.all_label)) + if notif.event == NotificationEvent.monitor_run and (table_name := (notif.settings or {}).get("table_name")): + doc.field("Table", table_name) + + +def _resolve_scope_name(kind: _ScopeEntityKind, entity_id: UUID | None) -> str | None: + if kind is _ScopeEntityKind.SUITE: + return _suite_name(entity_id) + if kind is _ScopeEntityKind.TABLE_GROUP: + return _table_group_name(entity_id) + return _scorecard_name(entity_id) + + +def _scope_value(name: str | None, entity_id: UUID | None, project_wide_label: str) -> str: + if entity_id is None: + return project_wide_label + display = name or str(entity_id) + return f"{display} ({MdDoc.code(str(entity_id))})" + + +def _suite_name(suite_id: UUID | None) -> str | None: + if suite_id is None: + return None + suite = TestSuite.get(suite_id) + return suite.test_suite if suite else None + + +def _table_group_name(tg_id: UUID | None) -> str | None: + if tg_id is None: + return None + tg = TableGroup.get(tg_id) + return tg.table_groups_name if tg else None + + +def _scorecard_name(score_id: UUID | None) -> str | None: + if score_id is None: + return None + sd = ScoreDefinition.get(str(score_id)) + return sd.name if sd else None + + +def _render( + rows: list[NotificationSummary], + total: int, + *, + page: int, + limit: int, + scope_label: str | None, +) -> str: + doc = MdDoc() + heading = f"Email Notifications — {scope_label}" if scope_label else "Email Notifications" + doc.heading(1, heading) + + if not rows: + doc.text("_No notifications match the supplied scope._") + return doc.render() + + if info := format_page_info(total, page, limit): + doc.text(info) + + suite_names = _batch_suite_names({r.test_suite_id for r in rows if r.test_suite_id}) + tg_names = _batch_table_group_names({r.table_group_id for r in rows if r.table_group_id}) + score_names = _batch_score_names({r.score_definition_id for r in rows if r.score_definition_id}) + + for r in rows: + status_word = "Active" if r.enabled else "Paused" + event_label = format_notification_event(r.event) + scope_text = _scope_text(r, suite_names, tg_names, score_names) + doc.heading(2, f"[{status_word}] {event_label} Notification — {scope_text}") + doc.field("Notification ID", r.id, code=True) + doc.field("Event Type", event_label) + doc.field("Status", status_word) + doc.field("Project", r.project_code, code=True) + doc.field("Scope", scope_text) + if trigger_label := format_notification_trigger(r.event, r.settings): + doc.field("Trigger", trigger_label) + if r.event == NotificationEvent.score_drop: + total_threshold = (r.settings or {}).get("total_threshold") + cde_threshold = (r.settings or {}).get("cde_threshold") + if total_threshold is not None: + doc.field("Total Score Threshold", total_threshold) + if cde_threshold is not None: + doc.field("CDE Score Threshold", cde_threshold) + doc.field("Recipients", ", ".join(r.recipients or []) or None) + + if footer := format_page_footer(total, page, limit): + doc.text(footer) + + return doc.render() + + +def _scope_text( + row: NotificationSummary, + suite_names: dict[UUID, str], + tg_names: dict[UUID, str], + score_names: dict[UUID, str], +) -> str: + batches = { + _ScopeEntityKind.SUITE: suite_names, + _ScopeEntityKind.TABLE_GROUP: tg_names, + _ScopeEntityKind.SCORECARD: score_names, + } + fields = _SCOPE_FIELDS.get(row.event, ()) + if not fields: + return "—" + # A project-wide entity reads as a bare label, e.g. "All Table Groups". + parts = [] + for field in fields: + entity_id = getattr(row, field.id_attr) + if entity_id is None: + parts.append(field.all_label) + else: + parts.append(f"{field.label}: {batches[field.kind].get(entity_id, str(entity_id))}") + if row.event == NotificationEvent.monitor_run and (table_name := (row.settings or {}).get("table_name")): + parts.append(f"Table: {table_name}") + return " · ".join(parts) + + +def _batch_suite_names(suite_ids: set[UUID]) -> dict[UUID, str]: + if not suite_ids: + return {} + return {s.id: s.test_suite for s in TestSuite.select_minimal_where(TestSuite.id.in_(list(suite_ids)))} + + +def _batch_table_group_names(tg_ids: set[UUID]) -> dict[UUID, str]: + if not tg_ids: + return {} + return {tg.id: tg.table_groups_name for tg in TableGroup.select_minimal_where(TableGroup.id.in_(list(tg_ids)))} + + +def _batch_score_names(score_ids: set[UUID]) -> dict[UUID, str]: + if not score_ids: + return {} + return ScoreDefinition.names_by_id(score_ids) diff --git a/testgen/mcp/tools/profile_history.py b/testgen/mcp/tools/profile_history.py new file mode 100644 index 00000000..17839798 --- /dev/null +++ b/testgen/mcp/tools/profile_history.py @@ -0,0 +1,875 @@ +"""MCP tools that operate across multiple profiling runs of a table group. + +- ``compare_profiling_runs`` — diff two runs (metric changes for shared columns + hygiene churn). +- ``get_profiling_trends`` — caller-named metric time-series across recent runs. +- ``get_schema_history`` — per-run structural changes (tables/columns added/dropped/re-typed) + with table record-count deltas. + +Structural enumeration intentionally lives only in ``get_schema_history``; the comparison tool +renders a one-line pointer to it rather than duplicating the per-table churn. +""" +from collections import defaultdict +from collections.abc import Iterable +from datetime import datetime +from typing import NamedTuple +from uuid import UUID + +from sqlalchemy import func + +from testgen.common.enums import Disposition, JobStatus +from testgen.common.models import get_current_session, with_database_session +from testgen.common.models.data_column import ProfileMetric +from testgen.common.models.hygiene_issue import HygieneIssue, HygieneIssueType +from testgen.common.models.job_execution import JobExecution +from testgen.common.models.profile_result import ProfileResult +from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunSummary +from testgen.mcp.exceptions import MCPUserError +from testgen.mcp.permissions import mcp_permission +from testgen.mcp.tools.common import ( + DocGroup, + parse_profile_metrics, + resolve_profiling_run, + resolve_table_group, + validate_limit, +) +from testgen.mcp.tools.markdown import MdDoc +from testgen.utils import friendly_score + +_DOC_GROUP = DocGroup.BROWSE_PROFILING + + +# --------------------------------------------------------------------------- +# General-type vocabulary +# --------------------------------------------------------------------------- + +# Single-letter general_type codes (stored on ProfileResult.general_type and +# DataColumnChars.general_type). Mirrors GENERAL_TYPE_TO_CODE values but locally +# named for readability inside this module's scope/type-restriction tables. +_TYPE_ALPHA = "A" +_TYPE_NUMERIC = "N" +_TYPE_DATE = "D" +_TYPE_BOOLEAN = "B" + +_TYPE_LABELS: dict[str, str] = { + _TYPE_ALPHA: "Alpha", + _TYPE_NUMERIC: "Numeric", + _TYPE_DATE: "Date", + _TYPE_BOOLEAN: "Boolean", + "T": "Time", + "X": "Other", +} + + +# --------------------------------------------------------------------------- +# Metric scope + extraction +# --------------------------------------------------------------------------- + +_SCOPE_TABLE_GROUP = "table_group" +_SCOPE_TABLE = "table" +_SCOPE_COLUMN = "column" + +_METRIC_SCOPE: dict[ProfileMetric, str] = { + ProfileMetric.NULL_RATIO: _SCOPE_COLUMN, + ProfileMetric.DISTINCT_RATIO: _SCOPE_COLUMN, + ProfileMetric.FILLED_RATIO: _SCOPE_COLUMN, + ProfileMetric.MIN_LENGTH: _SCOPE_COLUMN, + ProfileMetric.MAX_LENGTH: _SCOPE_COLUMN, + ProfileMetric.AVG_LENGTH: _SCOPE_COLUMN, + ProfileMetric.MIN: _SCOPE_COLUMN, + ProfileMetric.MAX: _SCOPE_COLUMN, + ProfileMetric.AVG: _SCOPE_COLUMN, + ProfileMetric.STDEV: _SCOPE_COLUMN, + ProfileMetric.MIN_DATE: _SCOPE_COLUMN, + ProfileMetric.MAX_DATE: _SCOPE_COLUMN, + ProfileMetric.TRUE_COUNT: _SCOPE_COLUMN, + ProfileMetric.RECORD_COUNT: _SCOPE_TABLE, + ProfileMetric.PROFILING_SCORE: _SCOPE_TABLE_GROUP, + ProfileMetric.HYGIENE_COUNT: _SCOPE_TABLE_GROUP, +} + +# Type-specific metrics only return a value when the column's general_type matches. +_METRIC_TYPE: dict[ProfileMetric, str] = { + ProfileMetric.MIN_LENGTH: _TYPE_ALPHA, + ProfileMetric.MAX_LENGTH: _TYPE_ALPHA, + ProfileMetric.AVG_LENGTH: _TYPE_ALPHA, + ProfileMetric.MIN: _TYPE_NUMERIC, + ProfileMetric.MAX: _TYPE_NUMERIC, + ProfileMetric.AVG: _TYPE_NUMERIC, + ProfileMetric.STDEV: _TYPE_NUMERIC, + ProfileMetric.MIN_DATE: _TYPE_DATE, + ProfileMetric.MAX_DATE: _TYPE_DATE, + ProfileMetric.TRUE_COUNT: _TYPE_BOOLEAN, +} + +# Metrics rendered as percentages. +_PERCENT_METRICS = { + ProfileMetric.NULL_RATIO, + ProfileMetric.DISTINCT_RATIO, + ProfileMetric.FILLED_RATIO, +} + + +def _validate_metric_scope(metrics: list[ProfileMetric], table_name: str | None, column_name: str | None) -> None: + """Reject when any metric needs a deeper scope than the provided arguments offer.""" + needs_column = [m for m in metrics if _METRIC_SCOPE[m] == _SCOPE_COLUMN] + needs_table = [m for m in metrics if _METRIC_SCOPE[m] == _SCOPE_TABLE] + if needs_column and column_name is None: + names = ", ".join(f"`{m.value}`" for m in needs_column) + raise MCPUserError(f"Metrics {names} require both `table_name` and `column_name`.") + if needs_table and table_name is None: + names = ", ".join(f"`{m.value}`" for m in needs_table) + raise MCPUserError(f"Metrics {names} require `table_name`.") + + +def _column_metric_value(metric: ProfileMetric, pr: ProfileResult | None) -> object | None: + """Extract a column-scope metric value from a ProfileResult row. + + Returns ``None`` if the row is missing or the metric doesn't apply to the + column's ``general_type`` (e.g. ``Average Length`` on a numeric column). + """ + if pr is None: + return None + required_type = _METRIC_TYPE.get(metric) + if required_type is not None and pr.general_type != required_type: + return None + record_ct = pr.record_ct + if metric is ProfileMetric.NULL_RATIO: + return pr.null_value_ct / record_ct if record_ct and pr.null_value_ct is not None else None + if metric is ProfileMetric.DISTINCT_RATIO: + return pr.distinct_value_ct / record_ct if record_ct and pr.distinct_value_ct is not None else None + if metric is ProfileMetric.FILLED_RATIO: + return pr.filled_value_ct / record_ct if record_ct and pr.filled_value_ct is not None else None + if metric is ProfileMetric.RECORD_COUNT: + return pr.record_ct + if metric is ProfileMetric.MIN_LENGTH: + return pr.min_length + if metric is ProfileMetric.MAX_LENGTH: + return pr.max_length + if metric is ProfileMetric.AVG_LENGTH: + return pr.avg_length + if metric is ProfileMetric.MIN: + return pr.min_value + if metric is ProfileMetric.MAX: + return pr.max_value + if metric is ProfileMetric.AVG: + return pr.avg_value + if metric is ProfileMetric.STDEV: + return pr.stdev_value + if metric is ProfileMetric.MIN_DATE: + return pr.min_date + if metric is ProfileMetric.MAX_DATE: + return pr.max_date + if metric is ProfileMetric.TRUE_COUNT: + return pr.boolean_true_ct + return None + + +def _format_metric_value(metric: ProfileMetric, value: object | None) -> str: + if value is None: + return "—" + if metric is ProfileMetric.PROFILING_SCORE and isinstance(value, int | float): + return friendly_score(value) or "—" + if metric in _PERCENT_METRICS and isinstance(value, int | float): + return f"{float(value) * 100:.1f}%" + if isinstance(value, datetime): + return value.date().isoformat() + if isinstance(value, float): + # 6 significant digits with thousands separators preserves precision for + # ratios in the 0.x range (e.g. 5.94821) while keeping wide values readable + # (e.g. 12,345.6). + return f"{value:,.6g}" + if isinstance(value, int): + return f"{value:,}" + return str(value) + + +def _delta_cell(metric: ProfileMetric, baseline: object | None, target: object | None) -> str: + """Render a baseline → target cell. ``B (=)`` when unchanged after formatting. + + Equality is checked on the formatted strings, not the raw values — two timestamps + that render as the same date display as ``(=)`` rather than a no-op ``→``. + """ + baseline_str = _format_metric_value(metric, baseline) + target_str = _format_metric_value(metric, target) + if baseline_str == target_str: + return f"{target_str} (=)" + return f"{baseline_str} → {target_str}" + + +# --------------------------------------------------------------------------- +# Run-state guard +# --------------------------------------------------------------------------- + + +def _require_completed(run: ProfilingRun, label: str) -> None: + """Raise if the run's job execution isn't completed.""" + je = get_current_session().get(JobExecution, run.job_execution_id) + if je.status != JobStatus.COMPLETED: + status_label = ProfilingRunSummary.STATUS_LABEL.get(je.status, je.status) + raise MCPUserError( + f"{label} run is in `{status_label}` state — comparison requires a completed run." + ) + + +# --------------------------------------------------------------------------- +# Compare profiling runs +# --------------------------------------------------------------------------- + + +# Per-general-type metric tables. Excludes the type-display column header so the +# table is uniformly wide; cross-flavor type-display drift is surfaced via footnote. +_METRIC_TABLE_BY_TYPE: dict[str, list[ProfileMetric]] = { + _TYPE_NUMERIC: [ + ProfileMetric.NULL_RATIO, + ProfileMetric.DISTINCT_RATIO, + ProfileMetric.MIN, + ProfileMetric.MAX, + ProfileMetric.AVG, + ProfileMetric.STDEV, + ProfileMetric.RECORD_COUNT, + ], + _TYPE_ALPHA: [ + ProfileMetric.NULL_RATIO, + ProfileMetric.DISTINCT_RATIO, + ProfileMetric.AVG_LENGTH, + ProfileMetric.MIN_LENGTH, + ProfileMetric.MAX_LENGTH, + ProfileMetric.RECORD_COUNT, + ], + _TYPE_DATE: [ + ProfileMetric.NULL_RATIO, + ProfileMetric.MIN_DATE, + ProfileMetric.MAX_DATE, + ProfileMetric.RECORD_COUNT, + ], + _TYPE_BOOLEAN: [ + ProfileMetric.NULL_RATIO, + ProfileMetric.TRUE_COUNT, + ProfileMetric.RECORD_COUNT, + ], +} + +# Categorical attributes rendered only when they change. Keys are user-facing +# field labels; values are ProfileResult attribute names. +_CATEGORICAL_FIELDS: dict[str, str] = { + "Type": "column_type", + "Semantic Type": "functional_data_type", + "PII": "pii_flag", + "Suggested Type": "datatype_suggestion", +} + + +def _pair_results( + rows: Iterable[ProfileResult], target_run_id: UUID, baseline_run_id: UUID, +) -> dict[tuple[str, str, str], dict[str, ProfileResult]]: + """Group profile-results by (schema, table, column) and tag each row as target/baseline.""" + by_key: dict[tuple[str, str, str], dict[str, ProfileResult]] = defaultdict(dict) + for row in rows: + key = (row.schema_name, row.table_name, row.column_name) + if row.profile_run_id == target_run_id: + by_key[key]["target"] = row + elif row.profile_run_id == baseline_run_id: + by_key[key]["baseline"] = row + return by_key + + +@with_database_session +@mcp_permission("catalog") +def compare_profiling_runs( + target_job_execution_id: str, + baseline_job_execution_id: str | None = None, + table_name: str | None = None, + column_name: str | None = None, +) -> str: + """Compare two profiling runs on the same table group and report metric changes for shared columns plus hygiene issue churn. + + When ``baseline_job_execution_id`` is omitted, the baseline defaults to the most recent + completed profiling run on the same table group submitted before the target run. Both + runs must be in `Completed` state. + + Reports only on columns present in both runs. When structural drift exists, the output + notes that fact in one line; the per-table/column structural diff is not enumerated here. + + Args: + target_job_execution_id: UUID of the newer profiling run (the "after" snapshot), + e.g. from `list_profiling_runs`. + baseline_job_execution_id: Optional UUID of the older profiling run (the "before" + snapshot). When omitted, defaults to the previous completed run on the same + table group. + table_name: Optional — restrict the comparison to one table (case-sensitive). + column_name: Optional — restrict the comparison to one column (case-sensitive); requires + `table_name`. + """ + if column_name is not None and table_name is None: + raise MCPUserError("`column_name` requires `table_name`.") + + target_run = resolve_profiling_run(target_job_execution_id) + _require_completed(target_run, "Target") + + if baseline_job_execution_id is None: + baseline_run = target_run.get_previous() + if baseline_run is None: + raise MCPUserError( + f"Target run `{target_job_execution_id}` has no earlier completed " + "profiling run on its table group to compare against." + ) + else: + baseline_run = resolve_profiling_run(baseline_job_execution_id) + if baseline_run.table_groups_id != target_run.table_groups_id: + raise MCPUserError( + "Both runs must belong to the same table group to be comparable. " + f"Target is in table group `{target_run.table_groups_id}`, " + f"baseline is in table group `{baseline_run.table_groups_id}`." + ) + _require_completed(baseline_run, "Baseline") + + rows = ProfileResult.select_for_runs( + run_ids=[target_run.id, baseline_run.id], + table_name=table_name, + column_name=column_name, + ) + paired = _pair_results(rows, target_run.id, baseline_run.id) + + has_structural_changes = any( + "target" not in sides or "baseline" not in sides for sides in paired.values() + ) + shared = {key: sides for key, sides in paired.items() if "target" in sides and "baseline" in sides} + + hygiene_diff = _diff_hygiene_issues( + target_run.id, baseline_run.id, table_name=table_name, column_name=column_name, + ) + + return _render_run_comparison( + target_run=target_run, + baseline_run=baseline_run, + shared=shared, + has_structural_changes=has_structural_changes, + hygiene_diff=hygiene_diff, + ) + + +class _HygieneRow(NamedTuple): + table_name: str + column_name: str + issue_type: str + + +def _diff_hygiene_issues( + target_run_id: UUID, + baseline_run_id: UUID, + table_name: str | None, + column_name: str | None, +) -> dict[str, list[_HygieneRow]]: + """Return ``{"added": [...], "resolved": [...]}`` lists of hygiene-issue rows. + + Matches issues across the two runs by (table, column, type_id) — only confirmed + issues (default disposition) are counted. + """ + clauses = [ + HygieneIssue.profile_run_id.in_([target_run_id, baseline_run_id]), + func.coalesce(HygieneIssue.disposition, Disposition.CONFIRMED) == Disposition.CONFIRMED, + ] + if table_name is not None: + clauses.append(HygieneIssue.table_name == table_name) + if column_name is not None: + clauses.append(HygieneIssue.column_name == column_name) + issues = list(HygieneIssue.select_where(*clauses)) + + type_ids = {issue.type_id for issue in issues} + type_names: dict[str, str] = {} + if type_ids: + type_names = { + t.id: t.name for t in HygieneIssueType.select_where(HygieneIssueType.id.in_(type_ids)) + } + + target_keys: set[tuple[str, str, str]] = set() + baseline_keys: set[tuple[str, str, str]] = set() + for issue in issues: + key = (issue.table_name, issue.column_name, issue.type_id) + if issue.profile_run_id == target_run_id: + target_keys.add(key) + else: + baseline_keys.add(key) + + def _rows(keys: Iterable[tuple[str, str, str]]) -> list[_HygieneRow]: + return sorted( + (_HygieneRow(t, c, type_names.get(tid, tid)) for t, c, tid in keys), + key=lambda r: (r.table_name, r.column_name, r.issue_type), + ) + + return { + "added": _rows(target_keys - baseline_keys), + "resolved": _rows(baseline_keys - target_keys), + } + + +def _categorical_change(label: str, baseline: ProfileResult, target: ProfileResult) -> tuple[str, str] | None: + """Return ``(label, "B → T")`` when a categorical field changed, else ``None``.""" + attr = _CATEGORICAL_FIELDS[label] + baseline_value = getattr(baseline, attr) + target_value = getattr(target, attr) + if baseline_value == target_value: + return None + baseline_display = baseline_value if baseline_value is not None else "—" + target_display = target_value if target_value is not None else "—" + return label, f"{baseline_display} → {target_display}" + + +def _build_metric_rows_for_type( + general_type: str, + shared: dict[tuple[str, str, str], dict[str, ProfileResult]], +) -> tuple[list[str], list[list[str]]]: + """Build (headers, rows) for the metric-change table for one general_type bucket.""" + metrics = _METRIC_TABLE_BY_TYPE[general_type] + headers = ["Table", "Column", *(m.value for m in metrics)] + rows: list[list[str]] = [] + for (_, table, column), sides in sorted(shared.items()): + baseline = sides["baseline"] + target = sides["target"] + # Bucket by target's type. Columns that switched type between runs render here + # under the new type; the old/new type is also surfaced as a categorical change. + if target.general_type != general_type: + continue + # Only render columns that changed in at least one metric in this bucket. + deltas: list[str] = [] + any_changed = False + for metric in metrics: + target_value = _column_metric_value(metric, target) + baseline_value = _column_metric_value(metric, baseline) + if target_value != baseline_value: + any_changed = True + deltas.append(_delta_cell(metric, baseline_value, target_value)) + if any_changed: + rows.append([table, column, *deltas]) + return headers, rows + + +def _categorical_lines( + shared: dict[tuple[str, str, str], dict[str, ProfileResult]], +) -> list[str]: + """Return one bullet per shared column that has at least one categorical change.""" + lines: list[str] = [] + for (_, table, column), sides in sorted(shared.items()): + baseline = sides["baseline"] + target = sides["target"] + changes: list[str] = [] + for label in _CATEGORICAL_FIELDS: + change = _categorical_change(label, baseline, target) + if change is not None: + changes.append(f"{change[0]}: {change[1]}") + if changes: + lines.append(f"`{table}.{column}` — {', '.join(changes)}") + return lines + + +def _render_run_comparison( + target_run: ProfilingRun, + baseline_run: ProfilingRun, + shared: dict[tuple[str, str, str], dict[str, ProfileResult]], + has_structural_changes: bool, + hygiene_diff: dict[str, list[_HygieneRow]], +) -> str: + doc = MdDoc() + doc.heading(1, "Profiling Run Comparison") + doc.table( + ["", "Target", "Baseline"], + [ + ["Profiling Run", + MdDoc.code(str(target_run.job_execution_id)), + MdDoc.code(str(baseline_run.job_execution_id))], + ["Started", target_run.profiling_starttime, baseline_run.profiling_starttime], + ], + ) + + if has_structural_changes: + doc.text( + "_Structural changes also occurred between these runs — " + "call `get_schema_history(table_group_id)` for the per-table/column diff._" + ) + + # Metric tables, one per general_type bucket + rendered_any_metric_table = False + for general_type in (_TYPE_NUMERIC, _TYPE_ALPHA, _TYPE_DATE, _TYPE_BOOLEAN): + headers, rows = _build_metric_rows_for_type(general_type, shared) + if rows: + rendered_any_metric_table = True + doc.heading(2, f"{_TYPE_LABELS[general_type]} columns") + doc.table(headers, rows, code=[0, 1]) + + categorical_lines = _categorical_lines(shared) + if categorical_lines: + doc.heading(2, "Categorical changes") + doc.bullets(categorical_lines) + + added = hygiene_diff["added"] + resolved = hygiene_diff["resolved"] + if added or resolved: + doc.heading(2, "Hygiene issues") + if resolved: + doc.heading(3, f"Resolved ({len(resolved)})") + doc.table( + ["Table", "Column", "Issue type"], + [[r.table_name, r.column_name, r.issue_type] for r in resolved], + code=[0, 1], + ) + if added: + doc.heading(3, f"Added ({len(added)})") + doc.table( + ["Table", "Column", "Issue type"], + [[r.table_name, r.column_name, r.issue_type] for r in added], + code=[0, 1], + ) + + if not (rendered_any_metric_table or categorical_lines or added or resolved or has_structural_changes): + doc.text("_No changes between target and baseline._") + + return doc.render() + + +# --------------------------------------------------------------------------- +# Profiling trends +# --------------------------------------------------------------------------- + + +@with_database_session +@mcp_permission("catalog") +def get_profiling_trends( + table_group_id: str, + metrics: list[str], + table_name: str | None = None, + column_name: str | None = None, + limit: int = 10, +) -> str: + """Show a time series of caller-named profiling metrics across recent completed runs of a table group. + + Metric scope rules: + - Column-level metrics (e.g. `Null Ratio`, `Average Length`, `Minimum Value`) require both + `table_name` and `column_name`. + - `Row Count` is table-level and requires `table_name`. + - `Profiling Score` and `Hygiene Issues` are table-group-level and accept any scope. + - Type-specific metrics return `—` for runs where the column's general type + didn't match (e.g. `Minimum Value` on a column that was Alpha in an earlier run). + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + metrics: One or more metric names. Accepted values: `Null Ratio`, `Distinct Ratio`, + `Filled Ratio`, `Row Count`, `Profiling Score`, `Hygiene Issues`, + `Minimum Length`, `Maximum Length`, `Average Length`, `Minimum Value`, + `Maximum Value`, `Average Value`, `Standard Deviation`, `Minimum Date`, + `Maximum Date`, `True Count`. + table_name: Optional — restrict to one table (case-sensitive). + column_name: Optional — restrict to one column (case-sensitive); requires + `table_name`. + limit: Number of most-recent completed runs to include (default 10, max 50). + """ + validate_limit(limit, 50) + if column_name is not None and table_name is None: + raise MCPUserError("`column_name` requires `table_name`.") + + tg = resolve_table_group(table_group_id) + metric_enums = parse_profile_metrics(metrics) + _validate_metric_scope(metric_enums, table_name, column_name) + + runs = ProfilingRun.list_recent_complete(tg.id, limit=limit) + if not runs: + return f"No completed profiling runs found for table group `{table_group_id}`." + + run_ids = [r.id for r in runs] + needs_profile_rows = any(_METRIC_SCOPE[m] in (_SCOPE_COLUMN, _SCOPE_TABLE) for m in metric_enums) + profile_by_run: dict[UUID, ProfileResult] = {} + if needs_profile_rows: + rows = ProfileResult.select_for_runs( + run_ids=run_ids, table_name=table_name, column_name=column_name, + ) + if column_name is not None: + profile_by_run = {row.profile_run_id: row for row in rows} + else: + # Table-only scope: there may be many ProfileResult rows per (run, table). + # All carry the same record_ct (table-level); take any. + for row in rows: + profile_by_run.setdefault(row.profile_run_id, row) + + hygiene_counts: dict[UUID, int] = {} + if ProfileMetric.HYGIENE_COUNT in metric_enums: + hygiene_counts = ProfilingRun.count_confirmed_hygiene_issues(run_ids) + + # Bound the entity's presence in the window. `first_seen_run` is the oldest run + # with a profile row; `last_seen_run` is the newest. When either differs from the + # window extreme on its side, a one-line note explains the leading/trailing `—` + # cells in the rendered trend table. + first_seen_run: ProfilingRun | None = None + last_seen_run: ProfilingRun | None = None + if needs_profile_rows: + for run in reversed(runs): + if run.id in profile_by_run: + first_seen_run = run + break + for run in runs: + if run.id in profile_by_run: + last_seen_run = run + break + + return _render_trends( + tg_name=tg.table_groups_name, + runs=runs, + metrics=metric_enums, + profile_by_run=profile_by_run, + hygiene_counts=hygiene_counts, + table_name=table_name, + column_name=column_name, + first_seen_run=first_seen_run, + last_seen_run=last_seen_run, + needs_profile_rows=needs_profile_rows, + ) + + +def _entity_label(table_name: str | None, column_name: str | None) -> str: + if column_name is not None: + return f"`{table_name}.{column_name}`" + if table_name is not None: + return f"`{table_name}`" + return "" + + +def _trend_cell( + metric: ProfileMetric, + run: ProfilingRun, + profile_by_run: dict[UUID, ProfileResult], + hygiene_counts: dict[UUID, int], +) -> str: + if metric is ProfileMetric.PROFILING_SCORE: + return _format_metric_value(metric, run.dq_score_profiling) + if metric is ProfileMetric.HYGIENE_COUNT: + return _format_metric_value(metric, hygiene_counts.get(run.id, 0)) + pr = profile_by_run.get(run.id) + return _format_metric_value(metric, _column_metric_value(metric, pr)) + + +def _render_trends( + tg_name: str, + runs: list[ProfilingRun], + metrics: list[ProfileMetric], + profile_by_run: dict[UUID, ProfileResult], + hygiene_counts: dict[UUID, int], + table_name: str | None, + column_name: str | None, + first_seen_run: ProfilingRun | None, + last_seen_run: ProfilingRun | None, + needs_profile_rows: bool, +) -> str: + doc = MdDoc() + entity = _entity_label(table_name, column_name) + title = f"Profiling trends for {entity} in `{tg_name}`" if entity else f"Profiling trends for `{tg_name}`" + doc.heading(1, title) + doc.field("Runs included", len(runs)) + doc.field("Oldest run", runs[-1].profiling_starttime) + doc.field("Newest run", runs[0].profiling_starttime) + + if needs_profile_rows and first_seen_run is None: + doc.text( + f"_{entity} not present in any of the last {len(runs)} runs — nothing to trend._" + ) + return doc.render() + + if ( + needs_profile_rows + and first_seen_run is not None + and first_seen_run.id != runs[-1].id + ): + doc.text( + f"_{entity} first appears in the run started " + f"{_format_run_label(first_seen_run)}._" + ) + if ( + needs_profile_rows + and last_seen_run is not None + and last_seen_run.id != runs[0].id + ): + doc.text( + f"_{entity} last appears in the run started " + f"{_format_run_label(last_seen_run)}._" + ) + + # Newest-first columns + headers = ["Metric", *(_format_run_label(run) for run in runs)] + rows: list[list[str]] = [] + for metric in metrics: + row = [metric.value] + for run in runs: + row.append(_trend_cell(metric, run, profile_by_run, hygiene_counts)) + rows.append(row) + doc.table(headers, rows) + + return doc.render() + + +# --------------------------------------------------------------------------- +# Schema history +# --------------------------------------------------------------------------- + + +class _TableSnapshot(NamedTuple): + columns: dict[str, "_ColumnSnapshot"] + record_ct: int | None + + +class _ColumnSnapshot(NamedTuple): + column_type: str | None + general_type: str | None + db_data_type: str | None + + +def _build_run_snapshots(rows: Iterable[ProfileResult]) -> dict[UUID, dict[tuple[str, str], _TableSnapshot]]: + """Reduce per-(run, table) profile rows to a column-snapshot map.""" + accumulator: dict[UUID, dict[tuple[str, str], dict[str, _ColumnSnapshot]]] = defaultdict(lambda: defaultdict(dict)) + record_ct: dict[UUID, dict[tuple[str, str], int | None]] = defaultdict(dict) + for row in rows: + run_id = row.profile_run_id + table_key = (row.schema_name, row.table_name) + accumulator[run_id][table_key][row.column_name] = _ColumnSnapshot( + column_type=row.column_type, + general_type=row.general_type, + db_data_type=row.db_data_type, + ) + # All rows in a (run, table) carry the same record_ct; first one wins. + record_ct[run_id].setdefault(table_key, row.record_ct) + + out: dict[UUID, dict[tuple[str, str], _TableSnapshot]] = {} + for run_id, table_columns in accumulator.items(): + out[run_id] = { + tk: _TableSnapshot(columns=cols, record_ct=record_ct[run_id].get(tk)) + for tk, cols in table_columns.items() + } + return out + + +@with_database_session +@mcp_permission("catalog") +def get_schema_history(table_group_id: str, limit: int = 10) -> str: + """Show a per-run timeline of structural changes across recent profiling runs — tables and columns added or dropped, type changes, and record-count deltas per table. + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + limit: Number of recent runs to render deltas for (default 10, max 20). One + additional anchor run is pulled when available so the oldest in-window + run has a baseline to diff against. + """ + validate_limit(limit, 20) + tg = resolve_table_group(table_group_id) + + runs = ProfilingRun.list_recent_complete(tg.id, limit=limit + 1) + if len(runs) < 2: + if not runs: + return f"No completed profiling runs found for table group `{tg.table_groups_name}`." + return ( + f"Only one completed profiling run exists for table group `{tg.table_groups_name}` — " + "at least two are needed to render a history." + ) + + run_ids = [r.id for r in runs] + rows = ProfileResult.select_for_runs(run_ids=run_ids) + snapshots = _build_run_snapshots(rows) + + return _render_schema_history(tg.table_groups_name, runs, snapshots) + + +def _render_schema_history( + tg_name: str, + runs: list[ProfilingRun], + snapshots: dict[UUID, dict[tuple[str, str], _TableSnapshot]], +) -> str: + doc = MdDoc() + doc.heading(1, f"Schema history for `{tg_name}`") + doc.field("Runs analyzed", len(runs) - 1) + doc.field("Window", f"{_format_run_label(runs[-1])} → {_format_run_label(runs[0])}") + + # Iterate newest → oldest, pairing each run with its predecessor. + for index in range(len(runs) - 1): + target = runs[index] + baseline = runs[index + 1] + section_lines = _format_schema_delta( + target_snap=snapshots.get(target.id, {}), + baseline_snap=snapshots.get(baseline.id, {}), + ) + doc.heading(2, f"Run started {_format_run_label(target)}") + doc.field("Profiling Run", target.job_execution_id, code=True) + if section_lines: + doc.bullets(section_lines) + else: + doc.text("_No structural change since previous run._") + + return doc.render() + + +def _format_run_label(run: ProfilingRun) -> str: + """Format a run's start time as ``YYYY-MM-DD HH:MM`` — short enough for column + headers, precise enough to disambiguate same-day runs.""" + return run.profiling_starttime.strftime("%Y-%m-%d %H:%M") + + +def _format_table_key(key: tuple[str, str]) -> str: + schema, table = key + return f"`{schema}.{table}`" if schema else f"`{table}`" + + +def _format_schema_delta( + target_snap: dict[tuple[str, str], _TableSnapshot], + baseline_snap: dict[tuple[str, str], _TableSnapshot], +) -> list[str]: + lines: list[str] = [] + target_tables = set(target_snap) + baseline_tables = set(baseline_snap) + + added_tables = sorted(target_tables - baseline_tables) + for key in added_tables: + col_ct = len(target_snap[key].columns) + lines.append(f"Table added: {_format_table_key(key)} ({col_ct} columns)") + + dropped_tables = sorted(baseline_tables - target_tables) + for key in dropped_tables: + col_ct = len(baseline_snap[key].columns) + lines.append(f"Table dropped: {_format_table_key(key)} ({col_ct} columns)") + + for key in sorted(target_tables & baseline_tables): + target_table = target_snap[key] + baseline_table = baseline_snap[key] + column_changes = _format_column_delta(target_table.columns, baseline_table.columns) + record_delta = _format_record_delta(target_table.record_ct, baseline_table.record_ct) + for change in column_changes: + lines.append(f"{_format_table_key(key)}: {change}") + if record_delta is not None: + lines.append(f"{_format_table_key(key)}: Record count {record_delta}") + return lines + + +def _format_column_delta( + target_cols: dict[str, _ColumnSnapshot], + baseline_cols: dict[str, _ColumnSnapshot], +) -> list[str]: + out: list[str] = [] + target_names = set(target_cols) + baseline_names = set(baseline_cols) + for name in sorted(target_names - baseline_names): + snap = target_cols[name] + type_label = snap.column_type or snap.db_data_type + out.append(f"column `{name}` added ({type_label})" if type_label else f"column `{name}` added") + for name in sorted(baseline_names - target_names): + snap = baseline_cols[name] + type_label = snap.column_type or snap.db_data_type + out.append(f"column `{name}` dropped (was {type_label})" if type_label else f"column `{name}` dropped") + for name in sorted(target_names & baseline_names): + target_col = target_cols[name] + baseline_col = baseline_cols[name] + if target_col.column_type != baseline_col.column_type and target_col.column_type and baseline_col.column_type: + out.append( + f"column `{name}` retyped: {baseline_col.column_type} → {target_col.column_type}" + ) + return out + + +def _format_record_delta(target_ct: int | None, baseline_ct: int | None) -> str | None: + if target_ct is None or baseline_ct is None: + return None + if target_ct == baseline_ct: + return None + return f"{baseline_ct:,} → {target_ct:,}" diff --git a/testgen/mcp/tools/profiling.py b/testgen/mcp/tools/profiling.py index 9d293425..79ed638c 100644 --- a/testgen/mcp/tools/profiling.py +++ b/testgen/mcp/tools/profiling.py @@ -1,18 +1,44 @@ +import dataclasses from uuid import UUID +from sqlalchemy import func, or_ + from testgen.common.models import with_database_session -from testgen.common.models.data_column import ColumnProfileSummary, DataColumnChars +from testgen.common.models.data_column import ( + SUGGESTED_DATA_TYPE_TO_PREFIX, + ColumnOrderBy, + ColumnProfileDetail, + ColumnProfileSummary, + DataColumnChars, +) from testgen.common.models.data_table import DataTable -from testgen.common.models.profiling_run import ProfilingRun +from testgen.common.models.job_execution import JobExecution +from testgen.common.models.profile_result import ProfileResult +from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunSummary +from testgen.common.models.scheduler import RUN_PROFILE_JOB_KEY from testgen.common.models.table_group import TableGroup, TableGroupSummary +from testgen.common.pii_masking import PII_REDACTED, mask_profiling_pii +from testgen.common.profile_top_values import parse_top_freq_values, parse_top_patterns from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import get_project_permissions, mcp_permission from testgen.mcp.tools.common import ( DocGroup, + build_ilike_pattern, format_page_footer, format_page_info, + format_run_duration, + next_scheduled_run, + parse_column_order_by, + parse_general_type, + parse_pii_category, + parse_pii_risk_level, + parse_run_status_filter, + parse_suggested_data_type, parse_uuid, + resolve_profiling_run, resolve_table_group, + validate_limit, + validate_page, ) from testgen.mcp.tools.markdown import MdDoc from testgen.utils import friendly_score @@ -71,28 +97,80 @@ def list_column_profiles( table_name: str | None = None, columns: list[str] | None = None, job_execution_id: str | None = None, + null_ratio_above: float | None = None, + null_ratio_below: float | None = None, + distinct_ratio_above: float | None = None, + distinct_ratio_below: float | None = None, + filled_ratio_above: float | None = None, + filled_ratio_below: float | None = None, + score_profiling_above: float | None = None, + score_profiling_below: float | None = None, + score_testing_above: float | None = None, + score_testing_below: float | None = None, + pii: bool | None = None, + cde: bool | None = None, + suggested_data_type: str | None = None, + general_type: str | None = None, + semantic_data_type: str | None = None, + pii_category: str | None = None, + pii_risk_level: str | None = None, + order_by: str | None = None, limit: int = 100, page: int = 1, ) -> str: - """List per-column profile headers (~14 fields each) — the Layer 1 scan of profiling results across columns in a table group. + """List per-column profile headers across a table group, with optional profile-predicate filters. Args: table_group_id: UUID of the table group, e.g. from `get_data_inventory`. table_name: Optional — scope to one table (case-sensitive). columns: Optional — specific column names to include (case-sensitive). job_execution_id: UUID of a profiling run, e.g. from `get_table` or - `list_profiling_summaries`. When omitted, each column uses its own - latest run. - limit: Page size (default 100). + `list_profiling_summaries`. When omitted, each column uses its own latest run. + null_ratio_above: Match columns whose null fraction exceeds this value + (e.g. `0.2` for above 20% null). + null_ratio_below: Match columns whose null fraction is below this value. + distinct_ratio_above: Match columns whose distinct-value fraction exceeds this + value (e.g. `0.95` for near-unique columns). + distinct_ratio_below: Match columns whose distinct-value fraction is below this + value (e.g. `0.001` for low cardinality). + filled_ratio_above: Match columns whose dummy/placeholder-value fraction exceeds + this value. + filled_ratio_below: Match columns whose dummy/placeholder-value fraction is below + this value. + score_profiling_above: Match columns whose Profiling Score is above this value (0-100 scale). + score_profiling_below: Match columns whose Profiling Score is below this value (0-100 scale). + score_testing_above: Match columns whose Testing Score is above this value (0-100 scale). + score_testing_below: Match columns whose Testing Score is below this value (0-100 scale). + pii: When `true`, match columns flagged as PII; when `false`, exclude PII columns. + cde: When `true`, match columns flagged as a Critical Data Element (directly + or inherited from the table); when `false`, exclude CDE columns. + suggested_data_type: Match columns where profiling suggests a more suitable data + type. Pass `Any` for any mismatch, or a concrete type (`Smallint`, `Integer`, + `Bigint`, `Decimal`, `Numeric`, `Varchar`, `Date`, `Timestamp`, `Boolean`) to + filter mismatches whose suggestion starts with that type. Columns where the + suggestion matches the column's stored type are always excluded. + general_type: Broad type classification — + `Alpha`, `Numeric`, `Datetime`, `Boolean`, `Time`, or `Other`. + semantic_data_type: Substring match (case-insensitive) on Semantic Data Type. + Bare tokens auto-wrap with `%`; an explicit `%` is honored as a wildcard. + See `testgen://column-profile-fields` for the canonical value list. + pii_category: PII category — `ID`, `Name`, `Demographic`, or `Contact`. + pii_risk_level: PII risk level — `High`, `Moderate`, or `Low`. + order_by: Sort key — `Null Ratio`, `Distinct Ratio`, `Filled Ratio`, + `Profiling Score`, `Testing Score`, or `Hygiene Count`. Defaults to + table/column position. + limit: Page size (default 100, max 500). page: Page number starting at 1 (default 1). """ + validate_page(page) + validate_limit(limit, 500) + tg = resolve_table_group(table_group_id) profiling_run_id: UUID | None = None if job_execution_id: - run_uuid = parse_uuid(job_execution_id, "job_execution_id") - profiling_run = ProfilingRun.get_by_id_or_job(run_uuid) - if profiling_run is None or profiling_run.table_groups_id != tg.id: + profiling_run = resolve_profiling_run(job_execution_id) + if profiling_run.table_groups_id != tg.id: raise MCPResourceNotAccessible("Profiling run", job_execution_id) profiling_run_id = profiling_run.id @@ -102,10 +180,98 @@ def list_column_profiles( if columns: clauses.append(DataColumnChars.column_name.in_(columns)) + if null_ratio_above is not None: + clauses.append(ProfileResult.null_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) > null_ratio_above) + if null_ratio_below is not None: + clauses.append(ProfileResult.null_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) < null_ratio_below) + if distinct_ratio_above is not None: + clauses.append( + ProfileResult.distinct_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) > distinct_ratio_above + ) + if distinct_ratio_below is not None: + clauses.append( + ProfileResult.distinct_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) < distinct_ratio_below + ) + if filled_ratio_above is not None: + clauses.append( + ProfileResult.filled_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) > filled_ratio_above + ) + if filled_ratio_below is not None: + clauses.append( + ProfileResult.filled_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) < filled_ratio_below + ) + + if score_profiling_above is not None: + clauses.append(DataColumnChars.dq_score_profiling > score_profiling_above / 100) + if score_profiling_below is not None: + clauses.append(DataColumnChars.dq_score_profiling < score_profiling_below / 100) + if score_testing_above is not None: + clauses.append(DataColumnChars.dq_score_testing > score_testing_above / 100) + if score_testing_below is not None: + clauses.append(DataColumnChars.dq_score_testing < score_testing_below / 100) + + if pii is True: + clauses.append(DataColumnChars.pii_flag.isnot(None)) + elif pii is False: + clauses.append(DataColumnChars.pii_flag.is_(None)) + + if cde is True: + # A column is a CDE when either it or its parent table is flagged. + clauses.append( + or_( + DataColumnChars.critical_data_element.is_(True), + DataTable.critical_data_element.is_(True), + ) + ) + elif cde is False: + clauses.append( + DataColumnChars.critical_data_element.isnot(True), + ) + clauses.append( + DataTable.critical_data_element.isnot(True), + ) + + if suggested_data_type is not None: + prefix = SUGGESTED_DATA_TYPE_TO_PREFIX[parse_suggested_data_type(suggested_data_type)] + if prefix is None: + clauses.append(ProfileResult.datatype_suggestion.isnot(None)) + else: + clauses.append(ProfileResult.datatype_suggestion.ilike(f"{prefix}%")) + + if general_type is not None: + clauses.append(DataColumnChars.general_type == parse_general_type(general_type)) + if semantic_data_type is not None: + if not semantic_data_type.strip(): + raise MCPUserError("`semantic_data_type` cannot be empty.") + clauses.append( + DataColumnChars.functional_data_type.ilike( + build_ilike_pattern(semantic_data_type), escape="\\" + ) + ) + if pii_category is not None: + category = parse_pii_category(pii_category) + # ``pii_flag`` stores ``//``; match on the middle segment. + clauses.append(DataColumnChars.pii_flag.like(f"%/{category}/%")) + if pii_risk_level is not None: + risk_code = parse_pii_risk_level(pii_risk_level) + # ``MANUAL`` is user-set PII, weighted equivalent to ``A`` (High) by ``dq_score_weight_defaults``. + if risk_code == "A": + clauses.append( + or_( + DataColumnChars.pii_flag.like("A/%"), + DataColumnChars.pii_flag == "MANUAL", + ) + ) + else: + clauses.append(DataColumnChars.pii_flag.like(f"{risk_code}/%")) + + order_value: ColumnOrderBy | None = parse_column_order_by(order_by) if order_by else None + data, total = DataColumnChars.list_for_table_group( *clauses, table_groups_id=tg.id, profiling_run_id=profiling_run_id, + order_by=order_value, page=page, limit=limit, ) @@ -209,9 +375,9 @@ def list_profiling_summaries( def _format_pii(value: str | None) -> str | None: """Render a `pii_flag` value as a human label. Mirrors `PiiDisplay` in metadata_tags.js.""" if not value: - return None + return "No" if value == "MANUAL": - return "PII" + return "Yes" risk, _, rest = value.partition("/") type_code, _, detail = rest.partition("/") risk_label = _PII_RISK_MAP.get(risk, "Moderate") @@ -221,7 +387,7 @@ def _format_pii(value: str | None) -> str | None: caption += f" - {type_label}" if detail and detail != type_label: caption += f" / {detail}" - return f"PII ({caption})" + return f"Yes ({caption})" def _render_column_profile_row(c: ColumnProfileSummary) -> list: @@ -243,6 +409,193 @@ def _render_column_profile_row(c: ColumnProfileSummary) -> list: ] +@with_database_session +@mcp_permission("catalog") +def list_profiling_runs( + table_group_id: str, + status: str | None = None, + limit: int = 10, + page: int = 1, +) -> str: + """List profiling run history for a table group, including queued, in-progress, and failed runs. + Ordered by submission time descending. + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + status: Optional run status filter. One of: Pending, Running, Completed, Canceled, Error. + limit: Page size (default 10, max 100). + page: Page number starting at 1 (default 1). + """ + validate_limit(limit, 100) + validate_page(page) + + statuses = parse_run_status_filter(status) if status else None + tg = resolve_table_group(table_group_id) + + summaries, total = ProfilingRun.select_summary( + project_code=tg.project_code, + table_group_id=tg.id, + statuses=statuses, + page=page, + page_size=limit, + ) + + # Queued/claimed JEs that don't yet have a profiling_runs row are invisible to TG-scoped + # joined-run queries. Surface them as a separate "Pending" section on page 1. + pending_jes: list[JobExecution] = [] + if page == 1: + pending_jes = JobExecution.select_active_by_kwargs( + project_code=tg.project_code, + job_key=RUN_PROFILE_JOB_KEY, + kwargs_match={"table_group_id": str(tg.id)}, + statuses=statuses, + ) + + doc = MdDoc() + scope = f" — status `{status}`" if status else "" + doc.heading(1, f"Profiling runs for `{tg.table_groups_name}`{scope}") + + next_run = next_scheduled_run( + RUN_PROFILE_JOB_KEY, {"table_group_id": str(tg.id)}, tg.project_code + ) + if next_run: + doc.field("Next scheduled run", next_run) + + if pending_jes: + doc.heading(2, f"Pending ({len(pending_jes)})") + for je in pending_jes: + _render_pending_profiling_je(doc, je, label=tg.table_groups_name) + + page_info = format_page_info(total, page, limit) + if page_info: + doc.text(page_info) + + if not summaries: + if page > 1: + doc.text(f"_No profiling runs on page {page} (total: {total})._") + elif not pending_jes: + doc.text("_No profiling runs found._") + return doc.render() + + for run in summaries: + _render_profiling_run_section(doc, run) + + footer = format_page_footer(total, page, limit) + if footer: + doc.text(footer) + + return doc.render() + + +@with_database_session +@mcp_permission("catalog") +def get_profiling_run(job_execution_id: str) -> str: + """Get a single profiling run with status, timing, totals, and per-table breakdown. Returns the + run regardless of state — including queued and in-progress runs without complete results yet. + The per-table breakdown is only available after the run completes. + + Args: + job_execution_id: UUID of a profiling run, e.g. from `list_profiling_runs` or + `list_profiling_summaries`. + """ + parse_uuid(job_execution_id, "job_execution_id") + perms = get_project_permissions() + + summaries, _ = ProfilingRun.select_summary(job_execution_id=job_execution_id, page_size=1) + summary = summaries[0] if summaries else None + if summary is None or summary.project_code not in perms.allowed_codes: + raise MCPResourceNotAccessible("Profiling run", job_execution_id) + + doc = MdDoc() + tg_label = summary.table_groups_name or "—" + doc.heading(1, f"Profiling run: {tg_label}") + doc.field("Job ID", summary.job_execution_id, code=True) + if summary.table_groups_name: + doc.field("Table group", summary.table_groups_name) + if summary.table_group_schema: + doc.field("Schema", summary.table_group_schema) + doc.field("Status", summary.status_label) + doc.field("Submitted", summary.created_at) + doc.field("Started", summary.started_at or "—") + doc.field("Ended", summary.completed_at or "In progress") + duration = format_run_duration(summary.started_at, summary.completed_at) + if duration: + doc.field("Duration", duration) + + has_totals = summary.table_ct or summary.column_ct or summary.record_ct or summary.anomaly_ct + if has_totals: + doc.field("Tables profiled", summary.table_ct or 0) + doc.field("Columns profiled", summary.column_ct or 0) + if summary.record_ct is not None: + doc.field("Records", summary.record_ct) + doc.field( + "Hygiene issues (confirmed)", + f"{(summary.anomalies_definite_ct or 0) + (summary.anomalies_likely_ct or 0) + (summary.anomalies_possible_ct or 0)} total " + f"— {summary.anomalies_definite_ct or 0} definite, " + f"{summary.anomalies_likely_ct or 0} likely, " + f"{summary.anomalies_possible_ct or 0} possible", + ) + if summary.dq_score_profiling is not None: + doc.field("Profiling Score", friendly_score(summary.dq_score_profiling)) + + if summary.profiling_run_id: + breakdown = ProfilingRun.select_table_breakdown(summary.profiling_run_id) + if breakdown: + doc.heading(2, "Per-table breakdown") + doc.table( + ["Schema", "Table", "Records", "Columns", "Hygiene issues"], + [ + [r.schema_name, r.table_name, r.record_ct, r.column_ct, r.anomaly_ct] + for r in breakdown + ], + code=[0, 1], + ) + + if summary.error_message: + doc.heading(2, "Error") + doc.text(summary.error_message) + + return doc.render() + + +def _render_pending_profiling_je(doc: MdDoc, je: JobExecution, label: str) -> None: + status_label = ProfilingRunSummary.STATUS_LABEL.get(je.status, je.status) + doc.heading(3, f"{label} — {status_label}") + doc.field("Job ID", je.id, code=True) + if je.job_schedule_id is not None: + doc.field("Schedule", je.job_schedule_id, code=True) + doc.field("Submitted", je.created_at) + doc.field("Started", je.started_at or "—") + doc.field("Ended", je.completed_at or "In progress") + + +def _render_profiling_run_section(doc: MdDoc, run: ProfilingRunSummary) -> None: + title = run.table_groups_name or run.profiling_run_id or run.job_execution_id + doc.heading(2, f"{title} — {run.status_label}") + doc.field("Job ID", run.job_execution_id, code=True) + if run.job_schedule_id is not None: + doc.field("Schedule", run.job_schedule_id, code=True) + doc.field("Submitted", run.created_at) + doc.field("Started", run.started_at or "—") + doc.field("Ended", run.completed_at or "In progress") + duration = format_run_duration(run.started_at, run.completed_at) + if duration: + doc.field("Duration", duration) + + if run.table_ct or run.column_ct: + doc.field("Tables profiled", run.table_ct or 0) + doc.field("Columns profiled", run.column_ct or 0) + if run.anomaly_ct is not None and ( + run.anomalies_definite_ct or run.anomalies_likely_ct or run.anomalies_possible_ct + ): + doc.field( + "Hygiene issues (confirmed)", + f"{(run.anomalies_definite_ct or 0) + (run.anomalies_likely_ct or 0) + (run.anomalies_possible_ct or 0)} total", + ) + if run.dq_score_profiling is not None: + doc.field("Profiling Score", friendly_score(run.dq_score_profiling)) + + def _render_table_group_summary(doc: MdDoc, s: TableGroupSummary) -> None: doc.heading(2, s.table_groups_name) if s.connection_name: @@ -269,3 +622,517 @@ def _render_table_group_summary(doc: MdDoc, s: TableGroupSummary) -> None: doc.field("Profiling Run", s.latest_profile_job_execution_id, code=True) if s.monitor_lookback_end: doc.field("Last monitored", s.monitor_lookback_end) + + +# --------------------------------------------------------------------------- +# get_column_profile_detail +# --------------------------------------------------------------------------- + +# Friendly labels for `std_pattern_match` — mirrors `standardPatternLabels` in +# `ui/components/frontend/js/data_profiling/column_distribution.js`. +_STD_PATTERN_LABELS = { + "STREET_ADDR": "Street Address", + "STATE_USA": "State (USA)", + "PHONE_USA": "Phone (USA)", + "EMAIL": "Email", + "ZIP_USA": "Zip Code (USA)", + "FILE_NAME": "Filename", + "CREDIT_CARD": "Credit Card", + "DELIMITED_DATA": "Delimited Data", + "SSN": "SSN (USA)", +} + + +def _format_std_pattern(value: str | None) -> str | None: + if not value: + return None + return _STD_PATTERN_LABELS.get(value, value.replace("_", " ").title()) + + +# --------------------------------------------------------------------------- +# Shared helpers for single-column tools (frequent values, patterns) +# --------------------------------------------------------------------------- + + +def _load_profile_for_column( + tg: TableGroup, + table_name: str, + column_name: str, + job_execution_id: str | None, +) -> tuple[ProfileResult, ProfilingRun, str | None]: + """Resolve and load the profile-results row for one column. + + Returns a triple of ``(profile, profiling_run, pii_flag)`` where ``pii_flag`` is + pulled from ``data_column_chars`` (the source of truth for column-level PII state). + """ + profiling_run: ProfilingRun | None = None + if job_execution_id: + profiling_run = resolve_profiling_run(job_execution_id) + if profiling_run.table_groups_id != tg.id: + raise MCPResourceNotAccessible("Profiling run", job_execution_id) + profile = ProfileResult.get_for_column( + table_groups_id=tg.id, + table_name=table_name, + column_name=column_name, + profiling_run_id=profiling_run.id if profiling_run else None, + ) + if profile is None: + raise MCPResourceNotAccessible("Column profile", f"{table_name}.{column_name}") + if profiling_run is None: + profiling_run = ProfilingRun.get(profile.profile_run_id) + if profiling_run is None: + raise MCPResourceNotAccessible("Profiling run", str(profile.profile_run_id)) + column_rows = list(DataColumnChars.select_where( + DataColumnChars.table_groups_id == tg.id, + DataColumnChars.table_name == table_name, + DataColumnChars.column_name == column_name, + )) + pii_flag = column_rows[0].pii_flag if column_rows else None + return profile, profiling_run, pii_flag + + +def _is_pii_redacted_for_caller(tg: TableGroup, pii_flag: str | None) -> bool: + """Decide whether to redact PII values for this caller + column.""" + if not pii_flag: + return False + return not get_project_permissions().has_permission("view_pii", tg.project_code) + + +@with_database_session +@mcp_permission("catalog") +def get_column_profile_detail( + table_group_id: str, + table_name: str, + column_name: str, + job_execution_id: str | None = None, +) -> str: + """Get the type-specific value distribution and statistics for one column from its profiling run. + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + table_name: Table name exactly as stored in TestGen (case-sensitive). + column_name: Column name exactly as stored in TestGen (case-sensitive). + job_execution_id: UUID of a profiling run, e.g. from `list_profiling_summaries`. + When omitted, uses the column's latest complete run. + """ + tg = resolve_table_group(table_group_id) + + profiling_run_id: UUID | None = None + if job_execution_id: + profiling_run = resolve_profiling_run(job_execution_id) + if profiling_run.table_groups_id != tg.id: + raise MCPResourceNotAccessible("Profiling run", job_execution_id) + profiling_run_id = profiling_run.id + + detail = DataColumnChars.get_column_detail( + table_groups_id=tg.id, + table_name=table_name, + column_name=column_name, + profiling_run_id=profiling_run_id, + ) + if detail is None: + raise MCPResourceNotAccessible("Column", column_name) + + if detail.profile_run_id is None: + if job_execution_id: + raise MCPUserError( + f"Profiling run `{job_execution_id}` did not include column `{column_name}`." + ) + raise MCPUserError( + f"Column `{column_name}` has not been profiled yet. " + "Run profiling for the table group first." + ) + + if detail.profile_run_status in ("Running", "Error", "Cancelled"): + _raise_run_not_ready(detail) + + payload = dataclasses.asdict(detail) + if detail.pii_flag and not get_project_permissions().has_permission("view_pii", tg.project_code): + mask_profiling_pii(payload, {detail.column_name}) + + return _render_column_profile_detail(payload) + + +def _raise_run_not_ready(detail: ColumnProfileDetail) -> None: + """Reject when the resolved profiling run is in `Running` or `Error` state. + + Surface the run id, status, started/ended timestamps, and `log_message` (Error only) + in the raised error so the LLM knows what to suggest next. + """ + je = detail.profile_run_je_id + status = detail.profile_run_status + started = detail.profile_run_started_at + ended = detail.profile_run_ended_at + started_label = started.strftime("%Y-%m-%d %H:%M UTC") if started else "—" + ended_label = ended.strftime("%Y-%m-%d %H:%M UTC") if ended else "—" + lines = [ + f"Profiling run `{je}` is in `{status}` state — no profile detail available.", + f"Started: {started_label}. Ended: {ended_label}.", + ] + if status == "Error" and detail.profile_run_log_message: + lines.append(f"Error: {detail.profile_run_log_message}") + raise MCPUserError("\n".join(lines)) + + +def _render_column_profile_detail(p: dict) -> str: + """Render a column profile detail payload as grouped Markdown sections.""" + doc = MdDoc() + fq_name = f"{p['schema_name']}.{p['table_name']}" if p["schema_name"] else p["table_name"] + doc.heading(1, f"Column Profile: `{p['column_name']}` in `{fq_name}`") + + general_type = p.get("general_type") + + # Run identity + L1 header fields + doc.field("Profiling Run", p["profile_run_je_id"], code=True) + doc.field("Profiled at", p["profile_run_started_at"]) + doc.field("General Type", _format_general_type(general_type)) + doc.field("Data Type", p["db_data_type"]) + doc.field("Semantic Data Type", p["functional_data_type"]) + if p.get("datatype_suggestion"): + doc.field("Suggested Data Type", p["datatype_suggestion"]) + doc.field("PII", _format_pii(p.get("pii_flag"))) + doc.field("Critical Data Element", p.get("critical_data_element") or False) + doc.field("Profiling Score", friendly_score(p.get("dq_score_profiling"))) + doc.field("Testing Score", friendly_score(p.get("dq_score_testing"))) + + if not p.get("query_error"): + doc.field("Hygiene Issues (confirmed)", p.get("hygiene_issue_count", 0)) + + # Type-specific dispatch (T and unknown fall through to common-counts only) + if general_type == "A": + _render_alpha_block(doc, p) + elif general_type == "N": + _render_numeric_block(doc, p) + elif general_type == "D": + _render_date_block(doc, p) + elif general_type == "B": + _render_boolean_block(doc, p) + else: + _render_unknown_block(doc, p) + else: + doc.heading(2, "Profiling Error") + doc.text(p["query_error"]) + + return doc.render() + + +_FIELD_GENERAL_TYPE_LABELS = { + "A": "Alpha", + "B": "Boolean", + "D": "Date", + "N": "Numeric", + "T": "Time", + "X": "Other", +} + + +def _format_general_type(value: str) -> str: + return _FIELD_GENERAL_TYPE_LABELS.get(value or "X") + + +def _render_counts(doc: MdDoc, p: dict) -> None: + doc.heading(2, "Counts") + doc.field("Row Count", p.get("record_ct")) + doc.field("Value Count", p.get("value_ct")) + doc.field("Distinct Values", p.get("distinct_value_ct")) + doc.field("Null", p.get("null_value_ct")) + doc.field("Dummy Values", p.get("filled_value_ct")) + doc.field("Zero Values", p.get("zero_value_ct")) + + +def _render_alpha_block(doc: MdDoc, p: dict) -> None: + _render_counts(doc, p) + doc.field("Zero Length", p.get("zero_length_ct")) + + doc.heading(2, "Length") + doc.field("Minimum Length", p.get("min_length")) + doc.field("Maximum Length", p.get("max_length")) + doc.field("Average Length", p.get("avg_length")) + + doc.heading(2, "Text Range") + doc.field("Minimum Text", p.get("min_text")) + doc.field("Maximum Text", p.get("max_text")) + + doc.heading(2, "Patterns") + doc.field("Standard Pattern Match", _format_std_pattern(p.get("std_pattern_match"))) + doc.field("Distinct Patterns", p.get("distinct_pattern_ct")) + doc.field("Frequent Patterns", p.get("top_patterns")) + doc.field("Frequent Values", p.get("top_freq_values")) + doc.field("Distinct Standard Values", p.get("distinct_std_value_ct")) + + doc.heading(2, "Case & Composition") + doc.field("Upper Case", p.get("upper_case_ct")) + doc.field("Lower Case", p.get("lower_case_ct")) + doc.field("Mixed Case", p.get("mixed_case_ct")) + doc.field("Non-Alpha", p.get("non_alpha_ct")) + doc.field("Includes Digits", p.get("includes_digit_ct")) + doc.field("Numeric Values", p.get("numeric_ct")) + doc.field("Date Values", p.get("date_ct")) + doc.field("Quoted Values", p.get("quoted_value_ct")) + doc.field("Leading Spaces", p.get("lead_space_ct")) + doc.field("Embedded Spaces", p.get("embedded_space_ct")) + doc.field("Average Embedded Spaces", p.get("avg_embedded_spaces")) + + +def _render_numeric_block(doc: MdDoc, p: dict) -> None: + _render_counts(doc, p) + + doc.heading(2, "Distribution") + doc.field("Minimum Value", p.get("min_value")) + doc.field("Minimum Value > 0", p.get("min_value_over_0")) + doc.field("Maximum Value", p.get("max_value")) + doc.field("Average Value", p.get("avg_value")) + doc.field("Standard Deviation", p.get("stdev_value")) + + doc.heading(2, "Percentiles") + doc.field("25th Percentile", p.get("percentile_25")) + doc.field("Median Value", p.get("percentile_50")) + doc.field("75th Percentile", p.get("percentile_75")) + + +def _render_date_block(doc: MdDoc, p: dict) -> None: + _render_counts(doc, p) + + doc.heading(2, "Date Range") + doc.field("Minimum Date", p.get("min_date")) + doc.field("Maximum Date", p.get("max_date")) + + doc.heading(2, "Age Buckets") + doc.field("Before 1 Year", p.get("before_1yr_date_ct")) + doc.field("Before 5 Years", p.get("before_5yr_date_ct")) + doc.field("Before 20 Years", p.get("before_20yr_date_ct")) + doc.field("Within 1 Year", p.get("within_1yr_date_ct")) + doc.field("Within 1 Month", p.get("within_1mo_date_ct")) + doc.field("Future Dates", p.get("future_date_ct")) + + +def _render_boolean_block(doc: MdDoc, p: dict) -> None: + _render_counts(doc, p) + + doc.heading(2, "Boolean Distribution") + true_ct = p.get("boolean_true_ct") or 0 + value_ct = p.get("value_ct") or 0 + false_ct = max(value_ct - true_ct, 0) + doc.field("True Count", true_ct) + doc.field("False Count", false_ct) + + +def _render_unknown_block(doc: MdDoc, p: dict) -> None: + _render_counts(doc, p) + + +# --------------------------------------------------------------------------- +# Single-column tools — frequent values and patterns +# --------------------------------------------------------------------------- + + +@with_database_session +@mcp_permission("catalog") +def get_column_frequent_values( + table_group_id: str, + table_name: str, + column_name: str, + job_execution_id: str | None = None, +) -> str: + """Get the top frequent values for one column from its profile run, with row counts and percentages. + + Profiling captures the top 10 values; when the column has more distinct values, a + trailing `Other Values (N)` row aggregates the remainder. + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + table_name: Table name exactly as stored in TestGen (case-sensitive). + column_name: Column name exactly as stored in TestGen (case-sensitive). + job_execution_id: UUID of a profiling run. When omitted, uses the column's + latest profile run. + """ + tg = resolve_table_group(table_group_id) + profile, profiling_run, pii_flag = _load_profile_for_column(tg, table_name, column_name, job_execution_id) + + doc = MdDoc() + doc.heading(1, f"Frequent values: {table_name}.{column_name}") + doc.field("Table group", tg.id, code=True) + doc.field("Profiling Run", profiling_run.job_execution_id, code=True) + doc.field("Row Count", profile.record_ct) + doc.field("Distinct values", profile.distinct_value_ct) + if pii_flag: + doc.field("PII", _format_pii(pii_flag)) + + rows = parse_top_freq_values(profile.top_freq_values) + if not rows: + doc.text( + f"_Frequency data not available — high cardinality " + f"(distinct count: {profile.distinct_value_ct})._" + ) + return doc.render() + + redact = _is_pii_redacted_for_caller(tg, pii_flag) + record_ct = profile.record_ct or 0 + display_rows: list[list[object]] = [] + for value, count in rows: + pct = (count / record_ct * 100) if record_ct else None + display_value = PII_REDACTED if redact else value + display_rows.append([display_value, count, f"{pct:.2f}%" if pct is not None else None]) + + doc.heading(2, "Top values") + doc.table(["Value", "Count", "% of records"], display_rows) + return doc.render() + + +@with_database_session +@mcp_permission("catalog") +def get_column_patterns( + table_group_id: str, + table_name: str, + column_name: str, + job_execution_id: str | None = None, +) -> str: + """Get the top character patterns for one string column from its profile run. + + Patterns use shorthand: `A` = uppercase letter, `a` = lowercase letter, `N` = digit; + every other character (whitespace, punctuation, symbols) appears literally. Examples: + `Aaaaaaaa` (capitalized word), `NNNN-NN-NN` (ISO-like date), `aaa@aaa.aaa` (email-shaped). + Profiling captures the top 5 patterns. + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + table_name: Table name exactly as stored in TestGen (case-sensitive). + column_name: Column name exactly as stored in TestGen (case-sensitive). + job_execution_id: UUID of a profiling run. When omitted, uses the column's + latest profile run. + """ + tg = resolve_table_group(table_group_id) + profile, profiling_run, _ = _load_profile_for_column(tg, table_name, column_name, job_execution_id) + + doc = MdDoc() + doc.heading(1, f"Character patterns: {table_name}.{column_name}") + doc.field("Table group", tg.id, code=True) + doc.field("Profiling Run", profiling_run.job_execution_id, code=True) + doc.field("Row Count", profile.record_ct) + doc.field("Distinct values", profile.distinct_value_ct) + + if profile.general_type and profile.general_type != "A": + doc.text("_Pattern data not available — column is not a string type._") + return doc.render() + + rows = parse_top_patterns(profile.top_patterns) + if not rows: + doc.text( + f"_Pattern data not available — high cardinality " + f"(distinct count: {profile.distinct_value_ct})._" + ) + return doc.render() + + record_ct = profile.record_ct or 0 + display_rows: list[list[object]] = [] + for pattern, count in rows: + pct = (count / record_ct * 100) if record_ct else None + display_rows.append([pattern, count, f"{pct:.2f}%" if pct is not None else None]) + + doc.heading(2, "Top patterns") + doc.table(["Pattern", "Count", "% of records"], display_rows, code=[0]) + return doc.render() + + +# --------------------------------------------------------------------------- +# Cross-scope column-name search +# --------------------------------------------------------------------------- + + +@with_database_session +@mcp_permission("catalog") +def search_columns( + pattern: str, + project_code: str | None = None, + table_group_id: str | None = None, + limit: int = 100, + page: int = 1, +) -> str: + """Search columns by name across one or many projects (bare tokens auto-wrap as `%token%`; explicit `%` honored as a wildcard). + + Args: + pattern: Column-name search pattern. Case-insensitive. + project_code: Optional — scope to one project. Mutually exclusive with + `table_group_id`. + table_group_id: Optional — scope to one table group. Mutually exclusive + with `project_code`. + limit: Page size (default 100, max 500). + page: Page number starting at 1 (default 1). + """ + validate_page(page) + validate_limit(limit, 500) + + if not pattern or not pattern.strip(): + raise MCPUserError("`pattern` is required and cannot be empty.") + effective_pattern = build_ilike_pattern(pattern) + + if project_code is not None and table_group_id is not None: + raise MCPUserError("Pass either `project_code` or `table_group_id`, not both.") + + perms = get_project_permissions() + clauses: list = [] + + if table_group_id is not None: + tg = resolve_table_group(table_group_id) + clauses.append(DataColumnChars.table_groups_id == tg.id) + scope_label = f"table group `{table_group_id}`" + elif project_code is not None: + perms.verify_access( + project_code, + not_found=MCPResourceNotAccessible("Project", project_code), + ) + clauses.append(TableGroup.project_code == project_code) + scope_label = f"project `{project_code}`" + else: + # The @mcp_permission decorator guarantees ``allowed_codes`` is non-empty by + # the time the body runs (it raises MCPPermissionDenied otherwise). + clauses.append(TableGroup.project_code.in_(list(perms.allowed_codes))) + scope_label = "all accessible projects" + + data, total = DataColumnChars.search_by_name( + *clauses, + pattern=effective_pattern, + page=page, + limit=limit, + ) + + if not data: + if page > 1: + return f"No columns matching `{pattern}` on page {page} (total: {total})." + return f"No columns matching `{pattern}` in {scope_label}." + + doc = MdDoc() + doc.heading(1, f"Columns matching `{pattern}` in {scope_label}") + + page_info = format_page_info(total, page, limit) + if page_info: + doc.text(page_info) + + # Per-project match summary when no scope was provided. + if project_code is None and table_group_id is None: + summary_rows = DataColumnChars.summarize_matches_by_project( + *clauses, + pattern=effective_pattern, + ) + if summary_rows: + doc.heading(2, "Matches by project") + doc.table( + ["Project", "Matches"], + [[code_, count] for code_, count in summary_rows], + code=[0], + ) + + doc.heading(2, "Columns") + doc.table( + ["Project", "Table group", "Schema", "Table", "Column"], + [ + [hit.project_code, hit.table_groups_name, hit.schema_name, hit.table_name, hit.column_name] + for hit in data + ], + code=[0, 1], + ) + + footer = format_page_footer(total, page, limit) + if footer: + doc.text(footer) + return doc.render() diff --git a/testgen/mcp/tools/quality_scores.py b/testgen/mcp/tools/quality_scores.py new file mode 100644 index 00000000..1545d7f4 --- /dev/null +++ b/testgen/mcp/tools/quality_scores.py @@ -0,0 +1,1031 @@ +from collections import defaultdict + +from testgen.commands.run_refresh_score_cards_results import save_and_refresh_score_definition +from testgen.common.models import with_database_session +from testgen.common.models.scores import ( + ScoreCategory, + ScoreDefinition, + ScoreDefinitionBreakdownItem, + ScoreDefinitionCriteria, + ScoreDefinitionFilter, +) +from testgen.common.models.table_group import TableGroup +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError +from testgen.mcp.permissions import get_project_permissions, mcp_permission +from testgen.mcp.tools.common import ( + SCORE_CHAIN_LEAF_TO_COLUMN, + SCORE_FILTER_FIELD_TO_COLUMN, + SCORE_GROUP_BY_TO_COLUMN, + DocGroup, + ScoreChainLeafField, + ScoreFilterField, + ScoreGroupBy, + ScoreType, + format_page_footer, + format_page_info, + parse_category, + parse_score_group_by, + parse_score_type, + resolve_scorecard, + resolve_table_group, + validate_limit, + validate_page, +) +from testgen.mcp.tools.markdown import MdDoc +from testgen.utils import friendly_score, friendly_score_impact + +_DOC_GROUP = DocGroup.SCORING + +_DEFAULT_LIMIT = 20 +_MAX_LIMIT = 100 + +_VALUE_MAX_LEN = 256 +_VALUE_FORBIDDEN_CHARS = frozenset("'\";\\\x00") + +# Defensive Python-side cap on grouped output. The category-scores SQL doesn't +# LIMIT, and most valid group_by values produce small bounded result sets +# (≤ ~15 dimensions/domains), but pathological metadata could blow this up. +_ROW_CAP = 100 + +_TOTAL_LABEL = "Total Score" +_CDE_LABEL = "CDE Score" + +_COLUMN_TO_LABEL: dict[str, str] = { + column: group_by.value for group_by, column in SCORE_GROUP_BY_TO_COLUMN.items() +} +# Chain-only fields (mode 2): not exposed as standalone filter fields but valid +# as the leaves of a `table_groups_name → table_name → column_name` chain. +_COLUMN_TO_LABEL["table_name"] = "Table" +_COLUMN_TO_LABEL["column_name"] = "Column" + + +_CHAIN_ROOT_FIELD = ScoreFilterField.TABLE_GROUP.value # "Table Group" +_CHAIN_LEAF_FIELDS = tuple(f.value for f in ScoreChainLeafField) # ("Table", "Column") + + +@with_database_session +@mcp_permission("view") +def get_quality_scores( + *, + project_code: str | None = None, + table_group_id: str | None = None, + group_by: str | None = None, + score_type: str | None = None, + filters: list[dict] | None = None, + include_issue_ct: bool = False, + include_impact: bool = False, +) -> str: + """Quality-score rollup with optional grouping and filtering. + + Returns overall Total, CDE, Profiling, and Testing scores by default, + plus an optional breakdown table when ``group_by`` is set. Scope is + project-wide unless ``project_code`` or ``table_group_id`` narrows it. + + **Filters.** Each filter is + ``{"field": "...", "value": "...", "others"?: [...]}``. Same-field values + OR together; different fields AND together. Valid flat fields: + ``"Table Group"``, ``"Data Location"``, ``"Data Source"``, + ``"Source System"``, ``"Source Process"``, ``"Business Domain"``, + ``"Stakeholder Group"``, ``"Transform Level"``, ``"Semantic Data Type"``, + ``"Data Product"``. To target specific tables or columns, chain a + ``"Table Group"`` filter via ``others`` into ``"Table"`` (optionally + then ``"Column"``); sibling chains OR. ``"Impact Dimension"`` and + ``"Quality Dimension"`` are valid as ``group_by`` only, not as filter + fields. Filter values must not contain quotes, semicolons, or + backslashes. ``table_group_id`` cannot be combined with chained + filters — put ``"Table Group"`` in the chain root instead. + + Args: + project_code: Scope to a project. Mutually exclusive with + ``table_group_id``. Omit both to roll across every visible + project. + table_group_id: Scope to a table group, e.g. from + ``get_data_inventory``. + group_by: Break overall scores out by one of: ``"Impact Dimension"``, + ``"Quality Dimension"``, ``"Semantic Data Type"``, + ``"Table Group"``, ``"Data Location"``, ``"Data Source"``, + ``"Source System"``, ``"Source Process"``, ``"Business Domain"``, + ``"Stakeholder Group"``, ``"Transform Level"``, + ``"Data Product"``. + score_type: Narrow returned scores. Omit to show all four (Total, + CDE, Profiling, Testing); pass ``"Total"`` for Total + Profiling + + Testing, or ``"CDE"`` for CDE alone. + filters: List of filter entries. See **Filters** above for shape. + include_issue_ct: Include the count of contributing issues + (hygiene + test failures). + include_impact: Include the per-category percentage impact on the + overall score. Only affects grouped output. + """ + perms = get_project_permissions() + + if project_code is not None and table_group_id is not None: + raise MCPUserError( + "Pass either `project_code` or `table_group_id`, not both." + ) + + parsed_score_type: ScoreType | None = ( + parse_score_type(score_type) if score_type is not None else None + ) + parsed_group_by: ScoreGroupBy | None = ( + parse_score_group_by(group_by) if group_by is not None else None + ) + + user_filters, group_by_field = _validate_filters(filters, allow_empty=True) + + if table_group_id is not None and not group_by_field: + raise MCPUserError( + "`table_group_id` cannot be combined with chained filters — " + "put `Table Group` in the chain root instead." + ) + + if table_group_id is not None: + table_group = resolve_table_group(table_group_id) + scope_codes = [table_group.project_code] + table_group_name = table_group.table_groups_name + elif project_code is not None: + perms.verify_access( + project_code, + not_found=MCPResourceNotAccessible("Project", project_code), + ) + scope_codes = [project_code] + table_group_name = None + else: + scope_codes = list(perms.allowed_codes) + table_group_name = None + + doc = MdDoc() + doc.heading(1, "Quality Scores") + + if table_group_id is not None: + doc.text(f"Scope: Table Group `{table_group_name}` (project `{scope_codes[0]}`).") + elif project_code is not None: + doc.text(f"Scope: Project `{scope_codes[0]}`.") + else: + doc.text(f"Scope: all accessible projects ({len(scope_codes)}).") + + cross_project = project_code is None and table_group_id is None and len(scope_codes) > 1 + + for code in scope_codes: + _render_one_scope( + doc, + project_code=code, + table_group_name=table_group_name, + group_by=parsed_group_by, + score_type=parsed_score_type, + user_filters=user_filters, + group_by_field=group_by_field, + include_issue_ct=include_issue_ct, + include_impact=include_impact, + heading=code if cross_project else None, + ) + + return doc.render() + + +def _build_definition( + *, + project_code: str, + table_group_name: str | None, + group_by: ScoreGroupBy | None, + score_type: ScoreType | None, + user_filters: list[dict], + group_by_field: bool, +) -> ScoreDefinition: + definition = ScoreDefinition() + definition.project_code = project_code + definition.name = "__mcp_get_quality_scores__" + # score_type=None enables both; a specific value enables only that one. + # `as_score_card` derives `cde_only_categories = cde_score and not + # total_score` — so flag combinations decide whether the category SQL + # filters by `critical_data_element = true`. + definition.total_score = score_type is None or score_type is ScoreType.TOTAL + definition.cde_score = score_type is None or score_type is ScoreType.CDE + definition.category = ( + ScoreCategory(SCORE_GROUP_BY_TO_COLUMN[group_by]) if group_by is not None else None + ) + + filters: list[dict] = list(user_filters) + if table_group_name is not None: + filters.append({"field": "table_groups_name", "value": table_group_name}) + elif not filters: + # `as_score_card` short-circuits when criteria has no filters + # (scores.py:292). Mirror the score-explorer UI's pattern: a + # scorecard always carries at least one filter, typically + # `table_groups_name`. For the unfiltered project-wide case, + # enumerate every table group in the project so the criteria + # still narrows by project_code (added by `_get_raw_query_filters`) + # and covers all table groups. + tg_names = [ + tg.table_groups_name + for tg in TableGroup.select_minimal_where( + TableGroup.project_code == project_code, + ) + ] + if tg_names: + filters.extend( + {"field": "table_groups_name", "value": name} for name in tg_names + ) + + definition.criteria = ScoreDefinitionCriteria.from_filters( + filters, group_by_field=group_by_field, + ) + return definition + + +def _render_one_scope( + doc: MdDoc, + *, + project_code: str, + table_group_name: str | None, + group_by: ScoreGroupBy | None, + score_type: ScoreType | None, + user_filters: list[dict], + group_by_field: bool, + include_issue_ct: bool, + include_impact: bool, + heading: str | None, +) -> None: + if heading is not None: + doc.heading(2, f"Project `{heading}`") + + definition = _build_definition( + project_code=project_code, + table_group_name=table_group_name, + group_by=group_by, + score_type=score_type, + user_filters=user_filters, + group_by_field=group_by_field, + ) + + show_total = score_type is None or score_type is ScoreType.TOTAL + show_cde = score_type is None or score_type is ScoreType.CDE + + card = definition.as_score_card() + if show_total: + doc.field(_TOTAL_LABEL, friendly_score(card.get("score"))) + if show_cde: + doc.field(_CDE_LABEL, friendly_score(card.get("cde_score"))) + if show_total: + doc.field("Profiling Score", friendly_score(card.get("profiling_score"))) + doc.field("Testing Score", friendly_score(card.get("testing_score"))) + + if include_issue_ct and group_by is None: + doc.field("Issue Count", definition.get_overall_issue_ct()) + + if group_by is None: + return + + group_by_column = SCORE_GROUP_BY_TO_COLUMN[group_by] + + # Per-category data — score, impact, issue_ct — comes from + # get_score_card_breakdown. One call per enabled score type, since each + # filters different rows (Total includes all data points; CDE filters + # to critical_data_element=true). + total_rows: dict[str, dict] = {} + cde_rows: dict[str, dict] = {} + if show_total: + for r in definition.get_score_card_breakdown("score", group_by_column): + label = r.get(group_by_column) + if label is not None: + total_rows[label] = r + if show_cde: + for r in definition.get_score_card_breakdown("cde_score", group_by_column): + label = r.get(group_by_column) + if label is not None: + cde_rows[label] = r + + all_labels = set(total_rows) | set(cde_rows) + if not all_labels: + if user_filters: + doc.text("_Filter matched no data._") + else: + doc.text("_No category data._") + return + + # Worst score first. Sort by primary column (Total if shown, else CDE). + def _sort_key(label: str) -> float: + primary = total_rows if show_total else cde_rows + score = (primary.get(label) or {}).get("score") + return score if score is not None else 1.0 + + sorted_labels = sorted(all_labels, key=_sort_key) + row_count = len(sorted_labels) + capped = sorted_labels[:_ROW_CAP] + + both_shown = show_total and show_cde + total_issue_header = "Issue Count (Total)" if both_shown else "Issue Count" + cde_issue_header = "Issue Count (CDE)" if both_shown else "Issue Count" + + headers: list[str] = [group_by.value] + if show_total: + headers.append(_TOTAL_LABEL) + if include_impact: + headers.append("Impact on Total Score") + if include_issue_ct: + headers.append(total_issue_header) + if show_cde: + headers.append(_CDE_LABEL) + if include_impact: + headers.append("Impact on CDE Score") + if include_issue_ct: + headers.append(cde_issue_header) + + md_rows: list[list[object]] = [] + for label in capped: + cells: list[object] = [label] + c_row = total_rows.get(label) or {} + d_row = cde_rows.get(label) or {} + if show_total: + cells.append(friendly_score(c_row.get("score"))) + if include_impact: + cells.append(_format_impact(c_row.get("impact"))) + if include_issue_ct: + cells.append(c_row.get("issue_ct") if c_row else None) + if show_cde: + cells.append(friendly_score(d_row.get("score"))) + if include_impact: + cells.append(_format_impact(d_row.get("impact"))) + if include_issue_ct: + cells.append(d_row.get("issue_ct") if d_row else None) + md_rows.append(cells) + doc.table(headers, md_rows) + + if row_count > _ROW_CAP: + doc.text(f"_Showing top {_ROW_CAP} of {row_count} rows by lowest score._") + + +@with_database_session +@mcp_permission("view") +def list_scorecards( + project_code: str, + page: int = 1, + limit: int = _DEFAULT_LIMIT, +) -> str: + """List the scorecards defined in a project. + + Args: + project_code: Project to list scorecards for. + page: Page number, starting at 1. + limit: Page size (max 100). + """ + validate_page(page) + validate_limit(limit, _MAX_LIMIT) + + perms = get_project_permissions() + perms.verify_access( + project_code, + not_found=MCPResourceNotAccessible("Project", project_code), + ) + + definitions, total = ScoreDefinition.list_for_project( + project_code, page=page, limit=limit, + ) + + doc = MdDoc() + doc.heading(1, f"Scorecards in Project `{project_code}`") + + page_info = format_page_info(total, page, limit) + if page_info: + doc.text(page_info) + + if not definitions: + if page > 1: + doc.text(f"_No scorecards on page {page} (total: {total})._") + else: + doc.text("_No scorecards configured._") + return doc.render() + + for definition in definitions: + doc.heading(2, f"{definition.name} (id: `{definition.id}`)") + card = definition.as_cached_score_card() + if definition.total_score: + doc.field(_TOTAL_LABEL, friendly_score(card.get("score"))) + if definition.cde_score: + doc.field(_CDE_LABEL, friendly_score(card.get("cde_score"))) + if definition.total_score: + doc.field("Profiling Score", friendly_score(card.get("profiling_score"))) + doc.field("Testing Score", friendly_score(card.get("testing_score"))) + if definition.category is not None: + doc.field("Category", _column_label(definition.category.value)) + doc.field("Filters", _format_criteria_summary(definition.criteria)) + + footer = format_page_footer(total, page, limit) + if footer: + doc.text(footer) + + return doc.render() + + +@with_database_session +@mcp_permission("view") +def get_scorecard(scorecard_id: str) -> str: + """Get a scorecard with its current scores and per-category breakdown. + + Args: + scorecard_id: UUID returned by ``list_scorecards`` or ``get_data_inventory``. + """ + definition = resolve_scorecard(scorecard_id) + card = definition.as_cached_score_card() + + doc = MdDoc() + doc.heading(1, f"Scorecard: {definition.name}") + + doc.field("ID", definition.id, code=True) + doc.field("Project", definition.project_code, code=True) + if definition.total_score: + doc.field(_TOTAL_LABEL, friendly_score(card.get("score"))) + if definition.cde_score: + doc.field(_CDE_LABEL, friendly_score(card.get("cde_score"))) + if definition.total_score: + doc.field("Profiling Score", friendly_score(card.get("profiling_score"))) + doc.field("Testing Score", friendly_score(card.get("testing_score"))) + if definition.category is not None: + doc.field("Category", _column_label(definition.category.value)) + doc.field("Filters", _format_criteria_summary(definition.criteria)) + + if definition.category is not None: + _render_breakdown(doc, definition) + + return doc.render() + + +def _render_breakdown(doc: MdDoc, definition: ScoreDefinition) -> None: + """Render the per-category breakdown table for an enabled score_type pair. + + Total and CDE rows are merged by label so the same category value shows + on one line with both score_types. Sorted by primary-score-type impact + desc; capped at ``_ROW_CAP`` rows with a truncation footer when exceeded. + """ + category_column = definition.category.value + category_label = _column_label(category_column) + doc.heading(2, f"Breakdown by {category_label}") + + show_total = definition.total_score + show_cde = definition.cde_score + + total_rows: dict[str, dict] = {} + cde_rows: dict[str, dict] = {} + if show_total: + for item in ScoreDefinitionBreakdownItem.filter( + definition_id=definition.id, + category=category_column, + score_type="score", + ): + row = item.to_dict() + label = _row_label(row, category_column) + if label is not None: + total_rows[label] = row + if show_cde: + for item in ScoreDefinitionBreakdownItem.filter( + definition_id=definition.id, + category=category_column, + score_type="cde_score", + ): + row = item.to_dict() + label = _row_label(row, category_column) + if label is not None: + cde_rows[label] = row + + all_labels = set(total_rows) | set(cde_rows) + if not all_labels: + doc.text("_No breakdown data._") + return + + primary = total_rows if show_total else cde_rows + + def _sort_key(label: str) -> float: + impact = (primary.get(label) or {}).get("impact") + return impact if impact is not None else 0.0 + + # Highest impact first — same ordering as the cached rows from the model. + sorted_labels = sorted(all_labels, key=_sort_key, reverse=True) + row_count = len(sorted_labels) + capped = sorted_labels[:_ROW_CAP] + + both_shown = show_total and show_cde + total_issue_header = "Issue Count (Total)" if both_shown else "Issue Count" + cde_issue_header = "Issue Count (CDE)" if both_shown else "Issue Count" + + headers: list[str] = [category_label] + if show_total: + headers.extend([_TOTAL_LABEL, "Impact on Total Score", total_issue_header]) + if show_cde: + headers.extend([_CDE_LABEL, "Impact on CDE Score", cde_issue_header]) + + md_rows: list[list[object]] = [] + for label in capped: + cells: list[object] = [label] + c_row = total_rows.get(label) or {} + d_row = cde_rows.get(label) or {} + if show_total: + cells.append(friendly_score(c_row.get("score"))) + cells.append(_format_impact(c_row.get("impact"))) + cells.append(c_row.get("issue_ct") if c_row else None) + if show_cde: + cells.append(friendly_score(d_row.get("score"))) + cells.append(_format_impact(d_row.get("impact"))) + cells.append(d_row.get("issue_ct") if d_row else None) + md_rows.append(cells) + doc.table(headers, md_rows) + + if row_count > _ROW_CAP: + doc.text(f"_Showing top {_ROW_CAP} of {row_count} rows by highest impact._") + + +@with_database_session +@mcp_permission("edit") +def create_scorecard( + project_code: str, + name: str, + filters: list[dict], + *, + category: str | None = None, + show_total_score: bool = True, + show_cde_score: bool = False, +) -> str: + """Create a scorecard in a project. + + **Filters.** At least one filter is required. Each entry is + ``{"field": "...", "value": "...", "others"?: [...]}``. Same-field values + OR together; different fields AND together. Valid flat fields: + ``"Table Group"``, ``"Data Location"``, ``"Data Source"``, + ``"Source System"``, ``"Source Process"``, ``"Business Domain"``, + ``"Stakeholder Group"``, ``"Transform Level"``, ``"Semantic Data Type"``, + ``"Data Product"``. To target specific tables or columns, chain a + ``"Table Group"`` filter via ``others`` into ``"Table"`` (optionally + then ``"Column"``); sibling chains OR. + + Args: + project_code: Project that will own the scorecard. + name: Scorecard name. Must be non-empty. + filters: List of filter entries. See **Filters** above for shape. + category: Category for per-bucket breakdown. One of + ``"Quality Dimension"``, ``"Impact Dimension"``, + ``"Data Source"``, ``"Business Domain"``, ``"Stakeholder Group"``, + ``"Table Group"``, ``"Transform Level"``, ``"Data Location"``, + ``"Source System"``, ``"Source Process"``, ``"Data Product"``. + show_total_score: Whether the scorecard exposes the Total Score. + show_cde_score: Whether the scorecard exposes the CDE Score. + """ + perms = get_project_permissions() + perms.verify_access( + project_code, + not_found=MCPResourceNotAccessible("Project", project_code), + ) + + if not name.strip(): + raise MCPUserError("`name` must be non-empty.") + + parsed_filters, group_by_field = _validate_filters(filters) + category_value = parse_category(category) if category is not None else None + + definition = ScoreDefinition() + definition.project_code = project_code + definition.name = name + definition.total_score = show_total_score + definition.cde_score = show_cde_score + definition.category = category_value + definition.criteria = ScoreDefinitionCriteria.from_filters( + parsed_filters, + group_by_field=group_by_field, + ) + + save_and_refresh_score_definition(definition, is_new=True) + + doc = MdDoc() + doc.heading(1, f"Scorecard `{definition.name}` created") + doc.field("ID", definition.id, code=True) + doc.field("Project", definition.project_code, code=True) + doc.field(_TOTAL_LABEL, "Yes" if show_total_score else "No") + doc.field(_CDE_LABEL, "Yes" if show_cde_score else "No") + if category_value is not None: + doc.field("Category", _column_label(category_value.value)) + doc.field("Filters", _format_criteria_summary(definition.criteria)) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def update_scorecard( + scorecard_id: str, + *, + name: str | None = None, + show_total_score: bool | None = None, + show_cde_score: bool | None = None, + category: str | None = None, + filters: list[dict] | None = None, +) -> str: + """Update fields on an existing scorecard. Pass only the fields to change. + + **Filters.** When supplied, ``filters`` replaces the scorecard's filters + wholesale and at least one entry is required. Each entry is + ``{"field": "...", "value": "...", "others"?: [...]}``. Same-field values + OR together; different fields AND together. Valid flat fields: + ``"Table Group"``, ``"Data Location"``, ``"Data Source"``, + ``"Source System"``, ``"Source Process"``, ``"Business Domain"``, + ``"Stakeholder Group"``, ``"Transform Level"``, ``"Semantic Data Type"``, + ``"Data Product"``. To target specific tables or columns, chain a + ``"Table Group"`` filter via ``others`` into ``"Table"`` (optionally + then ``"Column"``); sibling chains OR. + + Args: + scorecard_id: UUID returned by ``list_scorecards`` or + ``get_data_inventory``. + name: New scorecard name. Must be non-empty when supplied. + show_total_score: Whether the scorecard exposes the Total Score. + show_cde_score: Whether the scorecard exposes the CDE Score. + category: Category for per-bucket breakdown. One of + ``"Quality Dimension"``, ``"Impact Dimension"``, + ``"Data Source"``, ``"Business Domain"``, ``"Stakeholder Group"``, + ``"Table Group"``, ``"Transform Level"``, ``"Data Location"``, + ``"Source System"``, ``"Source Process"``, ``"Data Product"``. + Pass ``""`` to clear an existing category. + filters: List of filter entries. See **Filters** above for shape. + """ + definition = resolve_scorecard(scorecard_id) + + new_category: ScoreCategory | None = None + clear_category = category == "" + if category is not None and not clear_category: + new_category = parse_category(category) + + parsed_filters: list[dict] | None = None + group_by_field: bool | None = None + if filters is not None: + parsed_filters, group_by_field = _validate_filters(filters) + + pending: dict = {} + if name is not None: + if not name.strip(): + raise MCPUserError("`name` must be non-empty.") + pending["name"] = name + if show_total_score is not None: + pending["total_score"] = show_total_score + if show_cde_score is not None: + pending["cde_score"] = show_cde_score + if new_category is not None: + pending["category"] = new_category + elif clear_category: + pending["category"] = None + if parsed_filters is not None: + pending["criteria"] = ScoreDefinitionCriteria.from_filters( + parsed_filters, + group_by_field=group_by_field, + ) + + if not pending: + raise MCPUserError("No fields supplied to update.") + + before = _snapshot_for_diff(definition, pending) + for attr, value in pending.items(): + setattr(definition, attr, value) + after = _snapshot_for_diff(definition, pending) + + save_and_refresh_score_definition(definition, is_new=False) + + doc = MdDoc() + doc.heading(1, f"Scorecard `{definition.name}` updated") + doc.field("ID", definition.id, code=True) + doc.field("Project", definition.project_code, code=True) + rows = [ + [_DIFF_LABELS[attr], before[attr], after[attr]] + for attr in pending + ] + doc.table(["Field", "Before", "After"], rows, code=[0]) + return doc.render() + + +_DIFF_LABELS: dict[str, str] = { + "name": "Name", + "total_score": _TOTAL_LABEL, + "cde_score": _CDE_LABEL, + "category": "Category", + "criteria": "Filters", +} + + +def _snapshot_for_diff(definition: ScoreDefinition, attrs: dict) -> dict[str, str | None]: + """Render display-form values for each attr being changed.""" + snapshot: dict[str, str | None] = {} + for attr in attrs: + value = getattr(definition, attr, None) + if attr == "category": + snapshot[attr] = _column_label(value.value) if value is not None else None + elif attr == "criteria": + snapshot[attr] = _format_criteria_summary(value) + elif isinstance(value, bool): + snapshot[attr] = "Yes" if value else "No" + else: + snapshot[attr] = value if value is not None else None + return snapshot + + +@with_database_session +@mcp_permission("edit") +def delete_scorecard(scorecard_id: str) -> str: + """Delete a scorecard. + + Args: + scorecard_id: UUID returned by ``list_scorecards`` or ``get_data_inventory``. + """ + definition = resolve_scorecard(scorecard_id) + name = definition.name + project_code = definition.project_code + deleted_id = definition.id + + definition.delete() + + doc = MdDoc() + doc.heading(1, f"Scorecard `{name}` deleted") + doc.field("ID", deleted_id, code=True) + doc.field("Project", project_code, code=True) + return doc.render() + +def _filter_value_errors(value: object, field: str) -> list[str]: + """Return error strings for an unsafe filter value (empty list if safe). + + Catches non-string types, over-length values, and forbidden characters + that would enable SQL injection via ``ScoreDefinitionCriteria.get_as_sql``. + Does not check for empty/missing values — callers handle that separately. + """ + if not isinstance(value, str): + return [f"({field!r}): value must be a string"] + errors: list[str] = [] + if len(value) > _VALUE_MAX_LEN: + errors.append(f"({field!r}): value too long ({len(value)} > {_VALUE_MAX_LEN})") + bad_chars = sorted(set(value) & _VALUE_FORBIDDEN_CHARS) + if bad_chars: + errors.append(f"({field!r}): value contains forbidden characters {bad_chars}") + return errors + + +def _validate_filters( + raw_filters: list[dict] | None, *, allow_empty: bool = False, +) -> tuple[list[dict], bool]: + """Validate user-supplied filter shape and translate to column-form storage. + + Returns ``(parsed_filters, group_by_field)``. Input ``field`` values are + display-form (e.g. ``"Table Group"``, ``"Data Source"``, ``"Table"``, + ``"Column"``); the returned dicts use the underlying DB column names + (e.g. ``"table_groups_name"``, ``"table_name"``). + + Two storage modes (selectable per call, not mutually exclusive across + callers): + + * Mode 1 (flat, ``group_by_field=True``): every filter is a single + ``(field, value)`` pair using one of the values from ``ScoreFilterField``. + * Mode 2 (chained, ``group_by_field=False``): each chained filter roots at + ``"Table Group"`` and chains only into ``"Table"`` then ``"Column"``. A + flat ``"Table Group"`` filter is also valid here. + + Errors are collected across every offending entry and reported in one + ``MCPUserError`` so callers see every problem at once rather than chasing + one fix at a time. + + When ``allow_empty=True``, ``None`` / ``[]`` short-circuits to + ``([], True)``. With the default ``allow_empty=False``, empty input raises. + """ + if not raw_filters: + if allow_empty: + return [], True + raise MCPUserError("At least one filter is required.") + + errors: list[str] = [] + for index, filter_ in enumerate(raw_filters): + if not filter_.get("field") or not filter_.get("value"): + errors.append( + f"filters[{index}] must have non-empty `field` and `value`." + ) + continue + errors.extend( + f"filters[{index}] {err}" + for err in _filter_value_errors(filter_["value"], filter_["field"]) + ) + + valid_mode_1_fields = {f.value for f in ScoreFilterField} + has_chain = any( + isinstance(filter_, dict) and filter_.get("others") + for filter_ in raw_filters + ) + + if not has_chain: + parsed: list[dict] = [] + for index, filter_ in enumerate(raw_filters): + if not filter_.get("field") or not filter_.get("value"): + continue + field = filter_["field"] + if field not in valid_mode_1_fields: + valid = ", ".join(sorted(valid_mode_1_fields)) + errors.append( + f"filters[{index}]: `{field}` is not a valid scorecard filter " + f"field. To target specific tables or columns, chain a " + f"`{_CHAIN_ROOT_FIELD}` filter with `others`: " + f'[{{"field": "Table", "value": "..."}}]. ' + f"Valid flat fields: {valid}." + ) + continue + parsed.append({ + "field": SCORE_FILTER_FIELD_TO_COLUMN[ScoreFilterField(field)], + "value": filter_["value"], + }) + if errors: + raise MCPUserError("Invalid filters: " + "; ".join(errors)) + return parsed, True + + parsed_chained: list[dict] = [] + for index, filter_ in enumerate(raw_filters): + if not filter_.get("field") or not filter_.get("value"): + continue + field = filter_["field"] + others = filter_.get("others") or [] + if others and field != _CHAIN_ROOT_FIELD: + errors.append( + f"filters[{index}]: chained filters must root at " + f"`{_CHAIN_ROOT_FIELD}`, got `{field}`." + ) + continue + if not others and field != _CHAIN_ROOT_FIELD: + errors.append( + f"filters[{index}]: when any filter chains tables/columns, " + f"all filters must root at `{_CHAIN_ROOT_FIELD}`. Got `{field}`." + ) + continue + + translated_others: list[dict] = [] + chain_errors = False + for chain_index, chain in enumerate(others): + if not chain.get("field") or not chain.get("value"): + errors.append( + f"filters[{index}].others[{chain_index}] must have " + f"non-empty `field` and `value`." + ) + chain_errors = True + continue + chain_field = chain["field"] + if chain_field not in _CHAIN_LEAF_FIELDS: + errors.append( + f"filters[{index}].others[{chain_index}]: `{chain_field}` " + f"is not a valid chain field. Chains may only descend into " + f"{' or '.join(f'`{f}`' for f in _CHAIN_LEAF_FIELDS)}." + ) + chain_errors = True + continue + value_errors = _filter_value_errors(chain["value"], chain_field) + if value_errors: + errors.extend( + f"filters[{index}].others[{chain_index}] {err}" + for err in value_errors + ) + chain_errors = True + continue + translated_others.append({ + "field": SCORE_CHAIN_LEAF_TO_COLUMN[ScoreChainLeafField(chain_field)], + "value": chain["value"], + }) + + chain_field_values = [c.get("field") for c in others] + if chain_field_values == [ScoreChainLeafField.COLUMN.value]: + errors.append( + f"filters[{index}]: a `Column` chain requires a `Table` step before it." + ) + continue + if ScoreChainLeafField.COLUMN.value in chain_field_values[:-1]: + errors.append( + f"filters[{index}]: `Column` must be the final chain step." + ) + continue + + if chain_errors: + continue + + parsed_chained.append({ + "field": SCORE_FILTER_FIELD_TO_COLUMN[ScoreFilterField.TABLE_GROUP], + "value": filter_["value"], + "others": translated_others, + }) + + if errors: + raise MCPUserError("Invalid filters: " + "; ".join(errors)) + return parsed_chained, False + + +def _row_label(row: dict, category_column: str) -> str | None: + """Compose the display label for a breakdown row. + + For ``column_name`` breakdowns, prefix with the table name so columns with + the same name from different tables don't collapse into one bucket. NULL + category values (e.g. table-scope tests with no column_name) return + ``None`` so the row is skipped — matches ``get_quality_scores``. + """ + if category_column == "column_name": + table = row.get("table_name") + column = row.get("column_name") + if column is None: + return None + return f"{table}.{column}" if table else column + return row.get(category_column) + + +def _format_impact(value: float | None) -> str | None: + # Pass None through so MdDoc renders an em-dash for missing data — + # friendly_score_impact returns the literal "-" for None/0, which + # mismatches the score column's em-dash treatment. + if value is None: + return None + return friendly_score_impact(value) + + +def _format_criteria_summary(criteria: ScoreDefinitionCriteria | None) -> str: + """Human-readable summary of a scorecard's criteria. + + Two render modes, dispatched by filter shape: + + * Mode 1 (flat filters only): same-field values collapse to ``Label in (a, b)`` + when ``group_by_field=True``; different fields are AND-joined alphabetically + by display label for stable output. + * Mode 2 (any filter has a ``next_filter`` chain): chains are grouped by + ``(root_field, root_value)``; siblings sharing the same chain shape collapse + their leaves into ``in (...)``; root groups are OR-joined. + """ + if criteria is None or not criteria.has_filters(): + return "(no filters)" + + if any(root.next_filter is not None for root in criteria.filters): + return _format_mode_2_summary(criteria) + return _format_mode_1_summary(criteria) + + +def _format_mode_1_summary(criteria: ScoreDefinitionCriteria) -> str: + simple_by_field: dict[str, list[str]] = defaultdict(list) + for root in criteria.filters: + simple_by_field[root.field].append(root.value) + + rendered: list[tuple[str, str]] = [] + for field, values in simple_by_field.items(): + label = _column_label(field) + if len(values) == 1: + rendered.append((label, f"{label} = {values[0]}")) + elif criteria.group_by_field: + rendered.append((label, f"{label} in ({', '.join(values)})")) + else: + joiner = f" {criteria.operand} " + rendered.append((label, joiner.join(f"{label} = {v}" for v in values))) + + rendered.sort(key=lambda p: p[0]) + return " AND ".join(part for _, part in rendered) + + +def _format_mode_2_summary(criteria: ScoreDefinitionCriteria) -> str: + """Render mode-2 (chained) filters with OR semantics and leaf collapse.""" + grouped: dict[tuple[str, str], list[ScoreDefinitionFilter]] = defaultdict(list) + for root in criteria.filters: + grouped[(root.field, root.value)].append(root) + + branches: list[str] = [] + for (root_field, root_value), siblings in grouped.items(): + root_part = f"{_column_label(root_field)} = {root_value}" + chain_paths: list[list[tuple[str, str]]] = [] + for root in siblings: + path: list[tuple[str, str]] = [] + current = root.next_filter + while current is not None: + path.append((current.field, current.value)) + current = current.next_filter + chain_paths.append(path) + + non_empty_paths = [p for p in chain_paths if p] + has_empty = any(not p for p in chain_paths) + + if not non_empty_paths: + branches.append(root_part) + continue + + same_shape = len({tuple(field for field, _ in p) for p in non_empty_paths}) == 1 + if same_shape and not has_empty: + leaf_fields = [field for field, _ in non_empty_paths[0]] + leaf_parts: list[str] = [] + for i, field in enumerate(leaf_fields): + values = [p[i][1] for p in non_empty_paths] + label = _column_label(field) + if len(set(values)) == 1: + leaf_parts.append(f"{label} = {values[0]}") + else: + leaf_parts.append(f"{label} in ({', '.join(values)})") + branches.append(f"{root_part} AND {' AND '.join(leaf_parts)}") + else: + sub_branches: list[str] = [] + for path in chain_paths: + if not path: + sub_branches.append(root_part) + else: + leaves = [f"{_column_label(field)} = {value}" for field, value in path] + sub_branches.append(f"({root_part} AND {' AND '.join(leaves)})") + branches.append(" OR ".join(sub_branches)) + + if len(branches) == 1: + return branches[0] + return " OR ".join(f"({b})" if " AND " in b else b for b in branches) + + +def _column_label(column: str) -> str: + return _COLUMN_TO_LABEL.get(column, column) diff --git a/testgen/mcp/tools/reference.py b/testgen/mcp/tools/reference.py index abaa4d17..210e4b89 100644 --- a/testgen/mcp/tools/reference.py +++ b/testgen/mcp/tools/reference.py @@ -105,6 +105,160 @@ def hygiene_issue_types_resource() -> str: return doc.render() +def column_profile_fields_resource() -> str: + """Reference for column-profile fields by general_type, with PII redaction notes.""" + return """\ +# TestGen Column Profile Fields Reference + +Column profiling stores ~70 statistics per column. The fields populated +depend on the column's `General Type` (Alpha / Numeric / Date / Boolean / Other). The +`get_column_profile_detail` tool emits only the fields relevant to a column's type — use +this reference to interpret what each field measures. + +## All Column Types + +These fields are populated for every successfully-profiled column. + +### Header +- **Profiling Run** — `job_execution_id` of the profiling run the rest of the fields come from. +- **Profiled at** — Timestamp when the profiling run started (`YYYY-MM-DD HH:MM UTC`). +- **General Type** — Broad category: `Alpha`, `Numeric`, `Date`, `Boolean`, `Time`, or `Other`. +- **Data Type** — Native DB type as reported by the source (e.g. `varchar(50)`, `numeric(18,4)`). +- **Semantic Data Type** — TestGen's functional classification (e.g. `Person Given Name`, `Currency`, `Datetime-Created`). +- **Suggested Data Type** — Suggested narrower DB type given observed values (e.g. `VARCHAR(20)`, `INTEGER`). Omitted when no suggestion applies. +- **PII** — `No` when the column has no PII flag; `Yes` when manually flagged; otherwise `Yes ( Risk[ - ][ / ])` — Risk is `High`, `Moderate`, or `Low`; Category is `ID`, `Name`, `Demographic`, or `Contact`; Detail is a subtype (e.g. `Email`, `Passport`) when present. +- **Critical Data Element** — `Yes` if the column is flagged as critical (directly or via its parent table), `No` otherwise. +- **Profiling Score** — Aggregated profiling-derived quality score, 0-100. +- **Testing Score** — Aggregated testing-derived quality score, 0-100. +- **Hygiene Issues (confirmed)** — Confirmed hygiene issues against this column (count). Omitted when the column has a profiling error. + +### Counts +- **Row Count** — Total rows in the table (count, integer). +- **Value Count** — Non-null values in this column (count, integer). +- **Distinct Values** — Distinct non-null values (count, integer). +- **Null** — Null values (count, integer). +- **Dummy Values** — Dummy / placeholder values like `'?'`, `'-'`, `'unknown'` (count, integer). +- **Zero Values** — Exact-zero or `'0'`-string values (count, integer). Populated for numeric and alpha columns. + +## Alpha (text) Columns + +Populated when `General Type == "Alpha"`. + +### Length +- **Minimum Length** — Shortest string length (chars). +- **Maximum Length** — Longest string length (chars). +- **Average Length** — Average string length (chars, float). + +### Text Range +- **Minimum Text** — Lexicographic minimum value (raw string; **PII-redactable**). +- **Maximum Text** — Lexicographic maximum value (raw string; **PII-redactable**). + +### Patterns +- **Standard Pattern Match** — Recognized standard pattern when applicable (`Email`, `Phone (USA)`, + `Street Address`, `State (USA)`, `Zip Code (USA)`, `Filename`, `Credit Card`, `Delimited Data`, `SSN (USA)`). +- **Distinct Patterns** — Distinct character-class patterns observed (count). +- **Frequent Patterns** — Top patterns and counts, pipe-separated. +- **Frequent Values** — Top frequent raw values and counts (raw strings; **PII-redactable**). +- **Distinct Standard Values** — Distinct values after standardization (count). + +### Case & Composition +- **Upper Case / Lower Case / Mixed Case / Non-Alpha** — Case-distribution counts. +- **Includes Digits** — Values containing at least one digit (count). +- **Numeric Values** — Values parseable as numeric (count). +- **Date Values** — Values parseable as a date (count). +- **Quoted Values** — Values wrapped in quotes (count). +- **Leading Spaces** — Values with leading whitespace (count). +- **Embedded Spaces** — Values with internal whitespace (count). +- **Average Embedded Spaces** — Average embedded-space count per value (float). +- **Zero Length** — Empty strings (count). + +## Numeric Columns + +Populated when `General Type == "Numeric"`. + +### Distribution +- **Minimum Value** — Minimum numeric value (raw value; **PII-redactable**). +- **Minimum Value > 0** — Minimum value strictly greater than zero (**PII-redactable**). +- **Maximum Value** — Maximum numeric value (**PII-redactable**). +- **Average Value** — Arithmetic mean. +- **Standard Deviation** — Standard deviation. + +### Percentiles +- **25th Percentile** — 25th percentile (Q1). +- **Median Value** — Median (Q2 / 50th percentile). +- **75th Percentile** — 75th percentile (Q3). + +## Date Columns + +Populated when `General Type == "Date"`. + +### Date Range +- **Minimum Date** — Minimum timestamp (**PII-redactable**). +- **Maximum Date** — Maximum timestamp (**PII-redactable**). + +### Age Buckets +- **Before 1 Year** — Values older than 1 year from profiling date (count). +- **Before 5 Years** — Values older than 5 years (count). +- **Before 20 Years** — Values older than 20 years (count). +- **Within 1 Year** — Values within the past year (count). +- **Within 1 Month** — Values within the past month (count). +- **Future Dates** — Values dated after the profiling date (count). + +## Boolean Columns + +Populated when `General Type == "Boolean"`. + +- **True Count** — Rows where the value is true (count). +- **False Count** — Rows where the value is false (count, derived as `Value Count - True Count`). + +## PII Redaction + +When a column is flagged as PII AND the caller's role lacks permission to view PII on the column's +project, the following raw-value fields render as `[PII Redacted]`: + +- Frequent Values +- Minimum Text +- Maximum Text +- Minimum Value +- Minimum Value > 0 +- Maximum Value +- Minimum Date +- Maximum Date + +Aggregates, counts, `Frequent Patterns`, and `Standard Pattern Match` are never redacted — they're +distribution-level signals that don't expose individual rows. + +## Semantic Data Type — values emitted by profiling, grouped by family. + +**Identifiers**: `ID`, `ID-FK`, `ID-Group`, `ID-Secondary`, `ID-SK`, +`ID-Unique`, `ID-Unique-SK` + +**Dates & schedules**: `Date Stamp`, `DateTime Stamp`, `Schedule Date`, +`Future Date`, `Historical Date`, `Transactional Date`, +`Transactional Date (Mo)`, `Transactional Date (Qtr)`, +`Transactional Date (Wk)` + +**Periods**: `Period`, `Period DOW`, `Period Mon-NN`, `Period Month`, +`Period Quarter`, `Period Week`, `Period Year`, `Period Year-Mon` + +**People**: `Person Full Name`, `Person Given Name`, `Person Last Name` + +**Location & contact**: `Address`, `City`, `State`, `Zip`, `Email`, `Phone` + +**Measurements**: `Measurement`, `Measurement Discrete`, `Measurement Pct`, +`Measurement Spike`, `Measurement Text` + +**Codes, flags, attributes**: `Attribute`, `Boolean`, `Code`, `Constant`, +`Flag`, `Sequence` + +**Entity & system**: `Entity Name`, `Process`, `Process User`, `System User` + +The `semantic_data_type` filter on `list_column_profiles` matches via `ILIKE`, +so partial inputs catch related variants (e.g. `ID` matches `ID`, `ID-FK`, +`ID-Group`, …). +""" + + def glossary_resource() -> str: """Glossary of TestGen concepts, entity hierarchy, result statuses, and quality dimensions.""" return """\ diff --git a/testgen/mcp/tools/schedules.py b/testgen/mcp/tools/schedules.py new file mode 100644 index 00000000..9c1ed1d0 --- /dev/null +++ b/testgen/mcp/tools/schedules.py @@ -0,0 +1,438 @@ +"""MCP tools for managing recurring TestGen schedules — profiling and test-run schedules.""" + +from datetime import datetime +from enum import StrEnum + +from sqlalchemy import select + +from testgen.common.cron_service import describe_cron, get_cron_sample +from testgen.common.enums import JobKey +from testgen.common.models import get_current_session, with_database_session +from testgen.common.models.job_execution import JobExecution +from testgen.common.models.scheduler import JobSchedule +from testgen.common.models.table_group import TableGroup +from testgen.common.models.test_run import TestRunSummary # STATUS_LABEL is shared with ProfilingRunSummary +from testgen.common.models.test_suite import TestSuite +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError +from testgen.mcp.permissions import get_project_permissions, mcp_permission +from testgen.mcp.tools.common import ( + DocGroup, + format_page_footer, + format_page_info, + format_run_duration, + resolve_schedule, + resolve_table_group, + resolve_test_suite, + validate_limit, + validate_page, +) +from testgen.mcp.tools.markdown import MdDoc + +_DOC_GROUP = DocGroup.TRIGGER + + +class ScheduleType(StrEnum): + profiling_run = "profiling_run" + test_run = "test_run" + + +_SCHEDULE_TYPE_TO_JOB_KEY: dict[ScheduleType, JobKey] = { + ScheduleType.profiling_run: JobKey.run_profile, + ScheduleType.test_run: JobKey.run_tests, +} + + +def _kind_display(key: str) -> str: + """User-facing label for a schedule's job kind.""" + if key == JobKey.run_profile: + return "Profiling Run" + return "Test Run" + +# --------------------------------------------------------------------------- +# Validation + rendering helpers +# --------------------------------------------------------------------------- + + +def _validate_cron(cron_expression: str, cron_tz: str) -> str: + """Validate cron expression + timezone. Returns the human-readable description.""" + if not cron_expression: + raise MCPUserError("`cron_expression` is required.") + if not cron_tz: + raise MCPUserError("`cron_tz` is required (IANA name, e.g. `UTC`).") + sample = get_cron_sample(cron_expression, cron_tz, sample_count=1) + if "error" in sample: + raise MCPUserError(f"Invalid cron expression or timezone: {sample['error']}") + return sample["readable_expr"] + + +def _parse_schedule_type(value: str) -> ScheduleType: + try: + return ScheduleType(value) + except ValueError as err: + valid = ", ".join(t.value for t in ScheduleType) + raise MCPUserError(f"Invalid schedule_type `{value}`. Valid values: {valid}") from err + + +def _linked_kind_label(key: str) -> str: + """Field label for the linked entity row, based on the schedule's ``key``.""" + if key == JobKey.run_profile: + return "Table Group" + return "Test Suite" + + +def _linked_entity_id(sched: JobSchedule) -> str | None: + """Extract the linked entity UUID from ``kwargs``. ``None`` if the row is malformed.""" + if sched.key == JobKey.run_profile: + return sched.kwargs.get("table_group_id") + return sched.kwargs.get("test_suite_id") + + +def _format_linked(sched: JobSchedule, name: str | None) -> str: + """Combined ``: `name` (ID: `uuid`)`` line used by both detail block and list rows.""" + linked_id = _linked_entity_id(sched) + name_part = f"`{name}`" if name else "—" + id_part = f" (ID: `{linked_id}`)" if linked_id else "" + return f"{name_part}{id_part}" + + +def _next_run(sched: JobSchedule) -> datetime | None: + try: + return sched.get_sample_triggering_timestamps(1)[0] + except Exception: + return None + + +def _render_schedule( + doc: MdDoc, + sched: JobSchedule, + *, + linked_name: str | None, + include_next_runs: int = 1, +) -> None: + doc.field("Schedule ID", sched.id, code=True) + doc.field("Type", _kind_display(sched.key)) + doc.field(_linked_kind_label(sched.key), _format_linked(sched, linked_name)) + doc.field("Cron expression", sched.cron_expr, code=True) + if (readable := describe_cron(sched.cron_expr)) is not None: + doc.field("Cron description", readable) + doc.field("Timezone", sched.cron_tz) + doc.field("Status", "Active" if sched.active else "Paused") + if include_next_runs > 0: + try: + next_times = sched.get_sample_triggering_timestamps(include_next_runs) + except Exception: + next_times = [] + if next_times: + label = "Next run" if include_next_runs == 1 else "Next runs" + doc.field(label, ", ".join(_format_dt(t) for t in next_times)) + + +def _format_dt(value: datetime | None) -> str: + if value is None: + return "—" + return value.strftime("%Y-%m-%d %H:%M %Z") or value.strftime("%Y-%m-%d %H:%M") + + +def _resolve_linked_names(schedules: list[JobSchedule]) -> dict[tuple[str, str], str]: + """Batch-fetch linked-entity names for a list of schedules. Avoids N+1. + + Returns a dict keyed by (kind, id) where kind ∈ {'tg', 'suite'} and id is the UUID string. + """ + session = get_current_session() + tg_ids: set[str] = set() + suite_ids: set[str] = set() + for sched in schedules: + linked_id = _linked_entity_id(sched) + if linked_id is None: + continue + if sched.key == JobKey.run_profile: + tg_ids.add(linked_id) + else: + suite_ids.add(linked_id) + + names: dict[tuple[str, str], str] = {} + if tg_ids: + rows = session.execute( + select(TableGroup.id, TableGroup.table_groups_name).where(TableGroup.id.in_(tg_ids)) + ).all() + for row_id, row_name in rows: + names[("tg", str(row_id))] = row_name + if suite_ids: + rows = session.execute( + select(TestSuite.id, TestSuite.test_suite).where(TestSuite.id.in_(suite_ids)) + ).all() + for row_id, row_name in rows: + names[("suite", str(row_id))] = row_name + return names + + +def _linked_name(sched: JobSchedule, names: dict[tuple[str, str], str]) -> str | None: + linked_id = _linked_entity_id(sched) + if linked_id is None: + return None + kind = "tg" if sched.key == JobKey.run_profile else "suite" + return names.get((kind, linked_id)) + + +# --------------------------------------------------------------------------- +# Write tools +# --------------------------------------------------------------------------- + + +@with_database_session +@mcp_permission("edit") +def create_profiling_schedule( + table_group_id: str, + cron_expression: str, + cron_tz: str = "UTC", + active: bool = True, +) -> str: + """Create a recurring profiling schedule for a table group. + + Args: + table_group_id: UUID of the table group to profile, e.g. from ``get_data_inventory``. + cron_expression: Five-field cron expression, e.g. ``0 3 * * *`` for daily at 03:00. + cron_tz: IANA timezone name (e.g. ``America/New_York``). Defaults to ``UTC``. + active: Whether the schedule should start active. Defaults to ``True``. + """ + table_group = resolve_table_group(table_group_id) + _validate_cron(cron_expression, cron_tz) + sched = JobSchedule( + project_code=table_group.project_code, + key=JobKey.run_profile, + kwargs={"table_group_id": str(table_group.id)}, + cron_expr=cron_expression, + cron_tz=cron_tz, + active=active, + ) + sched.save() + + doc = MdDoc() + doc.heading(1, f"Profiling schedule created for `{table_group.table_groups_name}`") + _render_schedule(doc, sched, linked_name=table_group.table_groups_name) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def create_test_run_schedule( + test_suite_id: str, + cron_expression: str, + cron_tz: str = "UTC", + active: bool = True, +) -> str: + """Create a recurring test-run schedule for a test suite. + + Args: + test_suite_id: UUID of the test suite to run, e.g. from ``list_test_suites``. + cron_expression: Five-field cron expression, e.g. ``0 6 * * 1`` for Mondays at 06:00. + cron_tz: IANA timezone name (e.g. ``America/New_York``). Defaults to ``UTC``. + active: Whether the schedule should start active. Defaults to ``True``. + """ + suite = resolve_test_suite(test_suite_id) + _validate_cron(cron_expression, cron_tz) + sched = JobSchedule( + project_code=suite.project_code, + key=JobKey.run_tests, + kwargs={"test_suite_id": str(suite.id)}, + cron_expr=cron_expression, + cron_tz=cron_tz, + active=active, + ) + sched.save() + + doc = MdDoc() + doc.heading(1, f"Test run schedule created for `{suite.test_suite}`") + _render_schedule(doc, sched, linked_name=suite.test_suite) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def update_schedule( + schedule_id: str, + cron_expression: str | None = None, + cron_tz: str | None = None, + active: bool | None = None, +) -> str: + """Update a schedule's cron, timezone, or active state. Atomic — no partial save. + + The job type and linked configuration are immutable — delete and recreate to change them. + + Args: + schedule_id: UUID of the schedule, e.g. from ``list_schedules``. + cron_expression: New cron expression. Omit to leave unchanged. + cron_tz: New IANA timezone. Omit to leave unchanged. + active: ``True`` to resume, ``False`` to pause. Omit to leave unchanged. + """ + if cron_expression is None and cron_tz is None and active is None: + raise MCPUserError("No fields supplied to update.") + + sched = resolve_schedule(schedule_id) + + new_expr = cron_expression if cron_expression is not None else sched.cron_expr + new_tz = cron_tz if cron_tz is not None else sched.cron_tz + if cron_expression is not None or cron_tz is not None: + _validate_cron(new_expr, new_tz) + + changes: list[tuple[str, object, object]] = [] + if cron_expression is not None and cron_expression != sched.cron_expr: + changes.append(("Cron expression", sched.cron_expr, cron_expression)) + sched.cron_expr = cron_expression + if cron_tz is not None and cron_tz != sched.cron_tz: + changes.append(("Timezone", sched.cron_tz, cron_tz)) + sched.cron_tz = cron_tz + if active is not None and active != sched.active: + before = "Active" if sched.active else "Paused" + after = "Active" if active else "Paused" + changes.append(("Status", before, after)) + sched.active = active + + sched.save() + + doc = MdDoc() + doc.heading(1, "Schedule updated") + doc.field("Schedule ID", sched.id, code=True) + if not changes: + doc.text("No fields changed — supplied values matched the current state.") + return doc.render() + doc.table(["Field", "Before", "After"], [list(c) for c in changes]) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def delete_schedule(schedule_id: str) -> str: + """Delete a schedule. Past executions remain accessible via ``list_test_runs`` / ``list_profiling_runs``. + + Args: + schedule_id: UUID of the schedule, e.g. from ``list_schedules``. + """ + sched = resolve_schedule(schedule_id) + JobSchedule.delete(sched.id) + + doc = MdDoc() + doc.heading(1, "Schedule deleted") + doc.field("Schedule ID", sched.id, code=True) + return doc.render() + + +# --------------------------------------------------------------------------- +# Read tools +# --------------------------------------------------------------------------- + + +@with_database_session +@mcp_permission("view") +def list_schedules( + project_code: str, + schedule_type: str | None = None, + limit: int = 20, + page: int = 1, +) -> str: + """List schedules for a project — profiling and test run schedules. + + Args: + project_code: Project to scope to, e.g. from ``list_projects``. + schedule_type: Optional filter — ``profiling_run`` or ``test_run``. + limit: Max rows per page. Defaults to 20. + page: 1-indexed page number. Defaults to 1. + """ + validate_page(page) + validate_limit(limit, 100) + + perms = get_project_permissions() + if project_code not in perms.allowed_codes: + raise MCPResourceNotAccessible("Project", project_code) + + key_filter: list[JobKey] | None = None + if schedule_type is not None: + st_enum = _parse_schedule_type(schedule_type) + key_filter = [_SCHEDULE_TYPE_TO_JOB_KEY[st_enum]] + + schedules, total = JobSchedule.list_for_project( + project_code, + key_filter=key_filter, + page=page, + limit=limit, + ) + + doc = MdDoc() + doc.heading(1, f"Schedules — `{project_code}`") + info = format_page_info(total, page, limit) + if info: + doc.text(info) + if not schedules: + doc.text("_No schedules._") + return doc.render() + + linked_names = _resolve_linked_names(schedules) + rows: list[list[object]] = [] + for sched in schedules: + rows.append([ + sched.id, + _kind_display(sched.key), + f"{_linked_kind_label(sched.key)}: {_format_linked(sched, _linked_name(sched, linked_names))}", + sched.cron_expr, + sched.cron_tz, + "Active" if sched.active else "Paused", + _format_dt(_next_run(sched)), + ]) + doc.table( + ["Schedule ID", "Type", "Details", "Cron", "Timezone", "Status", "Next run"], + rows, + code=[0, 3], + ) + footer = format_page_footer(total, page, limit) + if footer: + doc.text(footer) + return doc.render() + + +@with_database_session +@mcp_permission("view") +def get_schedule(schedule_id: str) -> str: + """Get full details for a schedule, including the last five execution attempts. + + Args: + schedule_id: UUID of the schedule, e.g. from ``list_schedules``. + """ + sched = resolve_schedule(schedule_id) + linked_names = _resolve_linked_names([sched]) + linked_name = _linked_name(sched, linked_names) + + doc = MdDoc() + doc.heading(1, "Schedule") + _render_schedule(doc, sched, linked_name=linked_name, include_next_runs=3) + + history = get_current_session().scalars( + select(JobExecution) + .where(JobExecution.job_schedule_id == sched.id) + .order_by(JobExecution.created_at.desc()) + .limit(5) + ).all() + + doc.heading(2, "Recent runs") + if not history: + doc.text("_No runs yet._") + return doc.render() + + rows: list[list[object]] = [] + for je in history: + rows.append([ + je.id, + TestRunSummary.STATUS_LABEL.get(je.status, je.status), + je.started_at, + je.completed_at, + format_run_duration(je.started_at, je.completed_at), + ]) + doc.table( + ["Job ID", "Status", "Started", "Completed", "Duration"], + rows, + code=[0], + ) + doc.text( + "_Showing the 5 most recent runs._ " + "Use `list_test_runs` or `list_profiling_runs` for full history." + ) + return doc.render() diff --git a/testgen/mcp/tools/source_data.py b/testgen/mcp/tools/source_data.py index 1e75b78c..1b3fb0bb 100644 --- a/testgen/mcp/tools/source_data.py +++ b/testgen/mcp/tools/source_data.py @@ -101,8 +101,7 @@ def get_source_data( validate_limit(limit, 500) context = _resolve_context(test_definition_id, reference_date) - perms = get_project_permissions() - mask_pii = context.get("project_code") not in perms.codes_allowed_to("view_pii") + mask_pii = not get_project_permissions().has_permission("view_pii", context.get("project_code")) result: SourceDataResult = fetch_test_result_source_data(context, limit, mask_pii) diff --git a/testgen/mcp/tools/test_definitions.py b/testgen/mcp/tools/test_definitions.py index 6d28e3a7..67197f20 100644 --- a/testgen/mcp/tools/test_definitions.py +++ b/testgen/mcp/tools/test_definitions.py @@ -1,9 +1,24 @@ +from datetime import UTC, datetime +from enum import StrEnum +from typing import NoReturn + +from sqlalchemy import update + +from testgen.common.custom_test_validation import validate_custom_query from testgen.common.enums import ImpactDimension, QualityDimension -from testgen.common.models import with_database_session -from testgen.common.models.test_definition import TestDefinition, TestDefinitionNote, TestDefinitionSummary, TestType +from testgen.common.models import get_current_session, with_database_session +from testgen.common.models.connection import Connection +from testgen.common.models.table_group import TableGroup +from testgen.common.models.test_definition import ( + InvalidTestDefinitionFields, + TestDefinition, + TestDefinitionNote, + TestDefinitionSummary, + TestType, +) from testgen.common.models.test_result import TestResult from testgen.mcp.exceptions import MCPUserError -from testgen.mcp.permissions import get_project_permissions, mcp_permission +from testgen.mcp.permissions import get_authorized_mcp_user, get_project_permissions, mcp_permission from testgen.mcp.tools.common import ( DocGroup, format_page_footer, @@ -11,17 +26,25 @@ parse_impact_dimension, parse_quality_dimension, parse_uuid, + resolve_test_definition, + resolve_test_note, + resolve_test_suite, resolve_test_type, validate_limit, validate_page, ) from testgen.mcp.tools.markdown import MdDoc -_DOC_GROUP = DocGroup.DISCOVER +_DOC_GROUP = DocGroup.INVESTIGATE _VALID_SCOPES = {"column", "table", "referential", "custom"} +class BulkAction(StrEnum): + ENABLE = "enable" + DISABLE = "disable" + + @with_database_session @mcp_permission("view") def list_tests( @@ -118,11 +141,41 @@ def get_test(test_definition_id: str) -> str: if td is None: return f"Test definition `{test_definition_id}` not found." - test_name = td.display_name - doc = MdDoc() + _append_td_summary(doc, td) + + # Last result + results = TestResult.select_history( + test_definition_id=def_uuid, + project_codes=perms.allowed_codes, + limit=1, + ) + doc.heading(2, "Last Result") + if results: + r = results[0] + doc.field("Date", r.test_time) + doc.field("Status", r.status.value if r.status else None) + if r.message: + doc.field("Message", r.message) + else: + doc.text("_No results recorded for this test definition._") + + # Description + description = td.test_description or td.default_test_description + if description: + doc.heading(2, "Description") + doc.text(description) + if td.usage_notes: + doc.heading(2, "Usage Notes") + doc.text(td.usage_notes) + + return doc.render() + + +def _append_td_summary(doc: MdDoc, td: TestDefinitionSummary) -> None: + """Render the identity, configuration, parameters, custom-SQL, and reference-match sections of a test definition.""" + test_name = td.display_name - # Header if td.column_name: doc.heading(1, f"{test_name} on `{td.column_name}` in `{td.table_name}`") else: @@ -158,7 +211,7 @@ def get_test(test_definition_id: str) -> str: doc.field("Export to Observability", "Yes" if td.export_to_observability else "No") # Review status - notes = TestDefinitionNote.get_notes(def_uuid) + notes = TestDefinitionNote.get_notes(td.id) flag_str = "Flagged" if td.flagged else "Not Flagged" note_str = f"{len(notes)} Notes" if notes else "No Notes" doc.field("Review", f"{flag_str}, {note_str}") @@ -185,33 +238,6 @@ def get_test(test_definition_id: str) -> str: # Reference match (only fields listed in param_columns) _append_match_section(doc, td) - # Last result - results = TestResult.select_history( - test_definition_id=def_uuid, - project_codes=perms.allowed_codes, - limit=1, - ) - doc.heading(2, "Last Result") - if results: - r = results[0] - doc.field("Date", r.test_time) - doc.field("Status", r.status.value if r.status else None) - if r.message: - doc.field("Message", r.message) - else: - doc.text("_No results recorded for this test definition._") - - # Description - description = td.test_description or td.default_test_description - if description: - doc.heading(2, "Description") - doc.text(description) - if td.usage_notes: - doc.heading(2, "Usage Notes") - doc.text(td.usage_notes) - - return doc.render() - @with_database_session @mcp_permission("view") @@ -242,15 +268,111 @@ def list_test_notes(test_definition_id: str) -> str: doc.text(f"{len(notes)} note(s).") doc.table( - headers=["Date", "Author", "Note", "Updated"], + headers=["Test note ID", "Date", "Author", "Note", "Updated"], rows=[ - [n["created_at"], n["created_by"], n["detail"], n["updated_at"]] + [n["id"], n["created_at"], n["created_by"], n["detail"], n["updated_at"]] for n in notes ], + code=[0], + ) + return doc.render() + + +def _validate_note_body(body: str) -> None: + if not isinstance(body, str) or not body.strip(): + raise MCPUserError("`body` cannot be empty or whitespace-only.") + + +def _note_parent_label(summary: TestDefinitionSummary) -> str: + where = f"`{summary.column_name}` in `{summary.table_name}`" if summary.column_name else f"`{summary.table_name}`" + return f"{summary.display_name} on {where}" + + +@with_database_session +@mcp_permission("edit") +def create_test_note(test_definition_id: str, body: str) -> str: + """Attach a note to a test definition. + + Args: + test_definition_id: UUID of the test definition, e.g. from ``list_tests``. + body: Note body (free-text). Empty or whitespace-only is rejected. + """ + _validate_note_body(body) + td = resolve_test_definition(test_definition_id) + username = get_authorized_mcp_user().username + + note = TestDefinitionNote.add_note(td.id, body, username) + + perms = get_project_permissions() + summary = TestDefinition.get_for_project(td.id, perms.allowed_codes) + + doc = MdDoc() + doc.text(f"**Note added** to {_note_parent_label(summary)}.") + doc.field("Test note ID", note.id, code=True) + doc.field("Author", username) + doc.field("Date", note.created_at) + doc.field("Note", MdDoc.escape(note.detail)) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def update_test_note(test_note_id: str, body: str) -> str: + """Replace the body of a test note. Only the note's author can update it. + + Args: + test_note_id: UUID of the test note, e.g. from ``list_test_notes`` or ``create_test_note``. + body: New note body (free-text). Empty or whitespace-only is rejected. + """ + _validate_note_body(body) + note = resolve_test_note(test_note_id) + username = get_authorized_mcp_user().username + if note.created_by != username: + raise MCPUserError("You can only edit notes you authored.") + + before_body = note.detail + TestDefinitionNote.update_note(note.id, body) + + perms = get_project_permissions() + summary = TestDefinition.get_for_project(note.test_definition_id, perms.allowed_codes) + + doc = MdDoc() + doc.text(f"**Note updated** on {_note_parent_label(summary)}.") + doc.table( + headers=["Field", "Before", "After"], + rows=[["Note", before_body, body]], ) return doc.render() +@with_database_session +@mcp_permission("edit") +def delete_test_note(test_note_id: str) -> str: + """Delete a test note. Only the note's author can delete it. + + Args: + test_note_id: UUID of the test note, e.g. from ``list_test_notes``. + """ + note = resolve_test_note(test_note_id) + username = get_authorized_mcp_user().username + if note.created_by != username: + raise MCPUserError("You can only delete notes you authored.") + + author = note.created_by + created_at = note.created_at + td_id = note.test_definition_id + TestDefinitionNote.delete_note(note.id) + + perms = get_project_permissions() + summary = TestDefinition.get_for_project(td_id, perms.allowed_codes) + + doc = MdDoc() + doc.text(f"**Note deleted** from {_note_parent_label(summary)}.") + doc.field("Author", author) + doc.field("Date", created_at) + return doc.render() + + def _append_parameters_section(doc: MdDoc, td: TestDefinitionSummary) -> None: """Build the editable parameters table from test type metadata. @@ -349,3 +471,274 @@ def list_test_types( ) return doc.render() + + +# --------------------------------------------------------------------------- +# Write tools (create / update / validate / bulk-update) +# +# All gated on ``edit`` permission. Atomic semantics on ``update_test`` — +# validation aggregates every field error before raising, so the LLM sees the +# full set in one response and the DB is never touched on a partial-error path. +# --------------------------------------------------------------------------- + + +def _raise_validation_errors(err: InvalidTestDefinitionFields, header: str) -> NoReturn: + """Convert aggregated validation errors into a user-facing ``MCPUserError``.""" + bullets = "\n".join(f"- `{field}`: {reason}" for field, reason in err.errors.items()) + raise MCPUserError(f"{header}\n\n{bullets}") from err + + +@with_database_session +@mcp_permission("edit") +def create_test( + test_suite_id: str, + test_type: str, + table_name: str, + fields: dict | None = None, +) -> str: + """Create a test in a test suite. + + Args: + test_suite_id: UUID of the test suite. + test_type: Test type name, e.g. ``Alpha Truncation`` or ``Custom Test``. + table_name: Target table name. Case-sensitive. + fields: Mapping of field name to value for the test's parameters and metadata + (e.g. ``threshold_value``, ``custom_query``, ``severity``, ``column_name``, + ``test_description``). Use ``list_test_types`` or ``get_test`` on a similar + test to discover what's settable for the chosen test type. + """ + suite = resolve_test_suite(test_suite_id) + tt_code = resolve_test_type(test_type) + tt = TestType.get(tt_code) + if tt is None: # resolve_test_type already raised if the short name is unknown + raise MCPUserError(f"Unknown test type: `{test_type}`.") + + table_group = TableGroup.get(suite.table_groups_id) + if table_group is None: + raise MCPUserError("Test suite is not associated with a table group.") + + td = TestDefinition( + test_suite_id=suite.id, + table_groups_id=table_group.id, + test_type=tt_code, + schema_name=table_group.table_group_schema, + table_name=table_name, + test_active=True, + lock_refresh=False, + last_manual_update=datetime.now(UTC), + ) + + fields = fields or {} + accepted = td.editable_fields(tt) + rejected = sorted(set(fields) - accepted) + if rejected: + bullets = "\n".join(f"- `{key}`: not editable for test type `{tt_code}`" for key in rejected) + raise MCPUserError(f"Test definition creation rejected. No changes saved.\n\n{bullets}") + for key, value in fields.items(): + setattr(td, key, value) + + try: + td.validate(tt) + except InvalidTestDefinitionFields as e: + _raise_validation_errors(e, "Test definition creation rejected. No changes saved.") + + td.save() + + # The joined test-type metadata (param_fields, default_severity, dq_dimension, ...) + # is only present on the Summary dataclass, so re-fetch for rendering. + perms = get_project_permissions() + summary = TestDefinition.get_for_project(td.id, perms.allowed_codes) + + doc = MdDoc() + doc.text(f"**Created** in suite `{suite.test_suite}`.") + _append_td_summary(doc, summary) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def update_test(test_definition_id: str, fields: dict) -> str: + """Update fields on an existing test. Atomic — no partial save. + + Args: + test_definition_id: UUID of the test definition. + fields: Mapping of field name to new value. Use ``get_test`` to see the current + values and which fields are settable for the test's type. + """ + td = resolve_test_definition(test_definition_id) + tt = TestType.get(td.test_type) + if tt is None: + raise MCPUserError(f"Test type `{td.test_type}` not found for this test definition.") + + if not fields: + raise MCPUserError("No fields supplied to update.") + + accepted = td.editable_fields(tt) + rejected = sorted(set(fields) - accepted) + if rejected: + bullets = "\n".join( + f"- `{key}`: not editable for test type `{tt.test_type}`" for key in rejected + ) + raise MCPUserError(f"Update rejected. No changes saved.\n\n{bullets}") + + before: dict = {key: getattr(td, key, None) for key in fields} + for key, value in fields.items(): + setattr(td, key, value) + td.last_manual_update = datetime.now(UTC) + + try: + td.validate(tt) + except InvalidTestDefinitionFields as e: + _raise_validation_errors(e, "Update rejected. No changes saved.") + + td.save() + + doc = MdDoc() + doc.heading(1, f"Test definition `{td.id}` updated") + rows = [[key, _format_diff(before[key]), _format_diff(fields[key])] for key in fields] + doc.table(["Field", "Before", "After"], rows, code=[0]) + doc.text(f"{len(fields)} field(s) changed.") + return doc.render() + + +def _format_diff(value: object) -> str | None: + """Render a before/after cell, normalizing empty strings to ``None`` (NullIfEmptyString).""" + if value is None or value == "": + return None + if isinstance(value, bool): + return "Yes" if value else "No" + return str(value) + + +@with_database_session +@mcp_permission("edit") +def validate_custom_test(test_suite_id: str, custom_sql: str) -> str: + """Dry-run a custom test SQL query against the test suite's parent connection. + + The query should return rows matching the test failure criteria — returning no rows + means the test passes; returning any rows means it fails. + + Args: + test_suite_id: UUID of the test suite whose connection the SQL runs against. + custom_sql: SQL query returning failure-criteria rows. + """ + suite = resolve_test_suite(test_suite_id) + connection = Connection.get_by_table_group(suite.table_groups_id) + if connection is None: + raise MCPUserError("No connection configured for this test suite's table group.") + table_group = TableGroup.get(suite.table_groups_id) + if table_group is None: + raise MCPUserError("Test suite is not associated with a table group.") + + can_view_pii = get_project_permissions().has_permission("view_pii", suite.project_code) + + doc = MdDoc() + doc.heading(1, "Custom test dry-run") + + try: + result = validate_custom_query( + connection, table_group.table_group_schema, custom_sql, preview_limit=1, + ) + except Exception as e: # broad catch: the DB error message IS the user-facing signal + doc.text(f"**SQL did not execute.** Query was not committed against `{connection.connection_name}`.") + message = str(e.args[0]) if e.args else str(e) + doc.text("**Error:**") + doc.code_block(message) + return doc.render() + + flavor = connection.sql_flavor_code or connection.sql_flavor or "target database" + doc.text( + f"**SQL ran successfully** against `{connection.connection_name}` ({flavor})." + ) + + if result.row_count == 0: + doc.text("**Would pass:** ✓ — query returned 0 rows matching the failure criteria.") + doc.text( + "_If saved as a CUSTOM test, this would currently pass: the test fails when any " + "rows match the failure criteria, and there are none._" + ) + return doc.render() + + doc.text( + f"**Would fail:** ✗ — query returned {result.row_count} row(s) matching the failure criteria." + ) + if result.preview_rows: + doc.heading(2, "Source data preview (first row)") + first = result.preview_rows[0] + columns = list(first.keys()) + if can_view_pii: + values = [first[c] for c in columns] + else: + values = ["[redacted]"] * len(columns) + doc.table(columns, [values]) + doc.text( + "_If saved as a CUSTOM test, this would currently fail because the SQL returned rows " + "matching the test failure criteria. Refine the query if some of those rows are false positives._" + ) + if not can_view_pii: + doc.text( + "_PII redacted: caller does not have permissions to view PII on this project._" + ) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def bulk_update_tests( + test_suite_id: str, + action: str, + table_name: str | None = None, + test_type: str | None = None, +) -> str: + """Enable or disable tests in a suite in bulk. + + Args: + test_suite_id: UUID of the test suite. + action: ``enable`` or ``disable``. + table_name: Optional table-name filter. Case-sensitive. + test_type: Optional test type name (e.g. ``Alpha Truncation``). + """ + try: + bulk_action = BulkAction(action) + except ValueError as err: + valid = ", ".join(f"`{a.value}`" for a in BulkAction) + raise MCPUserError(f"`action` must be one of: {valid}.") from err + suite = resolve_test_suite(test_suite_id) + tt_code = resolve_test_type(test_type) if test_type else None + + target = bulk_action is BulkAction.ENABLE + values: dict = {"test_active": target} + if target: + # Mirrors set_status_attribute: clearing the status when re-enabling so failed + # tests don't carry forward a stale "disabled because of X" marker. + values["test_definition_status"] = None + + where_clauses = [TestDefinition.test_suite_id == suite.id] + if table_name: + where_clauses.append(TestDefinition.table_name == table_name) + if tt_code: + where_clauses.append(TestDefinition.test_type == tt_code) + + stmt = update(TestDefinition).where(*where_clauses).values(**values) + session = get_current_session() + count = session.execute(stmt).rowcount + + verb = "Enabled" if target else "Disabled" + filters = [] + if table_name: + filters.append(f"table_name=`{table_name}`") + if test_type: + filters.append(f"test_type=`{test_type}`") + filter_str = ", ".join(filters) if filters else "no filter" + + doc = MdDoc() + if count == 0: + doc.heading(1, "No tests matched") + doc.text( + f"No tests in suite `{suite.test_suite}` matched the filter ({filter_str}). Nothing changed." + ) + return doc.render() + + doc.heading(1, f"{verb} {count} test(s) in suite `{suite.test_suite}`") + doc.field("Filter", filter_str) + return doc.render() diff --git a/testgen/mcp/tools/test_results.py b/testgen/mcp/tools/test_results.py index ec708a3a..d35d3f0a 100644 --- a/testgen/mcp/tools/test_results.py +++ b/testgen/mcp/tools/test_results.py @@ -1,9 +1,12 @@ from datetime import UTC, datetime, timedelta +from uuid import UUID -from testgen.common.models import with_database_session +from testgen.common.enums import JobStatus +from testgen.common.models import get_current_session, with_database_session +from testgen.common.models.job_execution import JobExecution from testgen.common.models.test_definition import TestType from testgen.common.models.test_result import BucketInterval, TestResult, TestResultStatus -from testgen.common.models.test_run import TestRun +from testgen.common.models.test_run import TestRun, TestRunSummary from testgen.common.models.test_suite import TestSuite from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import get_project_permissions, mcp_permission @@ -42,7 +45,7 @@ def list_test_results( the latest completed run of that suite. Args: - job_execution_id: UUID of a test run, e.g. from ``get_recent_test_runs`` or + job_execution_id: UUID of a test run, e.g. from ``list_test_runs`` or ``list_test_suites``. test_suite_id: UUID of a test suite. Resolves to the latest completed test run for the suite. Mutually exclusive with ``job_execution_id``. @@ -158,7 +161,7 @@ def get_failure_summary( Args: project_code: Scope to a project the caller can view. Ignored if ``job_execution_id`` is set. test_suite_id: UUID of a test suite to scope the aggregation to. - job_execution_id: UUID of a test run, e.g. from ``get_recent_test_runs``, + job_execution_id: UUID of a test run, e.g. from ``list_test_runs``, to scope the summary to a single run. since: Include runs since this point in time — e.g. '7 days', '2 weeks', '2026-04-01'. group_by: Group failures by 'test_type', 'table', or 'column' (default: 'test_type'). @@ -263,7 +266,7 @@ def get_failure_summary( @with_database_session @mcp_permission("view") -def get_test_result_history( +def list_test_result_history( test_definition_id: str, limit: int = 20, page: int = 1, @@ -330,7 +333,7 @@ def search_test_results( """Search test results across multiple runs with flexible filters. To drill into a single run, use ``list_test_results``. For a single test's history, use - ``get_test_result_history``. + ``list_test_result_history``. Args: project_code: Scope to a project the caller can view. @@ -520,66 +523,85 @@ def get_failure_trend( @with_database_session @mcp_permission("view") -def get_test_run_diff(job_execution_id_a: str, job_execution_id_b: str) -> str: +def compare_test_runs( + target_job_execution_id: str, + baseline_job_execution_id: str | None = None, +) -> str: """Compare two test runs and report regressions, improvements, persistent failures, and added/removed tests. + When ``baseline_job_execution_id`` is omitted, the baseline defaults to the immediately + previous completed test run on the same test suite as the target run. + Args: - job_execution_id_a: UUID of the older (baseline) test run, e.g. from ``get_recent_test_runs``. - job_execution_id_b: UUID of the newer test run. + target_job_execution_id: UUID of the newer test run, e.g. from ``list_test_runs``. + baseline_job_execution_id: Optional UUID of the older test run. + When omitted, defaults to the previous completed run on the same test suite. """ - uuid_a = parse_uuid(job_execution_id_a, "job_execution_id_a") - uuid_b = parse_uuid(job_execution_id_b, "job_execution_id_b") - - run_a = TestRun.get_by_id_or_job(uuid_a) - run_b = TestRun.get_by_id_or_job(uuid_b) - - # Permission check first — unify "not found" and "inaccessible" (also covers monitor suites, - # which are hidden from this tool the same way they're hidden from the inventory tools). perms = get_project_permissions() - suite_ids = [r.test_suite_id for r in (run_a, run_b) if r is not None] - suites_by_id: dict = {} - if suite_ids: - suites_by_id = { - s.id: s for s in TestSuite.select_where(TestSuite.id.in_(suite_ids)) - } - - def _accessible(run) -> bool: + + def _resolve_accessible(je_id_str: str, je_uuid: UUID) -> TestRun: + run = TestRun.get_by_id_or_job(je_uuid) if run is None: - return False - suite = suites_by_id.get(run.test_suite_id) - if suite is None or suite.is_monitor: - return False - return perms.has_access(suite.project_code) - - if not _accessible(run_a): - raise MCPResourceNotAccessible("Test run", job_execution_id_a) - if not _accessible(run_b): - raise MCPResourceNotAccessible("Test run", job_execution_id_b) - - # Both runs confirmed accessible — safe to reveal suite IDs in the compatibility message. - if run_a.test_suite_id != run_b.test_suite_id: - raise MCPUserError( - "Both runs must belong to the same test suite to be comparable. " - f"Run A is in suite `{run_a.test_suite_id}`, run B is in suite `{run_b.test_suite_id}`. " - "Use `get_recent_test_runs(test_suite=...)` to pick two runs of the same suite." - ) + raise MCPResourceNotAccessible("Test run", je_id_str) + suite = TestSuite.get_regular(run.test_suite_id) + if suite is None or not perms.has_access(suite.project_code): + raise MCPResourceNotAccessible("Test run", je_id_str) + return run + + def _require_completed(run: TestRun, label: str) -> None: + je = get_current_session().get(JobExecution, run.job_execution_id) + if je.status != JobStatus.COMPLETED: + status_label = TestRunSummary.STATUS_LABEL.get(je.status, je.status) + raise MCPUserError( + f"{label} run is in `{status_label}` state — comparison requires a completed run." + ) - diff = TestResult.diff_with_details(run_a.id, run_b.id) + target_uuid = parse_uuid(target_job_execution_id, "target_job_execution_id") + target_run = _resolve_accessible(target_job_execution_id, target_uuid) + _require_completed(target_run, "Target") + + if baseline_job_execution_id is None: + baseline_run = target_run.get_previous() + if baseline_run is None: + raise MCPUserError( + f"Target run `{target_job_execution_id}` has no earlier completed " + "test run on its test suite to compare against." + ) + else: + baseline_uuid = parse_uuid(baseline_job_execution_id, "baseline_job_execution_id") + baseline_run = _resolve_accessible(baseline_job_execution_id, baseline_uuid) + if baseline_run.test_suite_id != target_run.test_suite_id: + raise MCPUserError( + "Both runs must belong to the same test suite to be comparable. " + f"Target is in suite `{target_run.test_suite_id}`, " + f"baseline is in suite `{baseline_run.test_suite_id}`. " + "Use `list_test_runs(test_suite=...)` to pick two runs of the same suite." + ) + _require_completed(baseline_run, "Baseline") + + diff = TestResult.diff_with_details(baseline_run.id, target_run.id) doc = MdDoc() - doc.heading(1, "Test Run Diff") - doc.field("Test Run A", job_execution_id_a, code=True) - doc.field("Test Run B", job_execution_id_b, code=True) + doc.heading(1, "Test Run Comparison") + doc.table( + ["", "Target", "Baseline"], + [ + ["Test Run", + MdDoc.code(str(target_run.job_execution_id)), + MdDoc.code(str(baseline_run.job_execution_id))], + ["Started", target_run.test_starttime, baseline_run.test_starttime], + ], + ) doc.table( headers=["Category", "Count"], rows=[ - ["Regressions (A passed → B failed/warning)", len(diff.regressions)], - ["Improvements (A failed/warning → B passed)", len(diff.improvements)], + ["Regressions (Baseline passed → Target failed/warning)", len(diff.regressions)], + ["Improvements (Baseline failed/warning → Target passed)", len(diff.improvements)], ["Persistent failures", len(diff.persistent_failures)], - ["New tests (only in B)", len(diff.new_tests)], - ["Removed tests (only in A)", len(diff.removed_tests)], - ["Total in A", diff.total_a], - ["Total in B", diff.total_b], + ["New tests (only in Target)", len(diff.new_tests)], + ["Removed tests (only in Baseline)", len(diff.removed_tests)], + ["Total in Target", diff.total_target], + ["Total in Baseline", diff.total_baseline], ], ) @@ -588,17 +610,21 @@ def _section(title: str, rows: list) -> None: return doc.heading(2, title) doc.table( - headers=["Test Type", "Table", "Column", "A → B", "Measure A", "Measure B", "Threshold A", "Threshold B"], + headers=[ + "Test Type", "Table", "Column", "Baseline → Target", + "Measure Baseline", "Measure Target", "Threshold Baseline", "Threshold Target", + ], rows=[ [ row.test_name_short or row.test_type, row.table_name, row.column_names, - f"{row.status_a.value if row.status_a else '—'} → {row.status_b.value if row.status_b else '—'}", - row.measure_a, - row.measure_b, - row.threshold_a, - row.threshold_b, + f"{row.status_baseline.value if row.status_baseline else '—'} → " + f"{row.status_target.value if row.status_target else '—'}", + row.measure_baseline, + row.measure_target, + row.threshold_baseline, + row.threshold_target, ] for row in rows ], diff --git a/testgen/mcp/tools/test_runs.py b/testgen/mcp/tools/test_runs.py index 68f9ce7b..415571e6 100644 --- a/testgen/mcp/tools/test_runs.py +++ b/testgen/mcp/tools/test_runs.py @@ -1,8 +1,24 @@ +from datetime import datetime + from testgen.common.models import with_database_session -from testgen.common.models.test_run import TestRun +from testgen.common.models.job_execution import JobExecution +from testgen.common.models.scheduler import RUN_TESTS_JOB_KEY +from testgen.common.models.test_run import TestRun, TestRunSummary from testgen.common.models.test_suite import TestSuite +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import get_project_permissions, mcp_permission -from testgen.mcp.tools.common import DocGroup, validate_limit +from testgen.mcp.tools.common import ( + DocGroup, + format_page_footer, + format_page_info, + format_run_duration, + next_scheduled_run, + parse_run_status_filter, + parse_uuid, + resolve_table_group, + validate_limit, + validate_page, +) from testgen.mcp.tools.markdown import MdDoc _DOC_GROUP = DocGroup.INVESTIGATE @@ -10,76 +26,294 @@ @with_database_session @mcp_permission("view") -def get_recent_test_runs(project_code: str, test_suite: str | None = None, limit: int = 1) -> str: - """Get the latest test runs for each test suite in a project, optionally filtered by test suite name. +def list_test_runs( + project_code: str | None = None, + test_suite: str | None = None, + table_group_id: str | None = None, + status: str | None = None, + limit: int = 10, + page: int = 1, +) -> str: + """List test runs across a project, including queued and in-progress runs. Ordered by submission + time descending. Excludes monitor suites. Args: - project_code: The project code to query. - test_suite: Optional test suite name to filter by. - limit: Maximum runs per test suite (default 1, max 100). + project_code: Project code to query, e.g. from `list_projects`. Required unless + `table_group_id` is provided (which scopes to a single project). + test_suite: Optional test suite name to filter by (case-sensitive). + table_group_id: Optional UUID of a table group, e.g. from `get_data_inventory`. Returns + runs for any suite in the group. + status: Optional run status filter. One of: Pending, Running, Completed, Canceled, Error. + limit: Page size (default 10, max 100). + page: Page number starting at 1 (default 1). """ - if not project_code: - return "Missing required parameter `project_code`." validate_limit(limit, 100) + validate_page(page) - perms = get_project_permissions() - perms.verify_access(project_code, not_found=f"No completed test runs found in project `{project_code}`.") + statuses = parse_run_status_filter(status) if status else None + if not project_code and not table_group_id: + raise MCPUserError("Provide either `project_code` or `table_group_id`.") + + perms = get_project_permissions() test_suite_id = None + table_group = None + + if table_group_id: + table_group = resolve_table_group(table_group_id) + if project_code and project_code != table_group.project_code: + raise MCPUserError( + f"`project_code` `{project_code}` does not match the table group's project." + ) + project_code = table_group.project_code + else: + perms.verify_access( + project_code, + not_found=MCPResourceNotAccessible("Project", project_code), + ) + if test_suite: suites = TestSuite.select_minimal_where( TestSuite.project_code == project_code, TestSuite.test_suite == test_suite, + TestSuite.is_monitor.isnot(True), ) if not suites: - return f"Test suite `{test_suite}` not found in project `{project_code}`." + raise MCPResourceNotAccessible("Test suite", test_suite) test_suite_id = str(suites[0].id) - summaries, _ = TestRun.select_summary(project_code=project_code, test_suite_id=test_suite_id, page_size=1000) + summaries, total = TestRun.select_summary( + project_code=project_code, + table_group_id=str(table_group.id) if table_group else None, + test_suite_id=test_suite_id, + statuses=statuses, + page=page, + page_size=limit, + ) + + # Queued/claimed JEs that don't yet have a test_runs row are invisible to suite/TG-scoped + # joined-run queries. Surface them as a separate "Pending" section on page 1. + pending_jes: list[JobExecution] = [] + if page == 1 and (test_suite_id or table_group): + pending_jes = _select_pending_test_jes( + project_code=project_code, + test_suite_id=test_suite_id, + table_group_id=str(table_group.id) if table_group else None, + statuses=statuses, + ) + + scope_descriptor = _scope_descriptor(project_code, test_suite, table_group_id, status) + doc = MdDoc() + doc.heading(1, f"Test runs{scope_descriptor}") + + next_run = _next_test_run( + project_code=project_code, + test_suite_id=test_suite_id, + table_group_id=str(table_group.id) if table_group else None, + ) + if next_run: + doc.field("Next scheduled run", next_run) + + if pending_jes: + doc.heading(2, f"Pending ({len(pending_jes)})") + for je in pending_jes: + _render_pending_je(doc, je, label=test_suite or "Test run") + + page_info = format_page_info(total, page, limit) + if page_info: + doc.text(page_info) if not summaries: - scope = f" for suite `{test_suite}`" if test_suite else "" - return f"No completed test runs found in project `{project_code}`{scope}." - - # Take the first `limit` runs per suite (summaries are ordered by test_starttime DESC) - seen: dict[str, int] = {} - runs = [] - for s in summaries: - count = seen.get(s.test_suite, 0) - if count < limit: - runs.append(s) - seen[s.test_suite] = count + 1 + if page > 1: + doc.text(f"_No test runs on page {page} (total: {total})._") + elif not pending_jes: + doc.text("_No test runs found._") + return doc.render() + + for run in summaries: + _render_test_run_section(doc, run) + + footer = format_page_footer(total, page, limit) + if footer: + doc.text(footer) + + return doc.render() + + +@with_database_session +@mcp_permission("view") +def get_test_run(job_execution_id: str) -> str: + """Get a single test run with status, timing, result counts, and testing score. Returns the + run regardless of state — including queued and in-progress runs without complete results yet. + + Args: + job_execution_id: UUID of a test run, e.g. from `list_test_runs`. + """ + parse_uuid(job_execution_id, "job_execution_id") + perms = get_project_permissions() + + summaries, _ = TestRun.select_summary(job_execution_id=job_execution_id, page_size=1) + summary = summaries[0] if summaries else None + if summary is None or summary.project_code not in perms.allowed_codes: + raise MCPResourceNotAccessible("Test run", job_execution_id) doc = MdDoc() + suite_label = summary.test_suite or "—" + doc.heading(1, f"Test run: {suite_label}") + doc.field("Job ID", summary.job_execution_id, code=True) + doc.field("Test suite", suite_label) + if summary.table_groups_name: + doc.field("Table group", summary.table_groups_name) + doc.field("Project", summary.project_code) + doc.field("Status", summary.status_label) + doc.field("Submitted", summary.created_at) + doc.field("Started", summary.started_at or "—") + doc.field("Ended", summary.completed_at or "In progress") + duration = format_run_duration(summary.started_at, summary.completed_at) + if duration: + doc.field("Duration", duration) + + has_results = summary.test_ct or summary.passed_ct or summary.failed_ct or summary.warning_ct or summary.error_ct + if has_results: + passed = summary.passed_ct or 0 + failed = summary.failed_ct or 0 + warning = summary.warning_ct or 0 + errors = summary.error_ct or 0 + doc.field( + "Results", + f"{summary.test_ct or 0} tests — {passed} passed, {failed} failed, {warning} warnings, {errors} errors", + ) + if summary.dismissed_ct: + doc.field("Dismissed", summary.dismissed_ct) + if summary.dq_score_testing is not None: + doc.field("Testing Score", f"{summary.dq_score_testing:.1f}") + + if summary.error_message: + doc.heading(2, "Error") + doc.text(summary.error_message) + + return doc.render() + + +def _scope_descriptor( + project_code: str | None, + test_suite: str | None, + table_group_id: str | None, + status: str | None, +) -> str: + parts: list[str] = [] + if project_code: + parts.append(f"project `{project_code}`") if test_suite: - doc.heading(1, f"Recent Test Runs for `{project_code}` / `{test_suite}`") - else: - doc.heading(1, f"Recent Test Runs for `{project_code}`") - doc.text(f"Showing {len(runs)} run(s) ({limit} per suite).") + parts.append(f"suite `{test_suite}`") + if table_group_id: + parts.append(f"table group `{table_group_id}`") + if status: + parts.append(f"status `{status}`") + return f" — {', '.join(parts)}" if parts else "" - current_suite = None - for run in runs: - if run.test_suite != current_suite: - current_suite = run.test_suite - doc.heading(2, current_suite) - passed = run.passed_ct or 0 - failed = run.failed_ct or 0 - warning = run.warning_ct or 0 - errors = run.error_ct or 0 +def _next_test_run( + project_code: str | None, + test_suite_id: str | None, + table_group_id: str | None, +) -> datetime | None: + """Compute the next scheduled test run when scoped to a single suite or table group.""" + if not project_code: + return None + if test_suite_id: + return next_scheduled_run(RUN_TESTS_JOB_KEY, {"test_suite_id": test_suite_id}, project_code) + if table_group_id: + suite_ids = [ + str(s.id) + for s in TestSuite.select_minimal_where( + TestSuite.project_code == project_code, + TestSuite.table_groups_id == table_group_id, + TestSuite.is_monitor.isnot(True), + ) + ] + candidates = [ + next_scheduled_run(RUN_TESTS_JOB_KEY, {"test_suite_id": sid}, project_code) + for sid in suite_ids + ] + candidates = [c for c in candidates if c is not None] + return min(candidates) if candidates else None + return None - doc.heading(3, f"{run.created_at} — {run.status_label}") - doc.field("Test Run", run.job_execution_id, code=True) - doc.field("Started", run.created_at) - doc.field("Ended", run.completed_at or "In progress") - doc.field("Results", f"{run.test_ct or 0} tests — {passed} passed, {failed} failed, {warning} warnings, {errors} errors") - if run.dismissed_ct: - doc.field("Dismissed", run.dismissed_ct) +def _select_pending_test_jes( + *, + project_code: str, + test_suite_id: str | None, + table_group_id: str | None, + statuses, +) -> list[JobExecution]: + """Find queued/in-flight test-run JEs for a given suite or table group scope. For a + table-group scope, expands to the non-monitor suites in the group so monitor runs stay + excluded. + """ + if test_suite_id: + suite_ids: str | list[str] = test_suite_id + elif table_group_id: + suite_ids = [ + str(s.id) + for s in TestSuite.select_minimal_where( + TestSuite.project_code == project_code, + TestSuite.table_groups_id == table_group_id, + TestSuite.is_monitor.isnot(True), + ) + ] + if not suite_ids: + return [] + else: + return [] + return JobExecution.select_active_by_kwargs( + project_code=project_code, + job_key=RUN_TESTS_JOB_KEY, + kwargs_match={"test_suite_id": suite_ids}, + statuses=statuses, + ) - if run.dq_score_testing is not None: - doc.field("Testing Score", f"{run.dq_score_testing:.1f}") - doc.text("Use `list_test_results(job_execution_id='...')` for detailed results of a specific run.") +def _render_pending_je(doc: MdDoc, je: JobExecution, label: str) -> None: + status_label = TestRunSummary.STATUS_LABEL.get(je.status, je.status) + doc.heading(3, f"{label} — {status_label}") + doc.field("Job ID", je.id, code=True) + if je.job_schedule_id is not None: + doc.field("Schedule", je.job_schedule_id, code=True) + doc.field("Submitted", je.created_at) + doc.field("Started", je.started_at or "—") + doc.field("Ended", je.completed_at or "In progress") - return doc.render() + +def _render_test_run_section(doc: MdDoc, run: TestRunSummary) -> None: + title = run.test_suite or run.project_code + doc.heading(2, f"{title} — {run.status_label}") + doc.field("Job ID", run.job_execution_id, code=True) + if run.job_schedule_id is not None: + doc.field("Schedule", run.job_schedule_id, code=True) + if run.test_suite: + doc.field("Test suite", run.test_suite) + if run.table_groups_name: + doc.field("Table group", run.table_groups_name) + doc.field("Submitted", run.created_at) + doc.field("Started", run.started_at or "—") + doc.field("Ended", run.completed_at or "In progress") + duration = format_run_duration(run.started_at, run.completed_at) + if duration: + doc.field("Duration", duration) + + passed = run.passed_ct or 0 + failed = run.failed_ct or 0 + warning = run.warning_ct or 0 + errors = run.error_ct or 0 + if run.test_ct or passed or failed or warning or errors: + doc.field( + "Results", + f"{run.test_ct or 0} tests — {passed} passed, {failed} failed, {warning} warnings, {errors} errors", + ) + + if run.dismissed_ct: + doc.field("Dismissed", run.dismissed_ct) + if run.dq_score_testing is not None: + doc.field("Testing Score", f"{run.dq_score_testing:.1f}") diff --git a/testgen/scheduler/cli_scheduler.py b/testgen/scheduler/cli_scheduler.py index ddc3c9ec..76375138 100644 --- a/testgen/scheduler/cli_scheduler.py +++ b/testgen/scheduler/cli_scheduler.py @@ -12,8 +12,9 @@ from testgen import settings from testgen.commands.job_registry import JOB_DISPATCH, run_final_callbacks +from testgen.common.enums import JobStatus from testgen.common.models import database_session, with_database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.scheduler import JobSchedule from testgen.scheduler.base import DelayedPolicy, Job, Scheduler @@ -22,7 +23,6 @@ @dataclass class CliJob(Job): key: str - args: Iterable[Any] kwargs: dict[str, Any] project_code: str | None = field(default=None) job_schedule_id: UUID | None = field(default=None) @@ -48,7 +48,7 @@ def get_jobs(self) -> Iterable[CliJob]: self.reload_timer.start() jobs = {} - for job_model in JobSchedule.select_where(): + for job_model in JobSchedule.select_runnable(): if job_model.key not in JOB_DISPATCH: LOG.error("Job '%s' scheduled but not registered", job_model.key) continue @@ -58,7 +58,6 @@ def get_jobs(self) -> Iterable[CliJob]: cron_tz=job_model.cron_tz, delayed_policy=DelayedPolicy.SKIP, key=job_model.key, - args=job_model.args, kwargs=job_model.kwargs, project_code=job_model.project_code, job_schedule_id=job_model.id, @@ -80,7 +79,7 @@ def start_job(self, job: CliJob, triggering_time: datetime) -> None: JobExecution.submit( job_key=job.key, kwargs=job.kwargs, - source="scheduler", + source=JOB_DISPATCH[job.key].scheduler_source, project_code=job.project_code, job_schedule_id=job.job_schedule_id, ) @@ -249,5 +248,3 @@ def run_scheduler(): scheduler = CliScheduler() scheduler.run() - - diff --git a/testgen/server/__init__.py b/testgen/server/__init__.py index 120a7789..d0867e9d 100644 --- a/testgen/server/__init__.py +++ b/testgen/server/__init__.py @@ -17,16 +17,14 @@ if settings.IS_DEBUG: os.environ.setdefault("AUTHLIB_INSECURE_TRANSPORT", "1") -from testgen.api.app import router as api_router -from testgen.api.jobs import router as jobs_router +from testgen.api import router as api_v1_router from testgen.api.oauth.metadata import router as metadata_router from testgen.api.oauth.routes import init_routes from testgen.api.oauth.routes import router as oauth_router from testgen.api.oauth.server import create_authorization_server -from testgen.api.runs import router as runs_router -from testgen.api.test_definitions import router as test_definitions_router from testgen.common import version_service from testgen.common.models import with_database_session +from testgen.server.middleware import BodySizeLimitMiddleware, SecurityHeadersMiddleware LOG = logging.getLogger("testgen") @@ -123,14 +121,26 @@ def favicon(): app.include_router(metadata_router) app.include_router(oauth_router) - app.include_router(api_router) - app.include_router(jobs_router) - app.include_router(runs_router) - app.include_router(test_definitions_router) + app.include_router(api_v1_router) if settings.MCP_ENABLED: app.mount("", mcp_app) + # add_middleware is LIFO — body cap is added first so it runs innermost, + # rejecting oversized requests before security headers wrap the 413 response + app.add_middleware(BodySizeLimitMiddleware, max_bytes=settings.API_MAX_REQUEST_BODY_BYTES) + + hsts = settings.API_HSTS_HEADER or ( + "max-age=63072000; includeSubDomains" if settings.API_TLS_ENABLED else None + ) + app.add_middleware( + SecurityHeadersMiddleware, + hsts=hsts, + csp=settings.API_CSP_HEADER, + referrer=settings.API_REFERRER_POLICY, + nosniff=True, + ) + if settings.IS_DEBUG: from starlette.middleware.cors import CORSMiddleware @@ -171,4 +181,11 @@ def run_server() -> None: "enabled" if settings.API_TLS_ENABLED else "disabled", "enabled" if settings.MCP_ENABLED else "disabled", ) - uvicorn.run(app, host=settings.API_HOST, port=settings.API_PORT, log_level="info", **ssl_kwargs) + uvicorn.run( + app, + host=settings.API_HOST, + port=settings.API_PORT, + log_level="info", + timeout_graceful_shutdown=settings.API_GRACEFUL_SHUTDOWN_TIMEOUT, + **ssl_kwargs, + ) diff --git a/testgen/server/middleware.py b/testgen/server/middleware.py new file mode 100644 index 00000000..56661a6b --- /dev/null +++ b/testgen/server/middleware.py @@ -0,0 +1,114 @@ +"""ASGI middlewares for the combined FastAPI + MCP server. + +These are pure-ASGI implementations (not BaseHTTPMiddleware) to avoid buffering +responses, which would break MCP's text/event-stream transport. +""" + +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +_413_BODY = b'{"detail":"Request body too large"}' + + +async def _send_413(send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": 413, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(_413_BODY)).encode()), + ], + } + ) + await send({"type": "http.response.body", "body": _413_BODY}) + + +class BodySizeLimitMiddleware: + """Reject requests whose body exceeds *max_bytes* with HTTP 413. + + Checks Content-Length up front when present; otherwise tracks accumulated + body bytes and disconnects when the limit is exceeded mid-stream. Only + inspects http.request messages, so MCP SSE response streams pass through + untouched. + """ + + def __init__(self, app: ASGIApp, max_bytes: int) -> None: + self.app = app + self.max_bytes = max_bytes + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http" or scope.get("method") in ("GET", "HEAD", "OPTIONS"): + await self.app(scope, receive, send) + return + + content_length = next( + (v for k, v in scope.get("headers", []) if k == b"content-length"), None + ) + if content_length is not None: + try: + if int(content_length) > self.max_bytes: + await _send_413(send) + return + except ValueError: + pass + + received = 0 + exceeded = False + + async def limited_receive() -> Message: + nonlocal received, exceeded + if exceeded: + return {"type": "http.disconnect"} + message = await receive() + if message["type"] == "http.request": + received += len(message.get("body", b"")) + if received > self.max_bytes: + exceeded = True + return {"type": "http.disconnect"} + return message + + await self.app(scope, limited_receive, send) + + +class SecurityHeadersMiddleware: + """Inject standard security headers on every HTTP response. + + Headers are added to http.response.start, so they apply uniformly to success + and error responses. Existing headers (case-insensitive match) are preserved, + letting per-route handlers override defaults. + """ + + def __init__( + self, + app: ASGIApp, + *, + hsts: str | None, + csp: str, + referrer: str, + nosniff: bool, + ) -> None: + self.app = app + self.headers: list[tuple[bytes, bytes]] = [] + if hsts: + self.headers.append((b"strict-transport-security", hsts.encode())) + if nosniff: + self.headers.append((b"x-content-type-options", b"nosniff")) + if referrer: + self.headers.append((b"referrer-policy", referrer.encode())) + if csp: + self.headers.append((b"content-security-policy", csp.encode())) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + async def send_wrapper(message: Message) -> None: + if message["type"] == "http.response.start": + existing = {k.lower() for k, _ in message.get("headers", [])} + for name, value in self.headers: + if name not in existing: + message["headers"].append((name, value)) + await send(message) + + await self.app(scope, receive, send_wrapper) diff --git a/testgen/settings.py b/testgen/settings.py index 94c4e002..69e922dc 100644 --- a/testgen/settings.py +++ b/testgen/settings.py @@ -512,6 +512,14 @@ def _ssl_files_present() -> bool: Disables sending usage data when set to any value except "true" and "yes". Defaults to "yes" """ +DISABLE_FEEDBACK_POPUP: bool = getenv("TG_DISABLE_FEEDBACK_POPUP", "no").lower() in ("yes", "true") +""" +When set to "yes" or "true", suppresses the periodic feedback popup entirely. + +from env variable: `TG_DISABLE_FEEDBACK_POPUP` +defaults to: `no` +""" + JOB_POLL_INTERVAL: int = int(getenv("TG_JOB_POLL_INTERVAL", "5")) """ Seconds between polls for pending job executions. @@ -618,3 +626,65 @@ def _default_ui_base_url() -> str: from env variable: `TG_UI_BASE_URL` defaults to: computed from UI_TLS_ENABLED and UI_PORT """ + +MCP_EXTRA_ALLOWED_HOSTS: list[str] = [ + h.strip() for h in (getenv("TG_MCP_EXTRA_ALLOWED_HOSTS", "") or "").split(",") if h.strip() +] +""" +Extra Host header values accepted by MCP DNS rebinding protection (comma-separated). +BASE_URL's hostname and loopback are always allowed; this adds more for multi-domain +deployments or reverse proxies that rewrite Host. Entries without a port (`tg.example.com`) +get an automatic `:*` wildcard; entries with a port are matched literally +(`tg.example.com:8080`) or with explicit wildcard (`tg.example.com:*`). +Only affects MCP routes — the parent FastAPI app does not validate Host headers. + +from env variable: `TG_MCP_EXTRA_ALLOWED_HOSTS` +defaults to: empty (BASE_URL hostname + loopback only) +""" + +API_MAX_REQUEST_BODY_BYTES: int = int( + getenv("TG_API_MAX_REQUEST_BODY_BYTES", str(10 * 1024 * 1024)) +) +""" +Reject HTTP requests larger than this with 413 Payload Too Large. + +from env variable: `TG_API_MAX_REQUEST_BODY_BYTES` +defaults to: 10485760 (10 MiB) +""" + +API_GRACEFUL_SHUTDOWN_TIMEOUT: int = int(getenv("TG_API_GRACEFUL_SHUTDOWN_TIMEOUT", "30")) +""" +Seconds uvicorn waits for in-flight requests on SIGTERM before force-closing. +Long blocking SQL queries that don't honor asyncio cancellation may be cut mid-flight; +align with the target DB's statement_timeout. + +from env variable: `TG_API_GRACEFUL_SHUTDOWN_TIMEOUT` +defaults to: 30 +""" + +API_HSTS_HEADER: str = getenv("TG_API_HSTS_HEADER", "") +""" +Override HSTS (Strict-Transport-Security) header value. When empty, HSTS is emitted +only when API_TLS_ENABLED with value 'max-age=63072000; includeSubDomains'. Setting +this forces emission regardless of TLS (useful when TLS terminates at a reverse proxy). + +from env variable: `TG_API_HSTS_HEADER` +defaults to: empty (auto from API_TLS_ENABLED) +""" + +API_CSP_HEADER: str = getenv("TG_API_CSP_HEADER", "frame-ancestors 'none'") +""" +Content-Security-Policy header value. Default restricts framing only; broader policies +risk breaking Redoc at /api/docs which loads CDN assets. + +from env variable: `TG_API_CSP_HEADER` +defaults to: `frame-ancestors 'none'` +""" + +API_REFERRER_POLICY: str = getenv("TG_API_REFERRER_POLICY", "no-referrer") +""" +Referrer-Policy header value. + +from env variable: `TG_API_REFERRER_POLICY` +defaults to: `no-referrer` +""" diff --git a/testgen/template/dbsetup/020_create_standard_functions_sprocs.sql b/testgen/template/dbsetup/020_create_standard_functions_sprocs.sql index 013343f0..57b83256 100644 --- a/testgen/template/dbsetup/020_create_standard_functions_sprocs.sql +++ b/testgen/template/dbsetup/020_create_standard_functions_sprocs.sql @@ -226,10 +226,12 @@ BEGIN RAISE EXCEPTION 'Invalid expression: dangerous statement detected'; END IF; - -- Remove all allowed tokens from the validation expression, treating 'FLOAT' as a keyword + -- Remove all allowed tokens from the validation expression, treating 'FLOAT' as a keyword. + -- Numeric pattern accepts leading-dot decimals (e.g. ".733") that Oracle emits + -- when converting NUMBER values with |x| < 1 to VARCHAR2. invalid_parts := regexp_replace( expression, - E'(\\mGREATEST|LEAST|ABS|FN_NORMAL_CDF|DATEDIFF|DAY|FLOAT|NULLIF)\\M|[0-9]+(\\.[0-9]+)?([eE][+-]?[0-9]+)?|[+\\-*/(),\\\'":]+|\\s+', + E'(\\mGREATEST|LEAST|ABS|FN_NORMAL_CDF|DATEDIFF|DAY|FLOAT|NULLIF)\\M|([0-9]+\\.?[0-9]*|\\.[0-9]+)([eE][+-]?[0-9]+)?|[+\\-*/(),\\\'":]+|\\s+', '', 'gi' ); diff --git a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql index 1e7217df..66e4db8b 100644 --- a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql +++ b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql @@ -50,14 +50,16 @@ CREATE TABLE stg_test_definition_updates ( ); CREATE TABLE projects ( - id UUID DEFAULT gen_random_uuid(), - project_code VARCHAR(30) NOT NULL + id UUID DEFAULT gen_random_uuid(), + project_code VARCHAR(30) NOT NULL CONSTRAINT projects_project_code_pk PRIMARY KEY, - project_name VARCHAR(50), + project_name VARCHAR(50), observability_api_key TEXT, observability_api_url TEXT DEFAULT '', - use_dq_score_weights BOOLEAN DEFAULT TRUE + use_dq_score_weights BOOLEAN DEFAULT TRUE, + data_retention_enabled BOOLEAN NOT NULL DEFAULT TRUE, + data_retention_days INTEGER DEFAULT 180 ); CREATE TABLE connections ( @@ -72,7 +74,7 @@ CREATE TABLE connections ( sql_flavor_code VARCHAR(30), project_host VARCHAR(250), project_port VARCHAR(5), - project_user VARCHAR(50), + project_user VARCHAR(256), project_db VARCHAR(100), connection_name VARCHAR(40), project_pw_encrypted BYTEA, @@ -711,7 +713,8 @@ CREATE TABLE auth_users ( name VARCHAR(256), password VARCHAR(120), is_global_admin BOOLEAN NOT NULL DEFAULT FALSE, - latest_login TIMESTAMP + latest_login TIMESTAMP, + preferences JSONB NOT NULL DEFAULT '{}' ); ALTER TABLE auth_users @@ -966,6 +969,9 @@ CREATE INDEX ix_dsl_tg_tcd CREATE INDEX ix_prun_pc_con ON profiling_runs(project_code, connection_id); +CREATE INDEX ix_prun_pc_starttime + ON profiling_runs(project_code, profiling_starttime); + CREATE INDEX ix_prun_tg ON profiling_runs(table_groups_id); @@ -1066,12 +1072,11 @@ CREATE TABLE job_schedules ( id UUID NOT NULL PRIMARY KEY, project_code VARCHAR(30) NOT NULL, key VARCHAR(100) NOT NULL, - args JSONB NOT NULL, kwargs JSONB NOT NULL, cron_expr VARCHAR(50) NOT NULL, cron_tz VARCHAR(30) NOT NULL, active BOOLEAN DEFAULT TRUE, - UNIQUE (project_code, key, args, kwargs, cron_expr, cron_tz) + UNIQUE (project_code, key, kwargs, cron_expr, cron_tz) ); CREATE INDEX job_schedules_idx ON job_schedules (project_code, key); @@ -1079,7 +1084,6 @@ CREATE INDEX job_schedules_idx ON job_schedules (project_code, key); CREATE TABLE job_executions ( id UUID NOT NULL DEFAULT gen_random_uuid() PRIMARY KEY, job_key VARCHAR(100) NOT NULL, - args JSONB NOT NULL DEFAULT '[]'::jsonb, kwargs JSONB NOT NULL DEFAULT '{}'::jsonb, source VARCHAR(20) NOT NULL, status VARCHAR(20) NOT NULL DEFAULT 'pending', @@ -1095,6 +1099,8 @@ CREATE TABLE job_executions ( CREATE INDEX idx_job_executions_poll ON job_executions (status, created_at) WHERE status = 'pending'; CREATE INDEX idx_job_executions_schedule ON job_executions (job_schedule_id); CREATE INDEX idx_job_executions_project ON job_executions (project_code, created_at DESC); +CREATE INDEX idx_job_executions_project_completed + ON job_executions (project_code, completed_at); CREATE TABLE settings ( key VARCHAR(50) NOT NULL PRIMARY KEY, diff --git a/testgen/template/dbsetup/040_populate_new_schema_project.sql b/testgen/template/dbsetup/040_populate_new_schema_project.sql index 36f6a30c..204b84c2 100644 --- a/testgen/template/dbsetup/040_populate_new_schema_project.sql +++ b/testgen/template/dbsetup/040_populate_new_schema_project.sql @@ -7,16 +7,29 @@ SELECT '{PROJECT_CODE}' as project_code, '{OBSERVABILITY_API_KEY}' as observability_api_key, '{OBSERVABILITY_API_URL}' as observability_api_url; +-- Seed the data retention schedule so the default project's cleanup job +-- runs out of the box (matches the column defaults: enabled, 180 days). +INSERT INTO job_schedules + (id, project_code, key, kwargs, cron_expr, cron_tz, active) +SELECT gen_random_uuid(), + '{PROJECT_CODE}', + 'run-data-cleanup', + jsonb_build_object('project_code', '{PROJECT_CODE}', 'retention_days', 180), + '0 1 * * *', + 'UTC', + TRUE; + WITH inserted_user AS ( INSERT INTO auth_users - (username, email, name, password, is_global_admin) + (username, email, name, password, is_global_admin, preferences) SELECT '{UI_USER_USERNAME}' as username, '{UI_USER_EMAIL}' as email, '{UI_USER_NAME}' as name, '{UI_USER_ENCRYPTED_PASSWORD}' as password, - true as is_global_admin + true as is_global_admin, + jsonb_build_object('last_feedback_popup', '{LAST_FEEDBACK_POPUP_SEED}') as preferences RETURNING id ) INSERT INTO project_memberships diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Boolean_Value_Mismatch.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Boolean_Value_Mismatch.yaml index c35be242..c6574673 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Boolean_Value_Mismatch.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Boolean_Value_Mismatch.yaml @@ -101,3 +101,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10047' + test_id: 1015 + test_type: Boolean_Value_Mismatch + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Date_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Date_Values.yaml index a4a44110..8afd6b10 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Date_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Date_Values.yaml @@ -112,3 +112,11 @@ profile_anomaly_types: lookup_query: |- SELECT A.* FROM (SELECT DISTINCT 'Date' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_DATE;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) A UNION ALL SELECT B.* FROM (SELECT DISTINCT 'Non-Date' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_DATE;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) B ORDER BY data_type, count DESC error_type: Profile Anomaly + - id: '10048' + test_id: 1012 + test_type: Char_Column_Date_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT A.* FROM ( SELECT DISTINCT 'Date' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_DATE;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS A UNION ALL SELECT B.* FROM ( SELECT DISTINCT 'Non-Date' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_DATE;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS B ORDER BY data_type, count DESC; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Number_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Number_Values.yaml index e23891b6..59504d2f 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Number_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Number_Values.yaml @@ -112,3 +112,11 @@ profile_anomaly_types: lookup_query: |- SELECT A.* FROM (SELECT DISTINCT 'Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) A UNION ALL SELECT B.* FROM (SELECT DISTINCT 'Non-Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) B ORDER BY data_type, count DESC error_type: Profile Anomaly + - id: '10049' + test_id: 1011 + test_type: Char_Column_Number_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT A.* FROM ( SELECT DISTINCT 'Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS A UNION ALL SELECT B.* FROM ( SELECT DISTINCT 'Non-Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS B ORDER BY data_type, count DESC; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml index 87441e8a..f6947302 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml @@ -15,9 +15,9 @@ profile_anomaly_types: AND STRPOS(p.top_patterns, 'N') > 0 AND ( ( (STRPOS(p.top_patterns, 'A') > 0 OR STRPOS(p.top_patterns, 'a') > 0) - AND SPLIT_PART(p.top_patterns, '|', 3)::NUMERIC / SPLIT_PART(p.top_patterns, '|', 1)::NUMERIC < 0.05) + AND NULLIF(SPLIT_PART(p.top_patterns, '|', 3), '')::NUMERIC / NULLIF(SPLIT_PART(p.top_patterns, '|', 1), '')::NUMERIC < 0.05) OR - SPLIT_PART(p.top_patterns, '|', 3)::NUMERIC / SPLIT_PART(p.top_patterns, '|', 1)::NUMERIC < 0.1 + NULLIF(SPLIT_PART(p.top_patterns, '|', 3), '')::NUMERIC / NULLIF(SPLIT_PART(p.top_patterns, '|', 1), '')::NUMERIC < 0.1 ) detail_expression: |- 'Patterns: ' || p.top_patterns @@ -25,7 +25,7 @@ profile_anomaly_types: suggested_action: |- Review the values for any data that doesn't conform to the most common pattern and correct any data errors. dq_score_prevalence_formula: |- - (p.record_ct - SPLIT_PART(p.top_patterns, '|', 1)::BIGINT)::FLOAT/NULLIF(p.record_ct, 0)::FLOAT + (p.record_ct - NULLIF(SPLIT_PART(p.top_patterns, '|', 1), '')::BIGINT)::FLOAT/NULLIF(p.record_ct, 0)::FLOAT dq_score_risk_factor: '0.66' dq_dimension: Validity impact_dimension: Usability @@ -99,7 +99,7 @@ profile_anomaly_types: sql_flavor: postgresql lookup_type: null lookup_query: |- - SELECT A.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 4)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( "{COLUMN_NAME}", '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) A UNION ALL SELECT B.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 6)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( "{COLUMN_NAME}", '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) B UNION ALL SELECT C.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 8)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( "{COLUMN_NAME}", '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) C UNION ALL SELECT D.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 10)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( "{COLUMN_NAME}", '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) D ORDER BY top_pattern DESC, count DESC; + SELECT A.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 4)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) A UNION ALL SELECT B.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 6)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) B UNION ALL SELECT C.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 8)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) C UNION ALL SELECT D.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 10)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) D ORDER BY top_pattern DESC, count DESC; error_type: Profile Anomaly - id: '1039' test_id: '1007' @@ -141,3 +141,11 @@ profile_anomaly_types: lookup_query: |- SELECT A.* FROM (SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT TRIM(SUBSTR_REGEXPR('[^|]+' IN '{DETAIL_EXPRESSION}' OCCURRENCE 4)) AS top_pattern FROM DUMMY) b WHERE REPLACE_REGEXPR('[0-9]' IN REPLACE_REGEXPR('[A-Z]' IN REPLACE_REGEXPR('[a-z]' IN "{COLUMN_NAME}" WITH 'a') WITH 'A') WITH 'N') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) A UNION ALL SELECT B.* FROM (SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT TRIM(SUBSTR_REGEXPR('[^|]+' IN '{DETAIL_EXPRESSION}' OCCURRENCE 6)) AS top_pattern FROM DUMMY) b WHERE REPLACE_REGEXPR('[0-9]' IN REPLACE_REGEXPR('[A-Z]' IN REPLACE_REGEXPR('[a-z]' IN "{COLUMN_NAME}" WITH 'a') WITH 'A') WITH 'N') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) B UNION ALL SELECT C.* FROM (SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT TRIM(SUBSTR_REGEXPR('[^|]+' IN '{DETAIL_EXPRESSION}' OCCURRENCE 8)) AS top_pattern FROM DUMMY) b WHERE REPLACE_REGEXPR('[0-9]' IN REPLACE_REGEXPR('[A-Z]' IN REPLACE_REGEXPR('[a-z]' IN "{COLUMN_NAME}" WITH 'a') WITH 'A') WITH 'N') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) C UNION ALL SELECT D.* FROM (SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT TRIM(SUBSTR_REGEXPR('[^|]+' IN '{DETAIL_EXPRESSION}' OCCURRENCE 10)) AS top_pattern FROM DUMMY) b WHERE REPLACE_REGEXPR('[0-9]' IN REPLACE_REGEXPR('[A-Z]' IN REPLACE_REGEXPR('[a-z]' IN "{COLUMN_NAME}" WITH 'a') WITH 'A') WITH 'N') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) D ORDER BY top_pattern DESC, count DESC error_type: Profile Anomaly + - id: '10050' + test_id: 1007 + test_type: Column_Pattern_Mismatch + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT A.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 4)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) A UNION ALL SELECT B.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 6)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) B UNION ALL SELECT C.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 8)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) C UNION ALL SELECT D.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 10)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) D ORDER BY top_pattern DESC, count DESC; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Delimited_Data_Embedded.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Delimited_Data_Embedded.yaml index 570ed5ad..ef4853e6 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Delimited_Data_Embedded.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Delimited_Data_Embedded.yaml @@ -96,3 +96,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" LIKE_REGEXPR '^([^,|' || NCHAR(9) || ']{1,20}[,|' || NCHAR(9) || ']){2,}[^,|' || NCHAR(9) || ']{0,20}([,|' || NCHAR(9) || ']{0,1}[^,|' || NCHAR(9) || ']{0,20})*$' AND NOT "{COLUMN_NAME}" LIKE_REGEXPR '[[:space:]](and|but|or|yet)[[:space:]]' GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10051' + test_id: 1025 + test_type: Delimited_Data_Embedded + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE REGEXP_LIKE(CAST("{COLUMN_NAME}" AS VARCHAR), '^([^,|\t]{1,20}[,|\t]){2,}[^,|\t]{0,20}([,|\t]{0,1}[^,|\t]{0,20})*$') AND NOT REGEXP_LIKE(CAST("{COLUMN_NAME}" AS VARCHAR), '\s(and|but|or|yet)\s') GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Inconsistent_Casing.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Inconsistent_Casing.yaml index 176a3565..6c8d7156 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Inconsistent_Casing.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Inconsistent_Casing.yaml @@ -141,3 +141,17 @@ profile_anomaly_types: lookup_query: |- SELECT * FROM (SELECT 'Upper Case' as casing, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE UPPER("{COLUMN_NAME}") = "{COLUMN_NAME}" GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT_2}) UNION ALL SELECT * FROM (SELECT 'Mixed Case' as casing, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" <> UPPER("{COLUMN_NAME}") AND "{COLUMN_NAME}" <> LOWER("{COLUMN_NAME}") GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT_2}) error_type: Profile Anomaly + - id: '10052' + test_id: 1028 + test_type: Inconsistent_Casing + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + (SELECT 'Upper Case' as casing, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" + WHERE UPPER("{COLUMN_NAME}") = "{COLUMN_NAME}" + GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT_2}) + UNION ALL + (SELECT 'Mixed Case' as casing, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" + WHERE "{COLUMN_NAME}" <> UPPER("{COLUMN_NAME}") AND "{COLUMN_NAME}" <> LOWER("{COLUMN_NAME}") + GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT_2}) + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml index ed042ca9..723b3862 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml @@ -9,14 +9,14 @@ profile_anomaly_types: p.distinct_pattern_ct > 1 AND (p.column_name ilike '%zip%' OR p.column_name ILIKE '%postal%') AND SPLIT_PART(p.top_patterns, ' | ', 2) = 'NNN' - AND SPLIT_PART(p.top_patterns, ' | ', 1)::FLOAT/NULLIF(value_ct, 0)::FLOAT > 0.50 + AND NULLIF(SPLIT_PART(p.top_patterns, ' | ', 1), '')::FLOAT/NULLIF(value_ct, 0)::FLOAT > 0.50 detail_expression: |- 'Pattern: ' || p.top_patterns issue_likelihood: Definite suggested_action: |- Review your source data, ingestion process, and any processing steps that update this column. dq_score_prevalence_formula: |- - (NULLIF(p.record_ct, 0)::INT - SPLIT_PART(p.top_patterns, ' | ', 1)::BIGINT)::FLOAT/NULLIF(p.record_ct, 0)::FLOAT + (NULLIF(p.record_ct, 0)::INT - NULLIF(SPLIT_PART(p.top_patterns, ' | ', 1), '')::BIGINT)::FLOAT/NULLIF(p.record_ct, 0)::FLOAT dq_score_risk_factor: '1' dq_dimension: Validity impact_dimension: Conformance @@ -98,3 +98,11 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE_REGEXPR('[0-9]' IN "{COLUMN_NAME}" WITH '9') <> '999' GROUP BY "{COLUMN_NAME}" ORDER BY count DESC, "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10053' + test_id: 1024 + test_type: Invalid_Zip3_USA + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE REGEXP_REPLACE(CAST("{COLUMN_NAME}" AS VARCHAR), '[0-8]', '9', 'g') <> '999' GROUP BY "{COLUMN_NAME}" ORDER BY count DESC, "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip_USA.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip_USA.yaml index 2e13f4a8..e9b094ec 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip_USA.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip_USA.yaml @@ -94,3 +94,11 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE_REGEXPR('[0-9]' IN "{COLUMN_NAME}" WITH '9') NOT IN ('99999', '999999999', '99999-9999') GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10054' + test_id: 1003 + test_type: Invalid_Zip_USA + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE REGEXP_REPLACE(CAST("{COLUMN_NAME}" AS VARCHAR), '[0-8]', '9', 'g') NOT IN ('99999', '999999999', '99999-9999') GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Leading_Spaces.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Leading_Spaces.yaml index d63a84e1..05ee4c0c 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Leading_Spaces.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Leading_Spaces.yaml @@ -94,3 +94,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" BETWEEN ' !' AND '!' THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10055' + test_id: 1009 + test_type: Leading_Spaces + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" BETWEEN ' !' AND '!' THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Major.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Major.yaml index cb4fa797..3633cc99 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Major.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Major.yaml @@ -108,3 +108,10 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT COLUMN_NAME, TABLE_NAME, CASE WHEN DATA_TYPE_NAME LIKE 'TIMESTAMP%%' THEN LOWER(DATA_TYPE_NAME) WHEN DATA_TYPE_NAME = 'DATE' THEN 'date' WHEN DATA_TYPE_NAME IN ('NVARCHAR', 'VARCHAR') THEN LOWER(DATA_TYPE_NAME) || '(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'CHAR' THEN 'char(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'DECIMAL' AND SCALE = 0 THEN 'decimal(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'DECIMAL' THEN 'decimal(' || LENGTH || ',' || SCALE || ')' WHEN DATA_TYPE_NAME IN ('INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT') THEN LOWER(DATA_TYPE_NAME) ELSE LOWER(DATA_TYPE_NAME) END AS data_type FROM SYS.TABLE_COLUMNS WHERE SCHEMA_NAME = '{TARGET_SCHEMA}' AND COLUMN_NAME = '{COLUMN_NAME}' ORDER BY data_type, TABLE_NAME LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10056' + test_id: 1005 + test_type: Multiple_Types_Major + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: null + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Minor.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Minor.yaml index 80e9fa80..eb1555ad 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Minor.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Minor.yaml @@ -108,3 +108,10 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT COLUMN_NAME, TABLE_NAME, CASE WHEN DATA_TYPE_NAME LIKE 'TIMESTAMP%%' THEN LOWER(DATA_TYPE_NAME) WHEN DATA_TYPE_NAME = 'DATE' THEN 'date' WHEN DATA_TYPE_NAME IN ('NVARCHAR', 'VARCHAR') THEN LOWER(DATA_TYPE_NAME) || '(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'CHAR' THEN 'char(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'DECIMAL' AND SCALE = 0 THEN 'decimal(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'DECIMAL' THEN 'decimal(' || LENGTH || ',' || SCALE || ')' WHEN DATA_TYPE_NAME IN ('INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT') THEN LOWER(DATA_TYPE_NAME) ELSE LOWER(DATA_TYPE_NAME) END AS data_type FROM SYS.TABLE_COLUMNS WHERE SCHEMA_NAME = '{TARGET_SCHEMA}' AND COLUMN_NAME = '{COLUMN_NAME}' ORDER BY data_type, TABLE_NAME LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10057' + test_id: 1004 + test_type: Multiple_Types_Minor + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: null + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_No_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_No_Values.yaml index 46c6f955..7f3835b4 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_No_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_No_Values.yaml @@ -96,3 +96,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10058' + test_id: 1006 + test_type: No_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Name_Address.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Name_Address.yaml index 820e6423..29d1aff7 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Name_Address.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Name_Address.yaml @@ -108,3 +108,13 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) as record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" = UPPER("{COLUMN_NAME}") AND "{COLUMN_NAME}" = LOWER("{COLUMN_NAME}") AND "{COLUMN_NAME}" IS NOT NULL GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10059' + test_id: 1029 + test_type: Non_Alpha_Name_Address + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) as record_ct FROM "{TABLE_NAME}" + WHERE "{COLUMN_NAME}" = UPPER("{COLUMN_NAME}") AND "{COLUMN_NAME}" = LOWER("{COLUMN_NAME}") AND "{COLUMN_NAME}" > '' + GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Prefixed_Name.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Prefixed_Name.yaml index 22ed1cd9..78c96148 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Prefixed_Name.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Prefixed_Name.yaml @@ -110,3 +110,13 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) as record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" < 'A' AND SUBSTR("{COLUMN_NAME}", 1, 1) NOT IN ('"', ' ') AND SUBSTR("{COLUMN_NAME}", -1, 1) <> '''' GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10060' + test_id: 1030 + test_type: Non_Alpha_Prefixed_Name + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) as record_ct FROM "{TABLE_NAME}" + WHERE "{COLUMN_NAME}" < 'A' AND SUBSTR("{COLUMN_NAME}", 1, 1) NOT IN ('"', ' ') AND SUBSTRING("{COLUMN_NAME}", LENGTH("{COLUMN_NAME}")) <> '''' + GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Printing_Chars.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Printing_Chars.yaml index 34821875..e8291e15 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Printing_Chars.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Printing_Chars.yaml @@ -161,3 +161,34 @@ profile_anomaly_types: lookup_query: |- SELECT REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", NCHAR(160), '\x160'), NCHAR(8201), '\x8201'), NCHAR(8203), '\x8203'), NCHAR(8204), '\x8204'), NCHAR(8205), '\x8205'), NCHAR(8206), '\x8206'), NCHAR(8207), '\x8207'), NCHAR(8239), '\x8239'), NCHAR(12288), '\x12288'), NCHAR(65279), '\x65279') as "{COLUMN_NAME}", COUNT(*) as record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", NCHAR(160), ''), NCHAR(8201), ''), NCHAR(8203), ''), NCHAR(8204), ''), NCHAR(8205), ''), NCHAR(8206), ''), NCHAR(8207), ''), NCHAR(8239), ''), NCHAR(12288), ''), NCHAR(65279), '') <> "{COLUMN_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10061' + test_id: 1031 + test_type: Non_Printing_Chars + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", + CHR(160), '\x160'), + CHR(8201), '\x8201'), + CHR(8203), '\x8203'), + CHR(8204), '\x8204'), + CHR(8205), '\x8205'), + CHR(8206), '\x8206'), + CHR(8207), '\x8207'), + CHR(8239), '\x8239'), + CHR(12288), '\x12288'), + CHR(65279), '\x65279') as "{COLUMN_NAME}", + COUNT(*) as record_ct FROM "{TABLE_NAME}" + WHERE REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", + CHR(160), ''), + CHR(8201), ''), + CHR(8203), ''), + CHR(8204), ''), + CHR(8205), ''), + CHR(8206), ''), + CHR(8207), ''), + CHR(8239), ''), + CHR(12288), ''), + CHR(65279), '') <> "{COLUMN_NAME}" + GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Standard_Blanks.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Standard_Blanks.yaml index 4e1c104b..e6bfb600 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Standard_Blanks.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Standard_Blanks.yaml @@ -107,3 +107,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE CASE WHEN "{COLUMN_NAME}" IN ('.', '?', ' ') THEN 1 WHEN LOWER("{COLUMN_NAME}") LIKE_REGEXPR '(-{2,}|0{2,}|9{2,}|x{2,}|z{2,})' THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('blank','error','missing','tbd', 'n/a','#na','none','null','unknown') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('(blank)','(error)','(missing)','(tbd)', '(n/a)','(#na)','(none)','(null)','(unknown)') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('[blank]','[error]','[missing]','[tbd]', '[n/a]','[#na]','[none]','[null]','[unknown]') THEN 1 WHEN "{COLUMN_NAME}" IS NULL THEN 1 ELSE 0 END = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10062' + test_id: 1002 + test_type: Non_Standard_Blanks + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE CASE WHEN "{COLUMN_NAME}" IN ('.', '?', ' ') THEN 1 WHEN REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '-{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '0{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '9{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), 'x{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), 'z{2,}') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('blank','error','missing','tbd', 'n/a','#na','none','null','unknown') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('(blank)','(error)','(missing)','(tbd)', '(n/a)','(#na)','(none)','(null)','(unknown)') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('[blank]','[error]','[missing]','[tbd]', '[n/a]','[#na]','[none]','[null]','[unknown]') THEN 1 WHEN "{COLUMN_NAME}" = '' THEN 1 WHEN "{COLUMN_NAME}" IS NULL THEN 1 ELSE 0 END = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_Duplicates.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_Duplicates.yaml index 46383270..613f5571 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_Duplicates.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_Duplicates.yaml @@ -96,3 +96,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" HAVING COUNT(*) > 1 ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10063' + test_id: 1016 + test_type: Potential_Duplicates + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" HAVING COUNT(*)> 1 ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_PII.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_PII.yaml index c33bfae9..492062d9 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_PII.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_PII.yaml @@ -94,3 +94,11 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10064' + test_id: 1100 + test_type: Potential_PII + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Quoted_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Quoted_Values.yaml index b7ac31bc..a315b5ed 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Quoted_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Quoted_Values.yaml @@ -95,3 +95,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" LIKE '"%%"' OR "{COLUMN_NAME}" LIKE '''%%''' THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10065' + test_id: 1010 + test_type: Quoted_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" LIKE '"%"' OR "{COLUMN_NAME}" LIKE '''%''' THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_One_Year.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_One_Year.yaml index d24286ca..40e4c02f 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_One_Year.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_One_Year.yaml @@ -90,3 +90,10 @@ profile_anomaly_types: lookup_query: |- created_in_ui error_type: Profile Anomaly + - id: '10066' + test_id: 1019 + test_type: Recency_One_Year + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: created_in_ui + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_Six_Months.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_Six_Months.yaml index a94f7474..0eafe386 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_Six_Months.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_Six_Months.yaml @@ -90,3 +90,10 @@ profile_anomaly_types: lookup_query: |- created_in_ui error_type: Profile Anomaly + - id: '10067' + test_id: 1020 + test_type: Recency_Six_Months + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: created_in_ui + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Divergent_Value_Ct.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Divergent_Value_Ct.yaml index 25c6065a..b8e5ec4d 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Divergent_Value_Ct.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Divergent_Value_Ct.yaml @@ -87,3 +87,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10068' + test_id: 1014 + test_type: Small Divergent Value Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Missing_Value_Ct.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Missing_Value_Ct.yaml index b8093ab0..3e6fb266 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Missing_Value_Ct.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Missing_Value_Ct.yaml @@ -90,3 +90,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" IN ('.', '?', ' ') THEN 1 WHEN LOWER("{COLUMN_NAME}") LIKE_REGEXPR '(-{2,}|0{2,}|9{2,}|x{2,}|z{2,})' THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('blank','error','missing','tbd', 'n/a','#na','none','null','unknown') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('(blank)','(error)','(missing)','(tbd)', '(n/a)','(#na)','(none)','(null)','(unknown)') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('[blank]','[error]','[missing]','[tbd]', '[n/a]','[#na]','[none]','[null]','[unknown]') THEN 1 WHEN "{COLUMN_NAME}" IS NULL THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10069' + test_id: 1013 + test_type: Small Missing Value Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" IN ('.', '?', ' ') THEN 1 WHEN REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '-{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '0{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '9{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), 'x{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), 'z{2,}') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('blank','error','missing','tbd', 'n/a','#na','none','null','unknown') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('(blank)','(error)','(missing)','(tbd)', '(n/a)','(#na)','(none)','(null)','(unknown)') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('[blank]','[error]','[missing]','[tbd]', '[n/a]','[#na]','[none]','[null]','[unknown]') THEN 1 WHEN "{COLUMN_NAME}" = '' THEN 1 WHEN "{COLUMN_NAME}" IS NULL THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Numeric_Value_Ct.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Numeric_Value_Ct.yaml index 0b868784..5249c9d6 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Numeric_Value_Ct.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Numeric_Value_Ct.yaml @@ -109,3 +109,11 @@ profile_anomaly_types: lookup_query: |- SELECT A.* FROM (SELECT DISTINCT 'Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) A UNION ALL SELECT B.* FROM (SELECT DISTINCT 'Non-Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) B ORDER BY data_type, count DESC error_type: Profile Anomaly + - id: '10070' + test_id: 1023 + test_type: Small_Numeric_Value_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT A.* FROM ( SELECT DISTINCT 'Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS A UNION ALL SELECT B.* FROM ( SELECT DISTINCT 'Non-Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS B ORDER BY data_type, count DESC; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Standardized_Value_Matches.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Standardized_Value_Matches.yaml index 870862a4..4e4f43ad 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Standardized_Value_Matches.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Standardized_Value_Matches.yaml @@ -104,3 +104,11 @@ profile_anomaly_types: lookup_query: |- WITH CTE AS ( SELECT DISTINCT UPPER(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", ' ', ''), '''', ''), ',', ''), '.', ''), '-', '')) as possible_standard_value, COUNT(DISTINCT "{COLUMN_NAME}") AS cnt FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY UPPER(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", ' ', ''), '''', ''), ',', ''), '.', ''), '-', '')) HAVING COUNT(DISTINCT "{COLUMN_NAME}") > 1 ) SELECT a."{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" a, cte b WHERE UPPER(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(a."{COLUMN_NAME}", ' ', ''), '''', ''), ',', ''), '.', ''), '-', '')) = b.possible_standard_value GROUP BY a."{COLUMN_NAME}" ORDER BY UPPER(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(a."{COLUMN_NAME}", ' ', ''), '''', ''), ',', ''), '.', ''), '-', '')) ASC, count DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10071' + test_id: 1017 + test_type: Standardized_Value_Matches + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH CTE AS ( SELECT DISTINCT UPPER(REGEXP_REPLACE("{COLUMN_NAME}", '[ '',.\-]', '', 'g')) as possible_standard_value, COUNT(DISTINCT "{COLUMN_NAME}") FROM "{TABLE_NAME}" GROUP BY UPPER(REGEXP_REPLACE("{COLUMN_NAME}", '[ '',.\-]', '', 'g')) HAVING COUNT(DISTINCT "{COLUMN_NAME}") > 1 ) SELECT a."{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" a, cte b WHERE UPPER(REGEXP_REPLACE(a."{COLUMN_NAME}", '[ '',.\-]', '', 'g')) = b.possible_standard_value GROUP BY a."{COLUMN_NAME}" ORDER BY UPPER(REGEXP_REPLACE(a."{COLUMN_NAME}", '[ '',.\-]', '', 'g')) ASC, count DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Suggested_Type.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Suggested_Type.yaml index b623888b..812c6b95 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Suggested_Type.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Suggested_Type.yaml @@ -95,3 +95,11 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10072' + test_id: 1001 + test_type: Suggested_Type + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml index d72d9875..319c7e3d 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml @@ -13,7 +13,7 @@ profile_anomaly_types: AND m.max_pattern_ct = 1 AND m.column_ct > 1 AND SPLIT_PART(p.top_patterns, '|', 2) <> SPLIT_PART(m.very_top_pattern, '|', 2) - AND SPLIT_PART(p.top_patterns, '|', 1)::NUMERIC / SPLIT_PART(m.very_top_pattern, '|', 1)::NUMERIC < 0.1 + AND NULLIF(SPLIT_PART(p.top_patterns, '|', 1), '')::NUMERIC / NULLIF(SPLIT_PART(m.very_top_pattern, '|', 1), '')::NUMERIC < 0.1 detail_expression: |- 'Patterns: ' || SPLIT_PART(p.top_patterns, '|', 2) || ', ' || SPLIT_PART(ltrim(m.very_top_pattern, '0'), '|', 2) issue_likelihood: Likely @@ -105,3 +105,10 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT COLUMN_NAME, TABLE_NAME FROM SYS.TABLE_COLUMNS WHERE SCHEMA_NAME = '{TARGET_SCHEMA}' AND COLUMN_NAME = '{COLUMN_NAME}' ORDER BY TABLE_NAME LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10073' + test_id: 1008 + test_type: Table_Pattern_Mismatch + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: null + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_Emails.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_Emails.yaml index 9c9dd4f8..f939cf78 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_Emails.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_Emails.yaml @@ -95,3 +95,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10074' + test_id: 1022 + test_type: Unexpected Emails + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_US_States.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_US_States.yaml index b86117ab..669c5259 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_US_States.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_US_States.yaml @@ -97,3 +97,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10075' + test_id: 1021 + test_type: Unexpected US States + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unlikely_Date_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unlikely_Date_Values.yaml index c5f9c540..f75eb93f 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unlikely_Date_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unlikely_Date_Values.yaml @@ -99,3 +99,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", TO_DATE('{PROFILE_RUN_DATE}', 'YYYY-MM-DD') AS profile_run_date, COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" a WHERE ("{COLUMN_NAME}" < TO_DATE('1900-01-01', 'YYYY-MM-DD')) OR ("{COLUMN_NAME}" > ADD_MONTHS(TO_DATE('{PROFILE_RUN_DATE}', 'YYYY-MM-DD'), 360)) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10076' + test_id: 1018 + test_type: Unlikely_Date_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", CAST('{PROFILE_RUN_DATE}' AS DATE) AS profile_run_date, COUNT(*) AS count FROM "{TABLE_NAME}" a WHERE ("{COLUMN_NAME}" < CAST('1900-01-01' AS DATE)) OR ("{COLUMN_NAME}" > CAST('{PROFILE_RUN_DATE}' AS DATE) + INTERVAL '30 year' ) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Variant_Coded_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Variant_Coded_Values.yaml index 72265501..e252f4c7 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Variant_Coded_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Variant_Coded_Values.yaml @@ -98,3 +98,11 @@ profile_anomaly_types: lookup_query: |- WITH val_list(token, remaining) AS ( SELECT CASE WHEN LOCATE(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), '|') > 0 THEN TRIM(SUBSTR(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), 1, LOCATE(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), '|') - 1)) ELSE TRIM(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2)) END AS token, CASE WHEN LOCATE(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), '|') > 0 THEN SUBSTR(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), LOCATE(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), '|') + 1) ELSE '' END AS remaining FROM DUMMY UNION ALL SELECT CASE WHEN LOCATE(remaining, '|') > 0 THEN TRIM(SUBSTR(remaining, 1, LOCATE(remaining, '|') - 1)) ELSE TRIM(remaining) END AS token, CASE WHEN LOCATE(remaining, '|') > 0 THEN SUBSTR(remaining, LOCATE(remaining, '|') + 1) ELSE '' END AS remaining FROM val_list WHERE LENGTH(remaining) > 0 ) SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE LOWER("{COLUMN_NAME}") IN (SELECT token FROM val_list) GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10077' + test_id: 1027 + test_type: Variant_Coded_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE LOWER("{COLUMN_NAME}") IN (SELECT TRIM(val) FROM UNNEST(STRING_TO_ARRAY(SUBSTRING('{DETAIL_EXPRESSION}', STRPOS('{DETAIL_EXPRESSION}', ':') + 2), '|')) AS t(val)) GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml index 89882477..f38b89f4 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml @@ -56,7 +56,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -83,7 +83,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -108,7 +108,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -132,7 +132,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -157,7 +157,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -182,7 +182,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -207,7 +207,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -232,7 +232,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -257,7 +257,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -266,6 +266,31 @@ test_types: ORDER BY {GROUPBY_NAMES} LIMIT {LIMIT} error_type: Test Results + - id: '10001' + test_id: 1500 + test_type: Aggregate_Balance + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) AS total, SUM(MATCH_TOTAL) AS MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} AS total, NULL AS match_total + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total + FROM "{MATCH_TABLE_NAME}" + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE total <> match_total OR (total IS NOT NULL AND match_total IS NULL) OR (total IS NULL AND match_total IS NOT NULL) + ORDER BY {GROUPBY_NAMES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2506' test_type: Aggregate_Balance @@ -698,3 +723,52 @@ test_types: WHERE total <> match_total OR (total IS NOT NULL AND match_total IS NULL) OR (total IS NULL AND match_total IS NOT NULL) + - id: '10001' + test_type: Aggregate_Balance + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) as total, SUM(MATCH_TOTAL) as MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} as total, NULL as match_total + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE total <> match_total + OR (total IS NOT NULL AND match_total IS NULL) + OR (total IS NULL AND match_total IS NOT NULL); diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml index b15b0114..1415731d 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml @@ -56,7 +56,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -85,7 +85,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -112,7 +112,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -138,7 +138,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -165,7 +165,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -192,7 +192,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -219,7 +219,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -246,7 +246,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -273,7 +273,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -284,6 +284,33 @@ test_types: ORDER BY {GROUPBY_NAMES} LIMIT {LIMIT} error_type: Test Results + - id: '10002' + test_id: 1504 + test_type: Aggregate_Balance_Percent + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) AS total, SUM(MATCH_TOTAL) AS MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} AS total, NULL AS match_total + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total + FROM "{MATCH_TABLE_NAME}" + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE (total IS NOT NULL AND match_total IS NULL) + OR (total IS NULL AND match_total IS NOT NULL) + OR (total NOT BETWEEN match_total * (1 + {LOWER_TOLERANCE}/100.0) AND match_total * (1 + {UPPER_TOLERANCE}/100.0)) + ORDER BY {GROUPBY_NAMES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2509' test_type: Aggregate_Balance_Percent @@ -716,3 +743,52 @@ test_types: WHERE (total IS NOT NULL AND match_total IS NULL) OR (total IS NULL AND match_total IS NOT NULL) OR (total NOT BETWEEN match_total * (1 + {LOWER_TOLERANCE}/100.0) AND match_total * (1 + {UPPER_TOLERANCE}/100.0)) + - id: '10002' + test_type: Aggregate_Balance_Percent + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) as total, SUM(MATCH_TOTAL) as MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} as total, NULL as match_total + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE (total IS NOT NULL AND match_total IS NULL) + OR (total IS NULL AND match_total IS NOT NULL) + OR (total NOT BETWEEN match_total * (1 + {LOWER_TOLERANCE}/100.0) AND match_total * (1 + {UPPER_TOLERANCE}/100.0)); diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml index 1fe4cdc4..84f20602 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml @@ -56,7 +56,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -85,7 +85,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -112,7 +112,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -138,7 +138,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -165,7 +165,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -192,7 +192,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -219,7 +219,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -246,7 +246,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -273,7 +273,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -284,6 +284,33 @@ test_types: ORDER BY {GROUPBY_NAMES} LIMIT {LIMIT} error_type: Test Results + - id: '10003' + test_id: 1505 + test_type: Aggregate_Balance_Range + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) AS total, SUM(MATCH_TOTAL) AS MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} AS total, NULL AS match_total + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total + FROM "{MATCH_TABLE_NAME}" + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE (total IS NOT NULL AND match_total IS NULL) + OR (total IS NULL AND match_total IS NOT NULL) + OR (total NOT BETWEEN match_total + {LOWER_TOLERANCE} AND match_total + {UPPER_TOLERANCE}) + ORDER BY {GROUPBY_NAMES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2510' test_type: Aggregate_Balance_Range @@ -716,3 +743,52 @@ test_types: WHERE (total IS NOT NULL AND match_total IS NULL) OR (total IS NULL AND match_total IS NOT NULL) OR (total NOT BETWEEN match_total + {LOWER_TOLERANCE} AND match_total + {UPPER_TOLERANCE}) + - id: '10003' + test_type: Aggregate_Balance_Range + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) as total, SUM(MATCH_TOTAL) as MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} as total, NULL as match_total + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE (total IS NOT NULL AND match_total IS NULL) + OR (total IS NULL AND match_total IS NOT NULL) + OR (total NOT BETWEEN match_total + {LOWER_TOLERANCE} AND match_total + {UPPER_TOLERANCE}); diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml index 8607dec0..425a72e2 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml @@ -56,7 +56,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -83,7 +83,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -108,7 +108,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -132,7 +132,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -157,7 +157,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -182,7 +182,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -207,7 +207,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -232,7 +232,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -257,7 +257,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -266,6 +266,31 @@ test_types: ORDER BY {GROUPBY_NAMES} LIMIT {LIMIT} error_type: Test Results + - id: '10004' + test_id: 1501 + test_type: Aggregate_Minimum + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) as total, SUM(MATCH_TOTAL) as MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} as total, NULL as match_total + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total + FROM "{MATCH_TABLE_NAME}" + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE total < match_total OR (total IS NULL AND match_total IS NOT NULL) + ORDER BY {GROUPBY_NAMES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2502' test_type: Aggregate_Minimum @@ -698,3 +723,52 @@ test_types: WHERE total < match_total -- OR (total IS NOT NULL AND match_total IS NULL) -- New categories OR (total IS NULL AND match_total IS NOT NULL) + - id: '10004' + test_type: Aggregate_Minimum + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) as total, SUM(MATCH_TOTAL) as MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} as total, NULL as match_total + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE total < match_total + -- OR (total IS NOT NULL AND match_total IS NULL) -- New categories + OR (total IS NULL AND match_total IS NOT NULL); -- Dropped categories diff --git a/testgen/template/dbsetup_test_types/test_types_Alpha_Trunc.yaml b/testgen/template/dbsetup_test_types/test_types_Alpha_Trunc.yaml index 41ab1ab7..88577f60 100644 --- a/testgen/template/dbsetup_test_types/test_types_Alpha_Trunc.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Alpha_Trunc.yaml @@ -117,6 +117,14 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10001' + test_type: Alpha_Trunc + sql_flavor: salesforce_data360 + measure: |- + MAX(LENGTH({COLUMN_NAME})) + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1364' test_id: '1004' @@ -197,4 +205,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", LENGTH("{COLUMN_NAME}") as current_max_length, {THRESHOLD_VALUE} as previous_max_length FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT MAX(LENGTH("{COLUMN_NAME}")) as max_length FROM "{TARGET_SCHEMA}"."{TABLE_NAME}") a WHERE LENGTH("{COLUMN_NAME}") = a.max_length AND a.max_length < {THRESHOLD_VALUE} LIMIT {LIMIT} error_type: Test Results + - id: '10005' + test_id: 1004 + test_type: Alpha_Trunc + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", LENGTH("{COLUMN_NAME}") as current_max_length, {THRESHOLD_VALUE} as previous_max_length FROM "{TABLE_NAME}", (SELECT MAX(LENGTH("{COLUMN_NAME}")) as max_length FROM "{TABLE_NAME}") a WHERE LENGTH("{COLUMN_NAME}") = a.max_length AND a.max_length < {THRESHOLD_VALUE} LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Avg_Shift.yaml b/testgen/template/dbsetup_test_types/test_types_Avg_Shift.yaml index 49a3c5b9..08f801a7 100644 --- a/testgen/template/dbsetup_test_types/test_types_Avg_Shift.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Avg_Shift.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>=' test_condition: |- {THRESHOLD_VALUE} + - id: '10002' + test_type: Avg_Shift + sql_flavor: salesforce_data360 + measure: |- + ABS( (AVG(CAST({COLUMN_NAME} AS FLOAT)) - {BASELINE_AVG}) / SQRT(((CAST(COUNT({COLUMN_NAME}) AS FLOAT)-1)*POWER(STDDEV(CAST({COLUMN_NAME} AS FLOAT)),2) + (CAST({BASELINE_VALUE_CT} AS FLOAT)-1) * POWER(CAST({BASELINE_SD} AS FLOAT),2)) /NULLIF(CAST(COUNT({COLUMN_NAME}) AS FLOAT) + CAST({BASELINE_VALUE_CT} AS FLOAT), 0) )) + test_operator: '>=' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1365' test_id: '1005' @@ -192,4 +200,12 @@ test_types: lookup_query: |- SELECT AVG(CAST("{COLUMN_NAME}" AS DECIMAL)) AS current_average FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10006' + test_id: 1005 + test_type: Avg_Shift + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT AVG(CAST("{COLUMN_NAME}" AS FLOAT)) AS current_average FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_CUSTOM.yaml b/testgen/template/dbsetup_test_types/test_types_CUSTOM.yaml index 3257b114..6005806c 100644 --- a/testgen/template/dbsetup_test_types/test_types_CUSTOM.yaml +++ b/testgen/template/dbsetup_test_types/test_types_CUSTOM.yaml @@ -382,3 +382,42 @@ test_types: FROM ( {CUSTOM_QUERY} ) TEST + - id: '10005' + test_type: CUSTOM + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + CASE + WHEN '{COLUMN_NAME_NO_QUOTES}' = '' OR '{COLUMN_NAME_NO_QUOTES}' IS NULL THEN NULL + ELSE '{COLUMN_NAME_NO_QUOTES}' + END as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + /* TODO: 'custom_query= {CUSTOM_QUERY_ESCAPED}' as input_parameters, */ + 'Skip_Errors={SKIP_ERRORS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( + {CUSTOM_QUERY} + ) TEST; diff --git a/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml b/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml index cdc5bfde..f4016136 100644 --- a/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml @@ -54,7 +54,7 @@ test_types: {HAVING_CONDITION} EXCEPT DISTINCT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -76,7 +76,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -98,7 +98,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -119,7 +119,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -141,7 +141,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -163,7 +163,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -185,7 +185,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -207,7 +207,7 @@ test_types: {HAVING_CONDITION} MINUS SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -229,7 +229,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -237,6 +237,28 @@ test_types: ORDER BY {COLUMN_NAME_NO_QUOTES} LIMIT {LIMIT} error_type: Test Results + - id: '10007' + test_id: 1502 + test_type: Combo_Match + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * + FROM ( SELECT {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {COLUMN_NAME_NO_QUOTES} + {HAVING_CONDITION} + EXCEPT + SELECT {MATCH_GROUPBY_NAMES} + FROM "{MATCH_TABLE_NAME}" + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} + ) test + ORDER BY {COLUMN_NAME_NO_QUOTES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2501' test_type: Combo_Match @@ -626,3 +648,47 @@ test_types: GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) test + - id: '10006' + test_type: Combo_Match + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( SELECT {COLUMN_NAME_NO_QUOTES} + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {COLUMN_NAME_NO_QUOTES} + {HAVING_CONDITION} + EXCEPT + SELECT {MATCH_GROUPBY_NAMES} + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} + ) test; diff --git a/testgen/template/dbsetup_test_types/test_types_Condition_Flag.yaml b/testgen/template/dbsetup_test_types/test_types_Condition_Flag.yaml index 733ef0b5..9c63f169 100644 --- a/testgen/template/dbsetup_test_types/test_types_Condition_Flag.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Condition_Flag.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10003' + test_type: Condition_Flag + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN {CUSTOM_QUERY} THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1366' test_id: '1006' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT * FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE {CUSTOM_QUERY} LIMIT {LIMIT} error_type: Test Results + - id: '10008' + test_id: 1006 + test_type: Condition_Flag + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * FROM "{TABLE_NAME}" WHERE {CUSTOM_QUERY} LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Constant.yaml b/testgen/template/dbsetup_test_types/test_types_Constant.yaml index 2bb8e6df..fff389cd 100644 --- a/testgen/template/dbsetup_test_types/test_types_Constant.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Constant.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10004' + test_type: Constant + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN {COLUMN_NAME} <> {BASELINE_VALUE} THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1367' test_id: '1007' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" <> {BASELINE_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10009' + test_id: 1007 + test_type: Constant + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" <> {BASELINE_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Daily_Record_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Daily_Record_Ct.yaml index fb9fe8bb..df372cd6 100644 --- a/testgen/template/dbsetup_test_types/test_types_Daily_Record_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Daily_Record_Ct.yaml @@ -121,6 +121,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10005' + test_type: Daily_Record_Ct + sql_flavor: salesforce_data360 + measure: |- + DATEDIFF('day', CAST(MIN({COLUMN_NAME}) AS DATE), CAST(MAX({COLUMN_NAME}) AS DATE))+1-COUNT(DISTINCT CAST({COLUMN_NAME} AS DATE)) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1368' test_id: '1009' @@ -263,4 +271,12 @@ test_types: lookup_query: |- WITH Pass0 AS (SELECT 1 C FROM DUMMY UNION ALL SELECT 1 FROM DUMMY), Pass1 AS (SELECT 1 C FROM Pass0 A, Pass0 B), Pass2 AS (SELECT 1 C FROM Pass1 A, Pass1 B), Pass3 AS (SELECT 1 C FROM Pass2 A, Pass2 B), Pass4 AS (SELECT 1 C FROM Pass3 A, Pass3 B), nums AS (SELECT ROW_NUMBER() OVER (ORDER BY C) - 1 AS rn FROM Pass4), bounds AS (SELECT MIN(CAST("{COLUMN_NAME}" AS DATE)) AS min_date, MAX(CAST("{COLUMN_NAME}" AS DATE)) AS max_date FROM "{TARGET_SCHEMA}"."{TABLE_NAME}"), daterange AS (SELECT ADD_DAYS(b.min_date, n.rn) AS all_dates FROM bounds b, nums n WHERE ADD_DAYS(b.min_date, n.rn) <= b.max_date), existing_periods AS (SELECT DISTINCT CAST("{COLUMN_NAME}" AS DATE) AS period, COUNT(1) AS period_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY CAST("{COLUMN_NAME}" AS DATE)) SELECT p.missing_period, p.prior_available_date, e.period_count AS prior_available_date_count, p.next_available_date, f.period_count AS next_available_date_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_date, MIN(c.period) AS next_available_date FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_date = e.period) LEFT JOIN existing_periods f ON (p.next_available_date = f.period) ORDER BY p.missing_period LIMIT {LIMIT} error_type: Test Results + - id: '10010' + test_id: 1009 + test_type: Daily_Record_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH RECURSIVE daterange(all_dates) AS (SELECT CAST(MIN("{COLUMN_NAME}") AS DATE) AS all_dates FROM "{TABLE_NAME}" UNION ALL SELECT CAST((d.all_dates + INTERVAL '1 day') AS DATE) AS all_dates FROM daterange d WHERE d.all_dates < (SELECT CAST(MAX("{COLUMN_NAME}") AS DATE) FROM "{TABLE_NAME}") ), existing_periods AS ( SELECT DISTINCT CAST("{COLUMN_NAME}" AS DATE) AS period, COUNT(1) AS period_count FROM "{TABLE_NAME}" GROUP BY CAST("{COLUMN_NAME}" AS DATE) ) SELECT p.missing_period, p.prior_available_date, e.period_count AS prior_available_date_count, p.next_available_date, f.period_count AS next_available_date_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_date, MIN(c.period) AS next_available_date FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_date = e.period) LEFT JOIN existing_periods f ON (p.next_available_date = f.period) ORDER BY p.missing_period LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Dec_Trunc.yaml b/testgen/template/dbsetup_test_types/test_types_Dec_Trunc.yaml index e717d8fb..770d1175 100644 --- a/testgen/template/dbsetup_test_types/test_types_Dec_Trunc.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Dec_Trunc.yaml @@ -118,6 +118,14 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10006' + test_type: Dec_Trunc + sql_flavor: salesforce_data360 + measure: |- + SUM(ROUND(ABS(MOD({COLUMN_NAME}, 1)), 5))+1 + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1369' test_id: '1011' @@ -199,4 +207,12 @@ test_types: lookup_query: |- SELECT DISTINCT CASE WHEN LOCATE(TO_VARCHAR("{COLUMN_NAME}"), '.') > 0 THEN LENGTH(SUBSTR(TO_VARCHAR("{COLUMN_NAME}"), LOCATE(TO_VARCHAR("{COLUMN_NAME}"), '.') + 1)) ELSE 0 END AS decimal_scale, COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY CASE WHEN LOCATE(TO_VARCHAR("{COLUMN_NAME}"), '.') > 0 THEN LENGTH(SUBSTR(TO_VARCHAR("{COLUMN_NAME}"), LOCATE(TO_VARCHAR("{COLUMN_NAME}"), '.') + 1)) ELSE 0 END LIMIT {LIMIT} error_type: Test Results + - id: '10011' + test_id: 1011 + test_type: Dec_Trunc + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT LENGTH(SPLIT_PART(CAST("{COLUMN_NAME}" AS TEXT), '.', 2)) AS decimal_scale, COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY decimal_scale LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Distinct_Date_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Distinct_Date_Ct.yaml index 4ddc1dd4..23967398 100644 --- a/testgen/template/dbsetup_test_types/test_types_Distinct_Date_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Distinct_Date_Ct.yaml @@ -118,6 +118,14 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10007' + test_type: Distinct_Date_Ct + sql_flavor: salesforce_data360 + measure: |- + COUNT(DISTINCT {COLUMN_NAME}) + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1370' test_id: '1012' @@ -196,4 +204,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Test Results + - id: '10012' + test_id: 1012 + test_type: Distinct_Date_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Distinct_Value_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Distinct_Value_Ct.yaml index e7737220..47e609ec 100644 --- a/testgen/template/dbsetup_test_types/test_types_Distinct_Value_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Distinct_Value_Ct.yaml @@ -117,6 +117,14 @@ test_types: test_operator: <> test_condition: |- {THRESHOLD_VALUE} + - id: '10008' + test_type: Distinct_Value_Ct + sql_flavor: salesforce_data360 + measure: |- + COUNT(DISTINCT {COLUMN_NAME}) + test_operator: '<>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1371' test_id: '1013' @@ -195,4 +203,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Test Results + - id: '10013' + test_id: 1013 + test_type: Distinct_Value_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml b/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml index 627cd8a3..666bf095 100644 --- a/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml @@ -49,13 +49,25 @@ test_types: lookup_query: |- WITH latest_ver AS ( SELECT {CONCAT_COLUMNS} AS category, - CAST(COUNT(*) AS FLOAT64) / SUM(COUNT(*)) OVER() AS pct_of_total - FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}` v1 + CAST(COUNT(*) AS FLOAT64) / CAST(SUM(COUNT(*)) OVER () AS FLOAT64) AS pct_of_total + FROM `{TARGET_SCHEMA}.{TABLE_NAME}` v1 WHERE {SUBSET_CONDITION} - GROUP BY {CONCAT_COLUMNS} + GROUP BY {COLUMN_NAME_NO_QUOTES} + ), + older_ver AS ( + SELECT {CONCAT_MATCH_GROUPBY} AS category, + CAST(COUNT(*) AS FLOAT64) / CAST(SUM(COUNT(*)) OVER () AS FLOAT64) AS pct_of_total + FROM `{MATCH_SCHEMA_NAME}.{TABLE_NAME}` v2 + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} ) - SELECT * - FROM latest_ver + SELECT COALESCE(l.category, o.category) AS category, + o.pct_of_total AS old_pct, + l.pct_of_total AS new_pct + FROM latest_ver l + FULL JOIN older_ver o + ON l.category = o.category + ORDER BY COALESCE(l.category, o.category) LIMIT {LIMIT}; error_type: Test Results - id: '1336' @@ -74,7 +86,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -102,7 +114,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) as FLOAT) / CAST(SUM(COUNT(*)) OVER () as FLOAT) AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT TOP {LIMIT} COALESCE(l.category, o.category) AS category, @@ -129,7 +141,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -157,7 +169,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -185,7 +197,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -213,7 +225,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -241,7 +253,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) AS NUMBER) / CAST(SUM(COUNT(*)) OVER () AS NUMBER) AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -269,7 +281,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) AS DECIMAL) / CAST(SUM(COUNT(*)) OVER () AS DECIMAL) AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -281,6 +293,33 @@ test_types: ORDER BY COALESCE(l.category, o.category) LIMIT {LIMIT} error_type: Test Results + - id: '10014' + test_id: 1503 + test_type: Distribution_Shift + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH latest_ver + AS ( SELECT {CONCAT_COLUMNS} as category, + CAST(COUNT(*) AS DOUBLE) / CAST(SUM(COUNT(*)) OVER () AS DOUBLE) AS pct_of_total + FROM "{TABLE_NAME}" v1 + WHERE {SUBSET_CONDITION} + GROUP BY {COLUMN_NAME_NO_QUOTES} ), + older_ver + AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, + CAST(COUNT(*) AS DOUBLE) / CAST(SUM(COUNT(*)) OVER () AS DOUBLE) AS pct_of_total + FROM "{MATCH_TABLE_NAME}" v2 + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} ) + SELECT COALESCE(l.category, o.category) AS category, + o.pct_of_total AS old_pct, + l.pct_of_total AS new_pct + FROM latest_ver l + FULL JOIN older_ver o + ON (l.category = o.category) + ORDER BY COALESCE(l.category, o.category) + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2503' test_type: Distribution_Shift @@ -302,7 +341,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} AS category, CAST(COUNT(*) AS FLOAT64) / CAST(SUM(COUNT(*)) OVER () AS FLOAT64) AS pct_of_total - FROM `{MATCH_SCHEMA_NAME}.{TABLE_NAME}` v2 + FROM `{MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME}` v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), @@ -355,7 +394,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -408,7 +447,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) as FLOAT) / CAST(SUM(COUNT(*)) OVER () as FLOAT) AS pct_of_total - FROM "{MATCH_SCHEMA_NAME}"."{TABLE_NAME}" v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -461,7 +500,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -514,7 +553,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -567,7 +606,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -620,7 +659,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -673,7 +712,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) AS NUMBER) / CAST(SUM(COUNT(*)) OVER () AS NUMBER) AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -724,7 +763,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) AS DECIMAL) / CAST(SUM(COUNT(*)) OVER () AS DECIMAL) AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -756,3 +795,56 @@ test_types: SELECT 0.5 * ABS(SUM(new_pct * LN(new_pct/avg_pct)/LN(2))) + 0.5 * ABS(SUM(old_pct * LN(old_pct/avg_pct)/LN(2))) as js_divergence FROM dataset ) rslt + - id: '10007' + test_type: Distribution_Shift + sql_flavor: salesforce_data360 + template: |- + -- Relative Entropy: measured by Jensen-Shannon Divergence + -- Smoothed and normalized version of KL divergence, + -- with scores between 0 (identical) and 1 (maximally different), + -- when using the base-2 logarithm. Formula is: + -- 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m) + -- Log base 2 of x = LN(x)/LN(2) + WITH latest_ver + AS ( SELECT {CONCAT_COLUMNS} as category, + CAST(COUNT(*) AS DOUBLE) / CAST(SUM(COUNT(*)) OVER () AS DOUBLE) AS pct_of_total + FROM {QUOTE}{TABLE_NAME}{QUOTE} v1 + WHERE {SUBSET_CONDITION} + GROUP BY {COLUMN_NAME_NO_QUOTES} ), + older_ver + AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, + CAST(COUNT(*) AS DOUBLE) / CAST(SUM(COUNT(*)) OVER () AS DOUBLE) AS pct_of_total + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} ), + dataset + AS ( SELECT COALESCE(l.category, o.category) AS category, + COALESCE(o.pct_of_total, 0.0000001) AS old_pct, + COALESCE(l.pct_of_total, 0.0000001) AS new_pct, + (COALESCE(o.pct_of_total, 0.0000001) + + COALESCE(l.pct_of_total, 0.0000001))/2.0 AS avg_pct + FROM latest_ver l + FULL JOIN older_ver o + ON (l.category = o.category) ) + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + -- '{GROUPBY_NAMES}' as column_names, + '{THRESHOLD_VALUE}' as threshold_value, + NULL as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN js_divergence > {THRESHOLD_VALUE} THEN 0 ELSE 1 END as result_code, + CONCAT('Divergence Level: ', + CONCAT(CAST(js_divergence AS {VARCHAR_TYPE}), + ', Threshold: {THRESHOLD_VALUE}.')) as result_message, + js_divergence as result_measure + FROM ( + SELECT 0.5 * ABS(SUM(new_pct * LN(new_pct/avg_pct)/LN(2))) + + 0.5 * ABS(SUM(old_pct * LN(old_pct/avg_pct)/LN(2))) as js_divergence + FROM dataset ) rslt; diff --git a/testgen/template/dbsetup_test_types/test_types_Dupe_Rows.yaml b/testgen/template/dbsetup_test_types/test_types_Dupe_Rows.yaml index 57c778cc..1ef27125 100644 --- a/testgen/template/dbsetup_test_types/test_types_Dupe_Rows.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Dupe_Rows.yaml @@ -165,6 +165,20 @@ test_types: ORDER BY {GROUPBY_NAMES} LIMIT {LIMIT} error_type: Test Results + - id: '10015' + test_id: 1510 + test_type: Dupe_Rows + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT {GROUPBY_NAMES}, COUNT(*) as record_ct + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + HAVING COUNT(*) > 1 + ORDER BY {GROUPBY_NAMES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2511' test_type: Dupe_Rows @@ -499,3 +513,41 @@ test_types: GROUP BY {GROUPBY_NAMES} HAVING COUNT(*) > 1 ) test + - id: '10008' + test_type: Dupe_Rows + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' duplicate row(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COALESCE(SUM(record_ct), 0) as result_measure + FROM ( SELECT {GROUPBY_NAMES}, COUNT(*) as record_ct + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + HAVING COUNT(*) > 1 + ) test; diff --git a/testgen/template/dbsetup_test_types/test_types_Email_Format.yaml b/testgen/template/dbsetup_test_types/test_types_Email_Format.yaml index ab0a8704..6d6573b4 100644 --- a/testgen/template/dbsetup_test_types/test_types_Email_Format.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Email_Format.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10009' + test_type: Email_Format + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN NOT REGEXP_LIKE(CAST({COLUMN_NAME} AS VARCHAR), '^[A-Za-z0-9._''%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$') THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1372' test_id: '1014' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE NOT "{COLUMN_NAME}" LIKE_REGEXPR '^[A-Za-z0-9._''%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$' GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10016' + test_id: 1014 + test_type: Email_Format + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE NOT REGEXP_LIKE(CAST("{COLUMN_NAME}" AS VARCHAR), '^[A-Za-z0-9._''%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$') GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Freshness_Trend.yaml b/testgen/template/dbsetup_test_types/test_types_Freshness_Trend.yaml index ba60e7c5..c8d1e3a2 100644 --- a/testgen/template/dbsetup_test_types/test_types_Freshness_Trend.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Freshness_Trend.yaml @@ -492,3 +492,53 @@ test_types: ELSE COALESCE(TO_VARCHAR(interval_minutes), 'Unknown') END AS result_signal FROM test_data; + - id: '10009' + test_type: Freshness_Trend + sql_flavor: salesforce_data360 + template: |- + WITH test_data AS ( + SELECT + MD5({CUSTOM_QUERY}) AS fingerprint, + DATEDIFF('minute', CAST(NULLIF('{BASELINE_SUM}', '') AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) AS interval_minutes + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + ) + SELECT '{TEST_TYPE}' AS test_type, + '{TEST_DEFINITION_ID}' AS test_definition_id, + '{TEST_SUITE_ID}' AS test_suite_id, + '{TEST_RUN_ID}' AS test_run_id, + '{RUN_DATE}' AS test_time, + '{SCHEMA_NAME}' AS schema_name, + '{TABLE_NAME}' AS table_name, + '{COLUMN_NAME_NO_QUOTES}' AS column_names, + '{SKIP_ERRORS}' AS threshold_value, + {SKIP_ERRORS} AS skip_errors, + '{INPUT_PARAMETERS}' AS input_parameters, + fingerprint AS result_measure, + CASE + -- Training mode: tolerances not yet calculated + WHEN {LOWER_TOLERANCE} IS NULL AND {UPPER_TOLERANCE} IS NULL THEN -1 + -- No change and excluded day: suppress + WHEN fingerprint = '{BASELINE_VALUE}' AND {IS_EXCLUDED_DAY} = 1 THEN 1 + -- No change, beyond time range (business time): LATE + WHEN fingerprint = '{BASELINE_VALUE}' + AND (interval_minutes - {EXCLUDED_MINUTES}) > {THRESHOLD_VALUE} THEN 0 + -- Table changed outside time range (business time): UNEXPECTED + WHEN fingerprint <> '{BASELINE_VALUE}' + AND NOT (interval_minutes - {EXCLUDED_MINUTES}) + BETWEEN {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} THEN 0 + ELSE 1 + END AS result_code, + 'Table update detected: ' || CASE WHEN fingerprint <> '{BASELINE_VALUE}' THEN 'Yes' ELSE 'No' END + || CASE + WHEN fingerprint <> '{BASELINE_VALUE}' AND (interval_minutes - {EXCLUDED_MINUTES}) BETWEEN {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} THEN '. On time.' + WHEN fingerprint <> '{BASELINE_VALUE}' AND (interval_minutes - {EXCLUDED_MINUTES}) < {LOWER_TOLERANCE} THEN '. Earlier than expected.' + WHEN fingerprint <> '{BASELINE_VALUE}' AND (interval_minutes - {EXCLUDED_MINUTES}) > {UPPER_TOLERANCE} THEN '. Later than expected.' + WHEN fingerprint = '{BASELINE_VALUE}' AND {IS_EXCLUDED_DAY} = 0 AND (interval_minutes - {EXCLUDED_MINUTES}) > {THRESHOLD_VALUE} THEN '. Late.' + ELSE '' + END AS result_message, + CASE + WHEN fingerprint <> '{BASELINE_VALUE}' THEN '0' + ELSE COALESCE(CAST(interval_minutes AS VARCHAR), 'Unknown') + END AS result_signal + FROM test_data; diff --git a/testgen/template/dbsetup_test_types/test_types_Future_Date.yaml b/testgen/template/dbsetup_test_types/test_types_Future_Date.yaml index 938091da..c2327843 100644 --- a/testgen/template/dbsetup_test_types/test_types_Future_Date.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Future_Date.yaml @@ -116,6 +116,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10010' + test_type: Future_Date + sql_flavor: salesforce_data360 + measure: |- + SUM(GREATEST(0, SIGN(CAST({COLUMN_NAME} AS DATE) - CAST('{RUN_DATE}' AS DATE)))) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1373' test_id: '1015' @@ -193,4 +201,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS DATE) > TO_DATE('{TEST_DATE}', 'YYYY-MM-DD HH24:MI:SS') GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10017' + test_id: 1015 + test_type: Future_Date + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE GREATEST(0, SIGN(CAST("{COLUMN_NAME}" AS DATE) - CAST('{TEST_DATE}' AS DATE))) > {THRESHOLD_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Future_Date_1Y.yaml b/testgen/template/dbsetup_test_types/test_types_Future_Date_1Y.yaml index 01a42a83..c1ec7d6d 100644 --- a/testgen/template/dbsetup_test_types/test_types_Future_Date_1Y.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Future_Date_1Y.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10011' + test_type: Future_Date_1Y + sql_flavor: salesforce_data360 + measure: |- + SUM(GREATEST(0, SIGN(CAST({COLUMN_NAME} AS DATE) - (CAST('{RUN_DATE}' AS DATE)+365)))) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1374' test_id: '1016' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS DATE) > ADD_DAYS(TO_DATE('{TEST_DATE}', 'YYYY-MM-DD HH24:MI:SS'), 365) GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10018' + test_id: 1016 + test_type: Future_Date_1Y + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE GREATEST(0, SIGN(CAST("{COLUMN_NAME}" AS DATE) - (CAST('{TEST_DATE}' AS DATE) + 365))) > {THRESHOLD_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Incr_Avg_Shift.yaml b/testgen/template/dbsetup_test_types/test_types_Incr_Avg_Shift.yaml index eddb6227..00b6b9f6 100644 --- a/testgen/template/dbsetup_test_types/test_types_Incr_Avg_Shift.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Incr_Avg_Shift.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>=' test_condition: |- {THRESHOLD_VALUE} + - id: '10012' + test_type: Incr_Avg_Shift + sql_flavor: salesforce_data360 + measure: |- + COALESCE(ABS( ({BASELINE_AVG} - (SUM({COLUMN_NAME}) - {BASELINE_SUM}) / NULLIF(CAST(COUNT({COLUMN_NAME}) AS FLOAT) - {BASELINE_VALUE_CT}, 0)) / {BASELINE_SD} ), 0) + test_operator: '>=' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1375' test_id: '1017' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT AVG(CAST("{COLUMN_NAME}" AS DECIMAL)) AS current_average, SUM(CAST("{COLUMN_NAME}" AS DECIMAL)) AS current_sum, NULLIF(COUNT("{COLUMN_NAME}"), 0) as current_value_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10019' + test_id: 1017 + test_type: Incr_Avg_Shift + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT AVG(CAST("{COLUMN_NAME}" AS FLOAT)) AS current_average, SUM(CAST("{COLUMN_NAME}" AS FLOAT)) AS current_sum, NULLIF(CAST(COUNT("{COLUMN_NAME}" ) AS FLOAT), 0) as current_value_count FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml b/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml index 2cf10836..cdd4bfda 100644 --- a/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml +++ b/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml @@ -115,6 +115,14 @@ test_types: test_operator: <> test_condition: |- {THRESHOLD_VALUE} + - id: '10013' + test_type: LOV_All + sql_flavor: salesforce_data360 + measure: |- + (SELECT ARRAY_JOIN(ARRAY_AGG(sub_val), '|') FROM (SELECT DISTINCT {COLUMN_NAME} AS sub_val FROM "{TABLE_NAME}" WHERE {SUBSET_CONDITION} ORDER BY 1 LIMIT 1000) sub_lov) + test_operator: '<>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1376' test_id: '1018' @@ -203,4 +211,12 @@ test_types: lookup_query: |- SELECT STRING_AGG("{COLUMN_NAME}", '|' ORDER BY "{COLUMN_NAME}") AS lov FROM (SELECT DISTINCT "{COLUMN_NAME}" FROM "{TARGET_SCHEMA}"."{TABLE_NAME}") HAVING STRING_AGG("{COLUMN_NAME}", '|' ORDER BY "{COLUMN_NAME}") <> {THRESHOLD_VALUE} LIMIT {LIMIT} error_type: Test Results + - id: '10020' + test_id: 1018 + test_type: LOV_All + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT ARRAY_JOIN(ARRAY_AGG(sub_val), '|') AS lov FROM (SELECT DISTINCT "{COLUMN_NAME}" AS sub_val FROM "{TABLE_NAME}" ORDER BY 1 LIMIT 1000) sub_lov HAVING ARRAY_JOIN(ARRAY_AGG(sub_val), '|') <> {THRESHOLD_VALUE} LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_LOV_Match.yaml b/testgen/template/dbsetup_test_types/test_types_LOV_Match.yaml index 768dd65b..38b8040c 100644 --- a/testgen/template/dbsetup_test_types/test_types_LOV_Match.yaml +++ b/testgen/template/dbsetup_test_types/test_types_LOV_Match.yaml @@ -221,6 +221,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10014' + test_type: LOV_Match + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN NULLIF({COLUMN_NAME}, '') NOT IN {BASELINE_VALUE} THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1377' test_id: '1019' @@ -298,4 +306,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL AND "{COLUMN_NAME}" NOT IN {BASELINE_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10021' + test_id: 1019 + test_type: LOV_Match + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT NULLIF("{COLUMN_NAME}", '') AS "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE NULLIF("{COLUMN_NAME}", '') NOT IN {BASELINE_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Metric_Trend.yaml b/testgen/template/dbsetup_test_types/test_types_Metric_Trend.yaml index 31e17846..7675e9d1 100644 --- a/testgen/template/dbsetup_test_types/test_types_Metric_Trend.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Metric_Trend.yaml @@ -106,6 +106,14 @@ test_types: test_operator: NOT BETWEEN test_condition: |- {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} + - id: '10015' + test_type: Metric_Trend + sql_flavor: salesforce_data360 + measure: |- + {CUSTOM_QUERY} + test_operator: NOT BETWEEN + test_condition: |- + {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} target_data_lookups: - id: '1484' test_id: '1514' @@ -206,4 +214,15 @@ test_types: {UPPER_TOLERANCE} AS upper_bound FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10022' + test_id: 1514 + test_type: Metric_Trend + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT {CUSTOM_QUERY} AS current_count, + {LOWER_TOLERANCE} AS lower_bound, + {UPPER_TOLERANCE} AS upper_bound + FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml b/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml index a2762969..1ade805e 100644 --- a/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10016' + test_type: Min_Date + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN CAST({COLUMN_NAME} AS DATE) < CAST('{BASELINE_VALUE}' AS DATE) THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1378' test_id: '1020' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" < CAST('{BASELINE_VALUE}' AS {COLUMN_TYPE}) GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10023' + test_id: 1020 + test_type: Min_Date + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS DATE) < CAST('{BASELINE_VALUE}' AS DATE) GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Min_Val.yaml b/testgen/template/dbsetup_test_types/test_types_Min_Val.yaml index 3a852155..90f107f4 100644 --- a/testgen/template/dbsetup_test_types/test_types_Min_Val.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Min_Val.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10017' + test_type: Min_Val + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN {COLUMN_NAME} < {BASELINE_VALUE} - 1e-6 THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1379' test_id: '1021' @@ -193,4 +201,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", (ABS("{COLUMN_NAME}") - ABS({BASELINE_VALUE})) AS difference_from_baseline FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" < {BASELINE_VALUE} LIMIT {LIMIT} error_type: Test Results + - id: '10024' + test_id: 1021 + test_type: Min_Val + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", (ABS("{COLUMN_NAME}") - ABS({BASELINE_VALUE})) AS difference_from_baseline FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" < {BASELINE_VALUE} LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Missing_Pct.yaml b/testgen/template/dbsetup_test_types/test_types_Missing_Pct.yaml index d85d0908..f4dd0a0a 100644 --- a/testgen/template/dbsetup_test_types/test_types_Missing_Pct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Missing_Pct.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>=' test_condition: |- {THRESHOLD_VALUE} + - id: '10018' + test_type: Missing_Pct + sql_flavor: salesforce_data360 + measure: |- + ABS( 2.0 * ASIN( SQRT( CAST({BASELINE_VALUE_CT} AS FLOAT) / CAST({BASELINE_CT} AS FLOAT) ) ) - 2 * ASIN( SQRT( CAST(COUNT( {COLUMN_NAME} ) AS FLOAT) / CAST(NULLIF(COUNT(*), 0) AS FLOAT) )) ) + test_operator: '>=' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1380' test_id: '1022' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT * FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NULL LIMIT {LIMIT} error_type: Test Results + - id: '10025' + test_id: 1022 + test_type: Missing_Pct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NULL OR CAST("{COLUMN_NAME}" AS VARCHAR(255)) = '' LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Monthly_Rec_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Monthly_Rec_Ct.yaml index 8fd1fcdb..35580b34 100644 --- a/testgen/template/dbsetup_test_types/test_types_Monthly_Rec_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Monthly_Rec_Ct.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10019' + test_type: Monthly_Rec_Ct + sql_flavor: salesforce_data360 + measure: |- + (MAX(DATEDIFF('month', CAST({COLUMN_NAME} AS DATE), CAST('{RUN_DATE}' AS DATE))) - MIN(DATEDIFF('month', CAST({COLUMN_NAME} AS DATE), CAST('{RUN_DATE}' AS DATE))) + 1) - COUNT(DISTINCT DATEDIFF('month', CAST({COLUMN_NAME} AS DATE), CAST('{RUN_DATE}' AS DATE))) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1381' test_id: '1023' @@ -259,4 +267,12 @@ test_types: lookup_query: |- WITH Pass0 AS (SELECT 1 C FROM DUMMY UNION ALL SELECT 1 FROM DUMMY), Pass1 AS (SELECT 1 C FROM Pass0 A, Pass0 B), Pass2 AS (SELECT 1 C FROM Pass1 A, Pass1 B), Pass3 AS (SELECT 1 C FROM Pass2 A, Pass2 B), nums AS (SELECT ROW_NUMBER() OVER (ORDER BY C) - 1 AS rn FROM Pass3), bounds AS (SELECT TO_DATE(YEAR(MIN("{COLUMN_NAME}")) || '-' || LPAD(MONTH(MIN("{COLUMN_NAME}")), 2, '0') || '-01', 'YYYY-MM-DD') AS min_month, TO_DATE(YEAR(MAX("{COLUMN_NAME}")) || '-' || LPAD(MONTH(MAX("{COLUMN_NAME}")), 2, '0') || '-01', 'YYYY-MM-DD') AS max_month FROM "{TARGET_SCHEMA}"."{TABLE_NAME}"), daterange AS (SELECT ADD_MONTHS(b.min_month, n.rn) AS all_dates FROM bounds b, nums n WHERE ADD_MONTHS(b.min_month, n.rn) <= b.max_month), existing_periods AS (SELECT DISTINCT TO_DATE(YEAR("{COLUMN_NAME}") || '-' || LPAD(MONTH("{COLUMN_NAME}"), 2, '0') || '-01', 'YYYY-MM-DD') AS period, COUNT(1) AS period_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY YEAR("{COLUMN_NAME}"), MONTH("{COLUMN_NAME}")) SELECT p.missing_period, p.prior_available_month, e.period_count AS prior_available_month_count, p.next_available_month, f.period_count AS next_available_month_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_month, MIN(c.period) AS next_available_month FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_month = e.period) LEFT JOIN existing_periods f ON (p.next_available_month = f.period) ORDER BY p.missing_period LIMIT {LIMIT} error_type: Test Results + - id: '10026' + test_id: 1023 + test_type: Monthly_Rec_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH RECURSIVE daterange(all_dates) AS (SELECT CAST(DATE_TRUNC('month', MIN("{COLUMN_NAME}")) AS DATE) AS all_dates FROM "{TABLE_NAME}" UNION ALL SELECT CAST((d.all_dates + INTERVAL '1 month') AS DATE) AS all_dates FROM daterange d WHERE d.all_dates < (SELECT CAST(DATE_TRUNC('month', MAX("{COLUMN_NAME}")) AS DATE) FROM "{TABLE_NAME}") ), existing_periods AS ( SELECT DISTINCT CAST(DATE_TRUNC('month',"{COLUMN_NAME}") AS DATE) AS period, COUNT(1) AS period_count FROM "{TABLE_NAME}" GROUP BY CAST(DATE_TRUNC('month',"{COLUMN_NAME}") AS DATE) ) SELECT p.missing_period, p.prior_available_month, e.period_count AS prior_available_month_count, p.next_available_month, f.period_count AS next_available_month_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_month, MIN(c.period) AS next_available_month FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_month = e.period) LEFT JOIN existing_periods f ON (p.next_available_month = f.period) ORDER BY p.missing_period LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Above.yaml b/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Above.yaml index 6b26ccb1..5cf1493f 100644 --- a/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Above.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Above.yaml @@ -122,6 +122,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10020' + test_type: Outlier_Pct_Above + sql_flavor: salesforce_data360 + measure: |- + CAST(SUM(CASE WHEN CAST({COLUMN_NAME} AS FLOAT) > {BASELINE_AVG}+(2.0*{BASELINE_SD}) THEN 1 ELSE 0 END) AS FLOAT) / CAST(NULLIF(COUNT({COLUMN_NAME}), 0) AS FLOAT) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1382' test_id: '1024' @@ -199,4 +207,12 @@ test_types: lookup_query: |- SELECT ({BASELINE_AVG} + (2*{BASELINE_SD})) AS outlier_threshold, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS DECIMAL) > ({BASELINE_AVG} + (2*{BASELINE_SD})) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC error_type: Test Results + - id: '10027' + test_id: 1024 + test_type: Outlier_Pct_Above + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT ({BASELINE_AVG} + (2*{BASELINE_SD})) AS outlier_threshold, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS FLOAT) > ({BASELINE_AVG} + (2*{BASELINE_SD})) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Below.yaml b/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Below.yaml index a2354e6e..fa88ab8c 100644 --- a/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Below.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Below.yaml @@ -122,6 +122,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10021' + test_type: Outlier_Pct_Below + sql_flavor: salesforce_data360 + measure: |- + CAST(SUM(CASE WHEN CAST({COLUMN_NAME} AS FLOAT) < {BASELINE_AVG}-(2.0*{BASELINE_SD}) THEN 1 ELSE 0 END) AS FLOAT) / CAST(NULLIF(COUNT({COLUMN_NAME}), 0) AS FLOAT) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1383' test_id: '1025' @@ -199,4 +207,12 @@ test_types: lookup_query: |- SELECT ({BASELINE_AVG} - (2*{BASELINE_SD})) AS outlier_threshold, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS DECIMAL) < ({BASELINE_AVG} - (2*{BASELINE_SD})) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC error_type: Test Results + - id: '10028' + test_id: 1025 + test_type: Outlier_Pct_Below + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT ({BASELINE_AVG} - (2*{BASELINE_SD})) AS outlier_threshold, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS FLOAT) < ({BASELINE_AVG} - (2*{BASELINE_SD})) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Pattern_Match.yaml b/testgen/template/dbsetup_test_types/test_types_Pattern_Match.yaml index b3d0862f..6998da47 100644 --- a/testgen/template/dbsetup_test_types/test_types_Pattern_Match.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Pattern_Match.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10022' + test_type: Pattern_Match + sql_flavor: salesforce_data360 + measure: |- + COUNT(NULLIF({COLUMN_NAME}, '')) - SUM(CASE WHEN REGEXP_LIKE(CAST(NULLIF({COLUMN_NAME}, '') AS VARCHAR), '{BASELINE_VALUE}') THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1384' test_id: '1026' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE NOT NULLIF(TO_VARCHAR("{COLUMN_NAME}"), '') LIKE_REGEXPR '{BASELINE_VALUE}' GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10029' + test_id: 1026 + test_type: Pattern_Match + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE NOT REGEXP_LIKE(CAST(NULLIF("{COLUMN_NAME}", '') AS VARCHAR), '{BASELINE_VALUE}') GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Recency.yaml b/testgen/template/dbsetup_test_types/test_types_Recency.yaml index 088a3a92..34945e9b 100644 --- a/testgen/template/dbsetup_test_types/test_types_Recency.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Recency.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10023' + test_type: Recency + sql_flavor: salesforce_data360 + measure: |- + DATEDIFF('day', CAST(MAX({COLUMN_NAME}) AS DATE), CAST('{RUN_DATE}' AS DATE)) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1385' test_id: '1028' @@ -203,4 +211,12 @@ test_types: lookup_query: |- SELECT DISTINCT col AS latest_date_available, TO_DATE('{TEST_DATE}', 'YYYY-MM-DD HH24:MI:SS') AS test_run_date FROM (SELECT MAX("{COLUMN_NAME}") AS col FROM "{TARGET_SCHEMA}"."{TABLE_NAME}") WHERE <%DATEDIFF_DAY;col;TO_DATE('{TEST_DATE}', 'YYYY-MM-DD HH24:MI:SS')%> > {THRESHOLD_VALUE} LIMIT {LIMIT} error_type: Test Results + - id: '10030' + test_id: 1028 + test_type: Recency + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT col AS latest_date_available, CAST('{TEST_DATE}' AS DATE) as test_run_date FROM (SELECT MAX("{COLUMN_NAME}") AS col FROM "{TABLE_NAME}") a WHERE DATEDIFF('day', CAST(col AS DATE), CAST('{TEST_DATE}' AS DATE)) > {THRESHOLD_VALUE} LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Required.yaml b/testgen/template/dbsetup_test_types/test_types_Required.yaml index 625b135f..cb294860 100644 --- a/testgen/template/dbsetup_test_types/test_types_Required.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Required.yaml @@ -116,6 +116,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10024' + test_type: Required + sql_flavor: salesforce_data360 + measure: |- + COUNT(*) - COUNT({COLUMN_NAME}) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1386' test_id: '1030' @@ -192,4 +200,12 @@ test_types: lookup_query: |- SELECT * FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NULL LIMIT {LIMIT} error_type: Test Results + - id: '10031' + test_id: 1030 + test_type: Required + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NULL LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Row_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Row_Ct.yaml index b5c4459d..06c3d62e 100644 --- a/testgen/template/dbsetup_test_types/test_types_Row_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Row_Ct.yaml @@ -115,6 +115,13 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10025' + test_type: Row_Ct + sql_flavor: salesforce_data360 + measure: COUNT(*) + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1387' test_id: '1031' @@ -195,4 +202,12 @@ test_types: lookup_query: |- WITH CTE AS (SELECT COUNT(*) AS current_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}") SELECT current_count, ABS(ROUND(100 * (current_count - {THRESHOLD_VALUE}) / {THRESHOLD_VALUE}, 2)) AS row_count_pct_decrease FROM cte WHERE current_count < {THRESHOLD_VALUE} error_type: Test Results + - id: '10032' + test_id: 1031 + test_type: Row_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH cte AS (SELECT COUNT(*) AS current_count FROM "{TABLE_NAME}") SELECT current_count, ABS(ROUND(100 * CAST((current_count - {THRESHOLD_VALUE}) AS NUMERIC) / CAST({THRESHOLD_VALUE} AS NUMERIC),2)) AS row_count_pct_decrease FROM cte WHERE current_count < {THRESHOLD_VALUE}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Row_Ct_Pct.yaml b/testgen/template/dbsetup_test_types/test_types_Row_Ct_Pct.yaml index 05efdf4c..47cbf379 100644 --- a/testgen/template/dbsetup_test_types/test_types_Row_Ct_Pct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Row_Ct_Pct.yaml @@ -116,6 +116,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10026' + test_type: Row_Ct_Pct + sql_flavor: salesforce_data360 + measure: |- + ABS(ROUND(100.0 * CAST((COUNT(*) - {BASELINE_CT}) AS DECIMAL(18,4)) / CAST({BASELINE_CT} AS DECIMAL(18,4)), 2)) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1388' test_id: '1032' @@ -195,4 +203,12 @@ test_types: lookup_query: |- WITH CTE AS (SELECT COUNT(*) AS current_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}") SELECT current_count, {BASELINE_CT} AS baseline_count, ABS(ROUND(100 * (current_count - {BASELINE_CT}) / {BASELINE_CT}, 2)) AS row_count_pct_difference FROM cte error_type: Test Results + - id: '10033' + test_id: 1032 + test_type: Row_Ct_Pct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT COUNT(*) AS current_count, {BASELINE_CT} AS baseline_count, ABS(ROUND(100 * CAST((COUNT(*) - {BASELINE_CT}) AS NUMERIC) / CAST({BASELINE_CT} AS NUMERIC), 2)) AS row_count_pct_difference FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Schema_Drift.yaml b/testgen/template/dbsetup_test_types/test_types_Schema_Drift.yaml index e5c908a7..4992ba2c 100644 --- a/testgen/template/dbsetup_test_types/test_types_Schema_Drift.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Schema_Drift.yaml @@ -530,3 +530,58 @@ test_types: AS result_message, column_adds + column_drops + column_mods AS result_measure FROM table_changes; + - id: '10010' + test_type: Schema_Drift + sql_flavor: salesforce_data360 + template: |- + WITH prev_test AS ( + SELECT MAX(test_starttime) AS last_run_time + FROM {APP_SCHEMA_NAME}.test_runs + WHERE test_suite_id = '{TEST_SUITE_ID}'::UUID + -- Ignore current run + AND id <> '{TEST_RUN_ID}'::UUID + ), + table_changes AS ( + SELECT + dsl.table_name, + MAX(prev_test.last_run_time) as window_start, + MAX(CASE WHEN dsl.column_id IS NULL AND dsl.change = 'A' THEN dsl.change_date ELSE NULL END) as last_add_date, + MAX(CASE WHEN dsl.column_id IS NULL AND dsl.change = 'D' THEN dsl.change_date ELSE NULL END) as last_drop_date, + COUNT(*) FILTER (WHERE dsl.column_id IS NOT NULL AND dsl.change = 'A') AS column_adds, + COUNT(*) FILTER (WHERE dsl.column_id IS NOT NULL AND dsl.change = 'D') AS column_drops, + COUNT(*) FILTER (WHERE dsl.column_id IS NOT NULL AND dsl.change = 'M') AS column_mods + FROM {APP_SCHEMA_NAME}.data_structure_log dsl + CROSS JOIN prev_test + WHERE dsl.table_groups_id = '{TABLE_GROUPS_ID}'::UUID + -- if no previous tests, this comparision yelds null and nothing is counted + AND dsl.change_date > prev_test.last_run_time + GROUP BY dsl.table_name + ) + SELECT + '{TEST_TYPE}' AS test_type, + '{TEST_DEFINITION_ID}' AS test_definition_id, + '{TEST_SUITE_ID}' AS test_suite_id, + '{TEST_RUN_ID}' AS test_run_id, + '{RUN_DATE}' AS test_time, + '{SCHEMA_NAME}' AS schema_name, + table_name, + '{INPUT_PARAMETERS}' AS input_parameters, + (CASE + WHEN last_add_date IS NOT NULL AND (last_drop_date IS NULL OR last_add_date > last_drop_date) THEN 'A' + WHEN last_drop_date IS NOT NULL AND (last_add_date IS NULL OR last_drop_date > last_add_date) THEN 'D' + ELSE 'M' + END) + || '|' || column_adds + || '|' || column_drops + || '|' || column_mods + || '|' || window_start::TEXT + AS result_signal, + 0 AS result_code, + CASE WHEN last_add_date IS NOT NULL AND (last_drop_date IS NULL OR last_add_date > last_drop_date) THEN 'Table added. ' ELSE '' END + || CASE WHEN last_drop_date IS NOT NULL AND (last_add_date IS NULL OR last_drop_date > last_add_date) THEN 'Table dropped. ' ELSE '' END + || CASE WHEN column_adds > 0 THEN column_adds || ' columns added. ' ELSE '' END + || CASE WHEN column_drops > 0 THEN column_drops || ' columns dropped. ' ELSE '' END + || CASE WHEN column_mods > 0 THEN column_mods || ' columns modified. ' ELSE '' END + AS result_message, + column_adds + column_drops + column_mods AS result_measure + FROM table_changes; diff --git a/testgen/template/dbsetup_test_types/test_types_Street_Addr_Pattern.yaml b/testgen/template/dbsetup_test_types/test_types_Street_Addr_Pattern.yaml index 7956ef0a..36b83009 100644 --- a/testgen/template/dbsetup_test_types/test_types_Street_Addr_Pattern.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Street_Addr_Pattern.yaml @@ -118,6 +118,14 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10027' + test_type: Street_Addr_Pattern + sql_flavor: salesforce_data360 + measure: |- + 100.0*CAST(SUM(CASE WHEN REGEXP_LIKE({COLUMN_NAME}, '^[0-9]{1,5}[a-zA-Z]?\s\w{1,5}\.?\s?\w*\s?\w*\s[a-zA-Z]{1,6}\.?\s?[0-9]{0,5}[A-Z]{0,1}$') THEN 1 ELSE 0 END) AS FLOAT) / CAST(NULLIF(COUNT({COLUMN_NAME}), 0) AS FLOAT) + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1389' test_id: '1033' @@ -196,4 +204,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE NOT TO_VARCHAR("{COLUMN_NAME}") LIKE_REGEXPR '^[0-9]{1,5}[a-zA-Z]?[[:space:]][[:alnum:]_]{1,5}\.?[[:space:]]?[[:alnum:]_]*[[:space:]]?[[:alnum:]_]*[[:space:]][a-zA-Z]{1,6}\.?[[:space:]]?[0-9]{0,5}[A-Z]{0,1}$' GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Test Results + - id: '10034' + test_id: 1033 + test_type: Street_Addr_Pattern + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE NOT REGEXP_LIKE("{COLUMN_NAME}", '^[0-9]{1,5}[a-zA-Z]?\s\w{1,5}\.?\s?\w*\s?\w*\s[a-zA-Z]{1,6}\.?\s?[0-9]{0,5}[A-Z]{0,1}$') GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Table_Freshness.yaml b/testgen/template/dbsetup_test_types/test_types_Table_Freshness.yaml index 76823e83..d81e834e 100644 --- a/testgen/template/dbsetup_test_types/test_types_Table_Freshness.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Table_Freshness.yaml @@ -329,3 +329,35 @@ test_types: FROM {QUOTE}{SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} WHERE {SUBSET_CONDITION} ) test + - id: '10011' + test_type: Table_Freshness + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + fingerprint as result_signal, + CASE + WHEN '{LOWER_TOLERANCE}' = 'NULL' OR fingerprint = '{LOWER_TOLERANCE}' THEN 0 + ELSE 1 + END AS result_code, + CASE + WHEN '{LOWER_TOLERANCE}' = 'NULL' OR fingerprint = '{LOWER_TOLERANCE}' THEN 'No table change detected.' + ELSE 'Table change detected.' + END AS result_message, + CASE + WHEN '{LOWER_TOLERANCE}' = 'NULL' OR fingerprint = '{LOWER_TOLERANCE}' THEN 0 + ELSE 1 + END AS result_measure + FROM ( SELECT MD5({CUSTOM_QUERY}) as fingerprint + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + ) test; diff --git a/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml b/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml index 61346177..34329e26 100644 --- a/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml @@ -60,6 +60,26 @@ test_types: GROUP BY {COLUMN_NAME_NO_QUOTES} LIMIT {LIMIT}; error_type: Test Results + - id: '1396' + test_id: '1508' + test_type: Timeframe_Combo_Gain + sql_flavor: databricks + lookup_type: null + lookup_query: |- + SELECT {COLUMN_NAME_NO_QUOTES} + FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}` + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= (SELECT MAX({WINDOW_DATE_COLUMN}) FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}`) - 2 * {WINDOW_DAYS} + AND {WINDOW_DATE_COLUMN} < (SELECT MAX({WINDOW_DATE_COLUMN}) FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}`) - {WINDOW_DAYS} + GROUP BY {COLUMN_NAME_NO_QUOTES} + EXCEPT + SELECT {COLUMN_NAME_NO_QUOTES} + FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}` + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= (SELECT MAX({WINDOW_DATE_COLUMN}) FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}`) - {WINDOW_DAYS} + GROUP BY {COLUMN_NAME_NO_QUOTES} + LIMIT {LIMIT}; + error_type: Test Results - id: '1263' test_id: '1508' test_type: Timeframe_Combo_Gain @@ -199,6 +219,26 @@ test_types: GROUP BY {COLUMN_NAME_NO_QUOTES} LIMIT {LIMIT} error_type: Test Results + - id: '10035' + test_id: 1508 + test_type: Timeframe_Combo_Gain + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + GROUP BY {COLUMN_NAME_NO_QUOTES} + EXCEPT + SELECT {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + GROUP BY {COLUMN_NAME_NO_QUOTES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2507' test_type: Timeframe_Combo_Gain @@ -602,3 +642,49 @@ test_types: AND {WINDOW_DATE_COLUMN} >= ADD_DAYS((SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{SCHEMA_NAME}"."{TABLE_NAME}"), -{WINDOW_DAYS}) GROUP BY {COLUMN_NAME_NO_QUOTES} ) test + - id: '10012' + test_type: Timeframe_Combo_Gain + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS VARCHAR), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( + SELECT {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + GROUP BY {COLUMN_NAME_NO_QUOTES} + EXCEPT + SELECT {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + GROUP BY {COLUMN_NAME_NO_QUOTES} + ) test; diff --git a/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Match.yaml b/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Match.yaml index e3d2086a..6b10231d 100644 --- a/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Match.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Match.yaml @@ -340,6 +340,40 @@ test_types: LIMIT {LIMIT_2} ) error_type: Test Results + - id: '10036' + test_id: 1509 + test_type: Timeframe_Combo_Match + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH prior_diff AS ( + SELECT 'Prior Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + EXCEPT + SELECT 'Prior Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + ), + latest_diff AS ( + SELECT 'Latest Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + EXCEPT + SELECT 'Latest Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + ) + SELECT * FROM (SELECT * FROM prior_diff LIMIT {LIMIT_2}) p + UNION ALL + SELECT * FROM (SELECT * FROM latest_diff LIMIT {LIMIT_2}) l + error_type: Test Results test_templates: - id: '2508' test_type: Timeframe_Combo_Match @@ -881,3 +915,62 @@ test_types: AND {WINDOW_DATE_COLUMN} >= ADD_DAYS((SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{SCHEMA_NAME}"."{TABLE_NAME}"), -{WINDOW_DAYS}) ) ) test + - id: '10013' + test_type: Timeframe_Combo_Match + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS VARCHAR), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( + ( + SELECT 'Prior Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + EXCEPT + SELECT 'Prior Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + ) + UNION ALL + ( + SELECT 'Latest Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + EXCEPT + SELECT 'Latest Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + ) + ) test; diff --git a/testgen/template/dbsetup_test_types/test_types_US_State.yaml b/testgen/template/dbsetup_test_types/test_types_US_State.yaml index a14181e8..397611df 100644 --- a/testgen/template/dbsetup_test_types/test_types_US_State.yaml +++ b/testgen/template/dbsetup_test_types/test_types_US_State.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10028' + test_type: US_State + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN NULLIF({COLUMN_NAME}, '') NOT IN ('AL','AK','AS','AZ','AR','CA','CO','CT','DE','DC','FM','FL','GA','GU','HI','ID','IL','IN','IA','KS','KY','LA','ME','MH','MD','MA','MI','MN','MS','MO','MT','NE','NV','NH','NJ','NM','NY','NC','ND','MP','OH','OK','OR','PW','PA','PR','RI','SC','SD','TN','TX','UT','VT','VI','VA','WA','WV','WI','WY','AE','AP','AA') THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1390' test_id: '1036' @@ -195,4 +203,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL AND "{COLUMN_NAME}" NOT IN ('AL','AK','AS','AZ','AR','CA','CO','CT','DE','DC','FM','FL','GA','GU','HI','ID','IL','IN','IA','KS','KY','LA','ME','MH','MD','MA','MI','MN','MS','MO','MT','NE','NV','NH','NJ','NM','NY','NC','ND','MP','OH','OK','OR','PW','PA','PR','RI','SC','SD','TN','TX','UT','VT','VI','VA','WA','WV','WI','WY','AE','AP','AA') GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10037' + test_id: 1036 + test_type: US_State + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE NULLIF("{COLUMN_NAME}", '') NOT IN ('AL','AK','AS','AZ','AR','CA','CO','CT','DE','DC','FM','FL','GA','GU','HI','ID','IL','IN','IA','KS','KY','LA','ME','MH','MD','MA','MI','MN','MS','MO','MT','NE','NV','NH','NJ','NM','NY','NC','ND','MP','OH','OK','OR','PW','PA','PR','RI','SC','SD','TN','TX','UT','VT','VI','VA','WA','WV','WI','WY','AE','AP','AA') GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Unique.yaml b/testgen/template/dbsetup_test_types/test_types_Unique.yaml index abf22dae..e1b5b661 100644 --- a/testgen/template/dbsetup_test_types/test_types_Unique.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Unique.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10029' + test_type: Unique + sql_flavor: salesforce_data360 + measure: |- + COUNT(*) - COUNT(DISTINCT {COLUMN_NAME}) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1391' test_id: '1034' @@ -196,4 +204,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" HAVING COUNT(*) > 1 ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Test Results + - id: '10038' + test_id: 1034 + test_type: Unique + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" HAVING COUNT(*) > 1 ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Unique_Pct.yaml b/testgen/template/dbsetup_test_types/test_types_Unique_Pct.yaml index 6e8767ae..6cb5908a 100644 --- a/testgen/template/dbsetup_test_types/test_types_Unique_Pct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Unique_Pct.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>=' test_condition: |- {THRESHOLD_VALUE} + - id: '10030' + test_type: Unique_Pct + sql_flavor: salesforce_data360 + measure: |- + ABS( 2.0 * ASIN( SQRT(CAST({BASELINE_UNIQUE_CT} AS FLOAT) / CAST({BASELINE_VALUE_CT} AS FLOAT) ) ) - 2 * ASIN( SQRT( CAST(COUNT( DISTINCT {COLUMN_NAME} ) AS FLOAT) / CAST(NULLIF(COUNT( {COLUMN_NAME} ), 0) AS FLOAT) )) ) + test_operator: '>=' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1392' test_id: '1035' @@ -195,4 +203,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Test Results + - id: '10039' + test_id: 1035 + test_type: Unique_Pct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Valid_Characters.yaml b/testgen/template/dbsetup_test_types/test_types_Valid_Characters.yaml index 6110a2f9..cd73a08c 100644 --- a/testgen/template/dbsetup_test_types/test_types_Valid_Characters.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Valid_Characters.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10031' + test_type: Valid_Characters + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE({COLUMN_NAME}, CHR(160), ''), CHR(8203), ''), CHR(65279), ''), CHR(8239), ''), CHR(8201), ''), CHR(12288), ''), CHR(8204), '') <> {COLUMN_NAME} OR {COLUMN_NAME} LIKE ' %' OR {COLUMN_NAME} LIKE '''%''' OR {COLUMN_NAME} LIKE '"%"' THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1397' test_id: '1043' @@ -199,4 +207,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", NCHAR(160), ''), NCHAR(8203), ''), NCHAR(65279), ''), NCHAR(8239), ''), NCHAR(8201), ''), NCHAR(12288), ''), NCHAR(8204), '') <> "{COLUMN_NAME}" OR "{COLUMN_NAME}" LIKE ' %' OR "{COLUMN_NAME}" LIKE '''%''' OR "{COLUMN_NAME}" LIKE '"%"' GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT} error_type: Test Results + - id: '10040' + test_id: 1043 + test_type: Valid_Characters + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TABLE_NAME}" WHERE REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", CHR(160), ''), CHR(8203), ''), CHR(65279), ''), CHR(8239), ''), CHR(8201), ''), CHR(12288), ''), CHR(8204), '') <> "{COLUMN_NAME}" OR "{COLUMN_NAME}" LIKE ' %' OR "{COLUMN_NAME}" LIKE '''%''' OR "{COLUMN_NAME}" LIKE '"%"' GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Valid_Month.yaml b/testgen/template/dbsetup_test_types/test_types_Valid_Month.yaml index a5a8fbcd..fab14ff1 100644 --- a/testgen/template/dbsetup_test_types/test_types_Valid_Month.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Valid_Month.yaml @@ -117,5 +117,13 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10032' + test_type: Valid_Month + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN NULLIF({COLUMN_NAME}, '') NOT IN ({BASELINE_VALUE}) THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: [] test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip.yaml b/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip.yaml index e5225b67..e380caef 100644 --- a/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip.yaml @@ -116,6 +116,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10033' + test_type: Valid_US_Zip + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN REGEXP_REPLACE({COLUMN_NAME}, '[0-9]', '9', 'g') NOT IN ('99999', '999999999', '99999-9999') THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1398' test_id: '1044' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE_REGEXPR('[0-9]' IN "{COLUMN_NAME}" WITH '9') NOT IN ('99999', '999999999', '99999-9999') GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT} error_type: Test Results + - id: '10041' + test_id: 1044 + test_type: Valid_US_Zip + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TABLE_NAME}" WHERE REGEXP_REPLACE("{COLUMN_NAME}", '[0-9]', '9', 'g') NOT IN ('99999', '999999999', '99999-9999') GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip3.yaml b/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip3.yaml index 5d174ae7..45218af9 100644 --- a/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip3.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip3.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10034' + test_type: Valid_US_Zip3 + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN REGEXP_REPLACE({COLUMN_NAME}, '[0-9]', '9', 'g') <> '999' THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1399' test_id: '1045' @@ -195,4 +203,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE_REGEXPR('[0-9]' IN "{COLUMN_NAME}" WITH '9') <> '999' GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT} error_type: Test Results + - id: '10042' + test_id: 1045 + test_type: Valid_US_Zip3 + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TABLE_NAME}" WHERE REGEXP_REPLACE("{COLUMN_NAME}", '[0-9]', '9', 'g') <> '999' GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Variability_Decrease.yaml b/testgen/template/dbsetup_test_types/test_types_Variability_Decrease.yaml index dda3e907..bb671fd8 100644 --- a/testgen/template/dbsetup_test_types/test_types_Variability_Decrease.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Variability_Decrease.yaml @@ -122,6 +122,14 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10035' + test_type: Variability_Decrease + sql_flavor: salesforce_data360 + measure: |- + 100.0*STDDEV(CAST({COLUMN_NAME} AS FLOAT))/CAST({BASELINE_SD} AS FLOAT) + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1395' test_id: '1041' @@ -196,4 +204,12 @@ test_types: lookup_query: |- SELECT STDDEV(CAST("{COLUMN_NAME}" AS DECIMAL)) as current_standard_deviation FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10043' + test_id: 1041 + test_type: Variability_Decrease + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT STDDEV(CAST("{COLUMN_NAME}" AS FLOAT)) as current_standard_deviation FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Variability_Increase.yaml b/testgen/template/dbsetup_test_types/test_types_Variability_Increase.yaml index 73b0b48d..54e11245 100644 --- a/testgen/template/dbsetup_test_types/test_types_Variability_Increase.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Variability_Increase.yaml @@ -126,6 +126,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10036' + test_type: Variability_Increase + sql_flavor: salesforce_data360 + measure: |- + 100.0*STDDEV(CAST({COLUMN_NAME} AS FLOAT))/CAST({BASELINE_SD} AS FLOAT) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1394' test_id: '1040' @@ -200,4 +208,12 @@ test_types: lookup_query: |- SELECT STDDEV(CAST("{COLUMN_NAME}" AS DECIMAL)) as current_standard_deviation FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10044' + test_id: 1040 + test_type: Variability_Increase + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT STDDEV(CAST("{COLUMN_NAME}" AS FLOAT)) as current_standard_deviation FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Volume_Trend.yaml b/testgen/template/dbsetup_test_types/test_types_Volume_Trend.yaml index 521688f6..67c9ff29 100644 --- a/testgen/template/dbsetup_test_types/test_types_Volume_Trend.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Volume_Trend.yaml @@ -107,6 +107,14 @@ test_types: test_operator: NOT BETWEEN test_condition: |- {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} + - id: '10037' + test_type: Volume_Trend + sql_flavor: salesforce_data360 + measure: |- + {CUSTOM_QUERY} + test_operator: NOT BETWEEN + test_condition: |- + {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} target_data_lookups: - id: '1477' test_id: '1513' @@ -207,4 +215,15 @@ test_types: {UPPER_TOLERANCE} AS upper_bound FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10045' + test_id: 1513 + test_type: Volume_Trend + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT {CUSTOM_QUERY} AS current_count, + {LOWER_TOLERANCE} AS lower_bound, + {UPPER_TOLERANCE} AS upper_bound + FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Weekly_Rec_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Weekly_Rec_Ct.yaml index 73a115dc..bf0e91df 100644 --- a/testgen/template/dbsetup_test_types/test_types_Weekly_Rec_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Weekly_Rec_Ct.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10038' + test_type: Weekly_Rec_Ct + sql_flavor: salesforce_data360 + measure: |- + MAX(DATEDIFF('week', CAST('1800-01-01' AS DATE), CAST({COLUMN_NAME} AS DATE))) - MIN(DATEDIFF('week', CAST('1800-01-01' AS DATE), CAST({COLUMN_NAME} AS DATE)))+1 - COUNT(DISTINCT DATEDIFF('week', CAST('1800-01-01' AS DATE), CAST({COLUMN_NAME} AS DATE))) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1393' test_id: '1037' @@ -259,4 +267,12 @@ test_types: lookup_query: |- WITH Pass0 AS (SELECT 1 C FROM DUMMY UNION ALL SELECT 1 FROM DUMMY), Pass1 AS (SELECT 1 C FROM Pass0 A, Pass0 B), Pass2 AS (SELECT 1 C FROM Pass1 A, Pass1 B), Pass3 AS (SELECT 1 C FROM Pass2 A, Pass2 B), nums AS (SELECT ROW_NUMBER() OVER (ORDER BY C) - 1 AS rn FROM Pass3), bounds AS (SELECT ADD_DAYS(CAST(MIN("{COLUMN_NAME}") AS DATE), -WEEKDAY(CAST(MIN("{COLUMN_NAME}") AS DATE))) AS min_week, ADD_DAYS(CAST(MAX("{COLUMN_NAME}") AS DATE), -WEEKDAY(CAST(MAX("{COLUMN_NAME}") AS DATE))) AS max_week FROM "{TARGET_SCHEMA}"."{TABLE_NAME}"), daterange AS (SELECT ADD_DAYS(b.min_week, n.rn * 7) AS all_dates FROM bounds b, nums n WHERE ADD_DAYS(b.min_week, n.rn * 7) <= b.max_week), existing_periods AS (SELECT DISTINCT ADD_DAYS(CAST("{COLUMN_NAME}" AS DATE), -WEEKDAY(CAST("{COLUMN_NAME}" AS DATE))) AS period, COUNT(1) AS period_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY ADD_DAYS(CAST("{COLUMN_NAME}" AS DATE), -WEEKDAY(CAST("{COLUMN_NAME}" AS DATE)))) SELECT p.missing_period, p.prior_available_week, e.period_count AS prior_available_week_count, p.next_available_week, f.period_count AS next_available_week_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_week, MIN(c.period) AS next_available_week FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_week = e.period) LEFT JOIN existing_periods f ON (p.next_available_week = f.period) ORDER BY p.missing_period LIMIT {LIMIT} error_type: Test Results + - id: '10046' + test_id: 1037 + test_type: Weekly_Rec_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH RECURSIVE daterange(all_dates) AS (SELECT CAST(DATE_TRUNC('week', MIN("{COLUMN_NAME}")) AS DATE) AS all_dates FROM "{TABLE_NAME}" UNION ALL SELECT CAST((d.all_dates + INTERVAL '1 week' ) AS DATE) AS all_dates FROM daterange d WHERE d.all_dates < (SELECT CAST(DATE_TRUNC('week' , MAX("{COLUMN_NAME}")) AS DATE) FROM "{TABLE_NAME}") ), existing_periods AS (SELECT DISTINCT CAST(DATE_TRUNC('week', "{COLUMN_NAME}") AS DATE) AS period, COUNT(1) as period_count FROM "{TABLE_NAME}" GROUP BY CAST(DATE_TRUNC('week', "{COLUMN_NAME}") AS DATE)) SELECT p.missing_period, p.prior_available_week, e.period_count AS prior_available_week_count, p.next_available_week, f.period_count AS next_available_week_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_week, MIN(c.period) AS next_available_week FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_week = e.period) LEFT JOIN existing_periods f ON (p.next_available_week = f.period) ORDER BY p.missing_period LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbupgrade/0188_incremental_upgrade.sql b/testgen/template/dbupgrade/0188_incremental_upgrade.sql new file mode 100644 index 00000000..90292af4 --- /dev/null +++ b/testgen/template/dbupgrade/0188_incremental_upgrade.sql @@ -0,0 +1,26 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +-- Drop the unused `args` column from job_schedules and job_executions. +-- It's vestigial: exec_job dispatches via handler(**je.kwargs); no path reads args. +-- The job_schedules UNIQUE constraint includes args, so resolve and drop it dynamically +-- (the auto-generated PG constraint name varies with truncation). + +DO $$ +DECLARE c_name TEXT; +BEGIN + SELECT conname INTO c_name + FROM pg_constraint + WHERE conrelid = 'job_schedules'::regclass + AND contype = 'u' + AND conkey @> ARRAY[(SELECT attnum FROM pg_attribute WHERE attrelid = 'job_schedules'::regclass AND attname = 'args')]; + IF c_name IS NOT NULL THEN + EXECUTE format('ALTER TABLE job_schedules DROP CONSTRAINT %I', c_name); + END IF; +END $$; + +ALTER TABLE job_schedules DROP COLUMN args; + +ALTER TABLE job_schedules + ADD CONSTRAINT job_schedules_uniq UNIQUE (project_code, key, kwargs, cron_expr, cron_tz); + +ALTER TABLE job_executions DROP COLUMN args; diff --git a/testgen/template/dbupgrade/0189_incremental_upgrade.sql b/testgen/template/dbupgrade/0189_incremental_upgrade.sql new file mode 100644 index 00000000..96227490 --- /dev/null +++ b/testgen/template/dbupgrade/0189_incremental_upgrade.sql @@ -0,0 +1,48 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +-- Loosen fn_eval's numeric token pattern to accept leading-dot decimals +-- (e.g. ".733"). Oracle's NUMBER -> VARCHAR2 conversion drops the leading +-- zero for |x| < 1, and the value flows verbatim into test_results.result_measure +-- (VARCHAR), so the DQ scoring prevalence formula like +-- 2.0 * (1.0 - fn_normal_cdf(ABS({RESULT_MEASURE}::FLOAT) / 2.0)) +-- fed ".733..." to fn_eval, which rejected it as "invalid token \".\"". + +CREATE OR REPLACE FUNCTION {SCHEMA_NAME}.fn_eval(expression TEXT) RETURNS FLOAT +AS +$$ +DECLARE + result FLOAT; + invalid_parts TEXT; +BEGIN + -- Check the modified expression for invalid characters, allowing colons + IF expression ~* E'[^0-9+\\-*/(),.\\sA-Z_:e\\\'"]' THEN + RAISE EXCEPTION 'Invalid characters detected in expression: %', expression; + END IF; + + -- Check for dangerous PostgreSQL-specific keywords + IF expression ~* E'\b(DROP|ALTER|INSERT|UPDATE|DELETE|TRUNCATE|GRANT|REVOKE|COPY|EXECUTE|CREATE|COMMENT|SECURITY|WITH|SET ROLE|SET SESSION|DO|CALL|--|/\\*|;|pg_read_file|pg_write_file|pg_terminate_backend)\b' THEN + RAISE EXCEPTION 'Invalid expression: dangerous statement detected'; + END IF; + + -- Remove all allowed tokens from the validation expression, treating 'FLOAT' as a keyword. + -- Numeric pattern accepts leading-dot decimals (e.g. ".733") that Oracle emits + -- when converting NUMBER values with |x| < 1 to VARCHAR2. + invalid_parts := regexp_replace( + expression, + E'(\\mGREATEST|LEAST|ABS|FN_NORMAL_CDF|DATEDIFF|DAY|FLOAT|NULLIF)\\M|([0-9]+\\.?[0-9]*|\\.[0-9]+)([eE][+-]?[0-9]+)?|[+\\-*/(),\\\'":]+|\\s+', + '', + 'gi' + ); + + -- If anything is left in the validation expression, it's invalid + IF invalid_parts <> '' THEN + RAISE EXCEPTION 'Invalid expression contains invalid tokens "%" in expression: %', invalid_parts, expression; + END IF; + + -- Use the original expression (with ::FLOAT) for execution + EXECUTE format('SELECT (%s)::FLOAT', expression) INTO result; + + RETURN result; +END; +$$ +LANGUAGE plpgsql; diff --git a/testgen/template/dbupgrade/0190_incremental_upgrade.sql b/testgen/template/dbupgrade/0190_incremental_upgrade.sql new file mode 100644 index 00000000..a44bf9af --- /dev/null +++ b/testgen/template/dbupgrade/0190_incremental_upgrade.sql @@ -0,0 +1,4 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +-- Widen project_user to accommodate Salesforce Data 360 Consumer Keys (86+ chars) +ALTER TABLE connections ALTER COLUMN project_user TYPE VARCHAR(256); diff --git a/testgen/template/dbupgrade/0191_incremental_upgrade.sql b/testgen/template/dbupgrade/0191_incremental_upgrade.sql new file mode 100644 index 00000000..84f389c7 --- /dev/null +++ b/testgen/template/dbupgrade/0191_incremental_upgrade.sql @@ -0,0 +1,27 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +-- Add data retention settings to projects. +-- Existing projects start disabled (NULL days); new projects default to enabled at 180 days, +-- enforced via ALTER COLUMN SET DEFAULT after the initial backfill. + +ALTER TABLE projects + ADD COLUMN IF NOT EXISTS data_retention_enabled BOOLEAN NOT NULL DEFAULT FALSE; + +ALTER TABLE projects + ALTER COLUMN data_retention_enabled SET DEFAULT TRUE; + +ALTER TABLE projects + ADD COLUMN IF NOT EXISTS data_retention_days INTEGER; + +ALTER TABLE projects + ALTER COLUMN data_retention_days SET DEFAULT 180; + +-- Indexes supporting data retention sweeps. +-- profiling_runs: retention filters by (project_code, profiling_starttime). +CREATE INDEX IF NOT EXISTS ix_prun_pc_starttime + ON profiling_runs(project_code, profiling_starttime); + +-- job_executions: supports retention queries filtering by +-- (project_code, completed_at). +CREATE INDEX IF NOT EXISTS idx_job_executions_project_completed + ON job_executions(project_code, completed_at); diff --git a/testgen/template/dbupgrade/0192_incremental_upgrade.sql b/testgen/template/dbupgrade/0192_incremental_upgrade.sql new file mode 100644 index 00000000..ce8c1158 --- /dev/null +++ b/testgen/template/dbupgrade/0192_incremental_upgrade.sql @@ -0,0 +1,4 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +ALTER TABLE auth_users + ADD COLUMN IF NOT EXISTS preferences JSONB NOT NULL DEFAULT '{}'; diff --git a/testgen/template/execution/get_current_freshness_signal.sql b/testgen/template/execution/get_current_freshness_signal.sql new file mode 100644 index 00000000..962d6f5c --- /dev/null +++ b/testgen/template/execution/get_current_freshness_signal.sql @@ -0,0 +1,12 @@ +-- Latest Freshness_Trend result_signal for a given table within the current run. Used +-- by Volume_Trend / Metric_Trend execution to detect whether the table has been updated +-- this run: result_signal = '0' means fingerprint changed, any other value means no +-- change (signal carries the interval-since-last-update). +SELECT result_signal +FROM test_results +WHERE test_run_id = :TEST_RUN_ID ::UUID + AND test_type = 'Freshness_Trend' + AND schema_name = :SCHEMA_NAME + AND table_name = :TABLE_NAME +ORDER BY test_time DESC +LIMIT 1; diff --git a/testgen/template/flavors/bigquery/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/bigquery/gen_query_tests/gen_Freshness_Trend.sql index ed6c227c..84944de7 100644 --- a/testgen/template/flavors/bigquery/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/bigquery/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/flavors/databricks/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/databricks/gen_query_tests/gen_Freshness_Trend.sql index aa9d2a87..a057fa6f 100644 --- a/testgen/template/flavors/databricks/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/databricks/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/flavors/mssql/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/mssql/gen_query_tests/gen_Freshness_Trend.sql index a14dc9a4..b8d45adb 100644 --- a/testgen/template/flavors/mssql/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/mssql/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/flavors/oracle/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/oracle/gen_query_tests/gen_Freshness_Trend.sql index 05724f8f..a1e753ef 100644 --- a/testgen/template/flavors/oracle/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/oracle/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Dupe_Rows.sql b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Dupe_Rows.sql new file mode 100644 index 00000000..8ad665ae --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Dupe_Rows.sql @@ -0,0 +1,55 @@ +WITH latest_run AS ( + -- Latest complete profiling run before as-of-date + SELECT MAX(run_date) AS last_run_date + FROM profile_results + WHERE table_groups_id = :TABLE_GROUPS_ID ::UUID + AND run_date::DATE <= :AS_OF_DATE ::DATE +), +selected_tables AS ( + SELECT profile_run_id, schema_name, table_name, + STRING_AGG(:QUOTE || column_name || :QUOTE, ', ' ORDER BY position) AS groupby_names + FROM profile_results p + INNER JOIN latest_run lr ON p.run_date = lr.last_run_date + WHERE table_groups_id = :TABLE_GROUPS_ID ::UUID + -- Skip X types - Hyper does not support GROUP BY on JSON columns (and VARBINARY by extension) + AND general_type <> 'X' + GROUP BY profile_run_id, schema_name, table_name +) +INSERT INTO test_definitions ( + table_groups_id, test_suite_id, test_type, + schema_name, table_name, + test_active, last_auto_gen_date, profiling_as_of_date, profile_run_id, + groupby_names, skip_errors +) +SELECT + :TABLE_GROUPS_ID ::UUID AS table_groups_id, + :TEST_SUITE_ID ::UUID AS test_suite_id, + 'Dupe_Rows' AS test_type, + s.schema_name, + s.table_name, + 'Y' AS test_active, + :RUN_DATE ::TIMESTAMP AS last_auto_gen_date, + :AS_OF_DATE ::TIMESTAMP AS profiling_as_of_date, + s.profile_run_id, + s.groupby_names, + 0 AS skip_errors +FROM selected_tables s + -- Only insert if test type is active +WHERE EXISTS (SELECT 1 FROM test_types WHERE test_type = 'Dupe_Rows' AND active = 'Y') + -- Only insert if test type is included in generation set + AND EXISTS (SELECT 1 FROM generation_sets WHERE test_type = 'Dupe_Rows' AND generation_set = :GENERATION_SET) + +-- Match "uix_td_autogen_table" unique index exactly +ON CONFLICT (test_suite_id, test_type, schema_name, table_name) +WHERE last_auto_gen_date IS NOT NULL + AND table_name IS NOT NULL + AND column_name IS NULL + +-- Update tests if they already exist +DO UPDATE SET + test_active = EXCLUDED.test_active, + last_auto_gen_date = EXCLUDED.last_auto_gen_date, + groupby_names = EXCLUDED.groupby_names, + skip_errors = EXCLUDED.skip_errors +-- Ignore locked tests +WHERE test_definitions.lock_refresh = 'N'; diff --git a/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql new file mode 100644 index 00000000..6abcd6d8 --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql @@ -0,0 +1,212 @@ +WITH latest_run AS ( + -- Latest complete profiling run before as-of-date + SELECT MAX(run_date) AS last_run_date + FROM profile_results + WHERE table_groups_id = :TABLE_GROUPS_ID ::UUID + AND run_date::DATE <= :AS_OF_DATE ::DATE +), +latest_results AS ( + -- Column results for latest run + SELECT p.profile_run_id, p.schema_name, p.table_name, p.column_name, + p.functional_data_type, p.general_type, + p.distinct_value_ct, p.record_ct, p.null_value_ct, + p.max_value, p.min_value, p.avg_value, p.stdev_value + FROM profile_results p + INNER JOIN latest_run lr ON p.run_date = lr.last_run_date + INNER JOIN data_table_chars dtc ON ( + dtc.table_groups_id = p.table_groups_id + AND dtc.schema_name = p.schema_name + AND dtc.table_name = p.table_name + -- Ignore dropped tables + AND dtc.drop_date IS NULL + ) + INNER JOIN data_column_chars dcc ON ( + dcc.table_groups_id = p.table_groups_id + AND dcc.schema_name = p.schema_name + AND dcc.table_name = p.table_name + AND dcc.column_name = p.column_name + -- Ignore dropped columns + AND dcc.drop_date IS NULL + ) + WHERE p.table_groups_id = :TABLE_GROUPS_ID ::UUID +), +-- IDs - TOP 2 +id_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY + CASE + WHEN functional_data_type ILIKE 'ID-Unique%' THEN 1 + WHEN functional_data_type = 'ID-Secondary' THEN 2 + ELSE 3 + END, distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'ID%' +), +-- Process Date - TOP 1 +process_date_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY + CASE + WHEN column_name ILIKE '%mod%' THEN 1 + WHEN column_name ILIKE '%up%' THEN 1 + WHEN column_name ILIKE '%cr%' THEN 2 + WHEN column_name ILIKE '%in%' THEN 2 + END, distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'process%' +), +-- Transaction Date - TOP 1 +tran_date_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) +), +-- Numeric Measures +numeric_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, +/* + -- Subscores + distinct_value_ct * 1.0 / NULLIF(record_ct, 0) AS cardinality_score, + (max_value - min_value) / NULLIF(ABS(NULLIF(avg_value, 0)), 1) AS range_score, + LEAST(1, LOG(GREATEST(distinct_value_ct, 2))) / LOG(GREATEST(record_ct, 2)) AS nontriviality_score, + stdev_value / NULLIF(ABS(NULLIF(avg_value, 0)), 1) AS variability_score, + 1.0 - (null_value_ct * 1.0 / NULLIF(NULLIF(record_ct, 0), 1)) AS null_penalty, +*/ + -- Weighted score + ( + 0.25 * (distinct_value_ct * 1.0 / NULLIF(record_ct, 0)) + + 0.15 * ((max_value - min_value) / NULLIF(ABS(NULLIF(avg_value, 0)), 1)) + + 0.10 * (LEAST(1, LOG(GREATEST(distinct_value_ct, 2))) / LOG(GREATEST(record_ct, 2))) + + 0.40 * (stdev_value / NULLIF(ABS(NULLIF(avg_value, 0)), 1)) + + 0.10 * (1.0 - (null_value_ct * 1.0 / NULLIF(NULLIF(record_ct, 0), 1))) + ) AS change_detection_score + FROM latest_results + WHERE general_type = 'N' + AND ( + functional_data_type ILIKE 'Measure%' + OR functional_data_type IN ('Sequence', 'Constant') + ) +), +numeric_cols_ranked AS ( + SELECT *, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY change_detection_score DESC, column_name + ) AS rank + FROM numeric_cols + WHERE change_detection_score IS NOT NULL +), +combined AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + 'ID' AS element_type, general_type, 10 + rank AS fingerprint_order + FROM id_cols + WHERE rank <= 2 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'DATE_P' AS element_type, general_type, 20 + rank AS fingerprint_order + FROM process_date_cols + WHERE rank = 1 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'DATE_T' AS element_type, general_type, 30 + rank AS fingerprint_order + FROM tran_date_cols + WHERE rank = 1 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'MEAS' AS element_type, general_type, 40 + rank AS fingerprint_order + FROM numeric_cols_ranked + WHERE rank = 1 +), +selected_tables AS ( + SELECT profile_run_id, schema_name, table_name, + STRING_AGG(column_name, ',' ORDER BY element_type, fingerprint_order, column_name) AS column_names, + 'CAST(COUNT(*) AS VARCHAR) || ''|'' || ' || + STRING_AGG( + REPLACE( + CASE + WHEN general_type = 'D' THEN 'CAST(MIN(@@@) AS VARCHAR) || ''|'' || CAST(MAX(@@@) AS VARCHAR) || ''|'' || CAST(COUNT(DISTINCT @@@) AS VARCHAR)' + WHEN general_type = 'A' THEN 'CAST(MIN(@@@) AS VARCHAR) || ''|'' || CAST(MAX(@@@) AS VARCHAR) || ''|'' || CAST(COUNT(DISTINCT @@@) AS VARCHAR) || ''|'' || CAST(SUM(LENGTH(@@@)) AS VARCHAR)' + WHEN general_type = 'N' THEN 'CAST(COUNT(@@@) AS VARCHAR) || ''|'' || + CAST(COUNT(DISTINCT MOD(CAST(CAST(COALESCE(@@@,0) AS DECIMAL(38,6)) * 1000000 AS DECIMAL(38,0)), 1000003)) AS VARCHAR) || ''|'' || + COALESCE(CAST(CAST(MIN(@@@) AS DECIMAL(38,6)) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(CAST(MAX(@@@) AS DECIMAL(38,6)) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(MOD(COALESCE(SUM(MOD(CAST(ABS(COALESCE(@@@,0)) AS DECIMAL(38,6)) * 1000000, 1000000007)), 0), 1000000007) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(MOD(COALESCE(SUM(MOD(CAST(ABS(COALESCE(@@@,0)) AS DECIMAL(38,6)) * 1000000, 1000000009)), 0), 1000000009) AS VARCHAR), '''')' + END, + '@@@', '"' || column_name || '"' + ), + ' || ''|'' || ' + ORDER BY element_type, fingerprint_order, column_name + ) AS fingerprint + FROM combined + GROUP BY profile_run_id, schema_name, table_name +) +-- Insert tests for selected tables +INSERT INTO test_definitions ( + table_groups_id, test_suite_id, test_type, + schema_name, table_name, groupby_names, + test_active, last_auto_gen_date, profiling_as_of_date, profile_run_id, + history_calculation, history_lookback, custom_query +) +SELECT + :TABLE_GROUPS_ID ::UUID AS table_groups_id, + :TEST_SUITE_ID ::UUID AS test_suite_id, + 'Freshness_Trend' AS test_type, + s.schema_name, + s.table_name, + s.column_names AS groupby_names, + 'Y' AS test_active, + :RUN_DATE ::TIMESTAMP AS last_auto_gen_date, + :AS_OF_DATE ::TIMESTAMP AS profiling_as_of_date, + s.profile_run_id, + 'PREDICT' AS history_calculation, + NULL AS history_lookback, + s.fingerprint AS custom_query +FROM selected_tables s + -- Only insert if test type is active +WHERE EXISTS (SELECT 1 FROM test_types WHERE test_type = 'Freshness_Trend' AND active = 'Y') + -- Only insert if test type is included in generation set + AND EXISTS (SELECT 1 FROM generation_sets WHERE test_type = 'Freshness_Trend' AND generation_set = :GENERATION_SET) + {TABLE_FILTER} + +-- Match "uix_td_autogen_table" unique index exactly +ON CONFLICT (test_suite_id, test_type, schema_name, table_name) +WHERE last_auto_gen_date IS NOT NULL + AND table_name IS NOT NULL + AND column_name IS NULL + +-- Update tests if they already exist +DO UPDATE SET + groupby_names = EXCLUDED.groupby_names, + test_active = EXCLUDED.test_active, + last_auto_gen_date = EXCLUDED.last_auto_gen_date, + profiling_as_of_date = EXCLUDED.profiling_as_of_date, + profile_run_id = EXCLUDED.profile_run_id, + history_calculation = EXCLUDED.history_calculation, + history_lookback = EXCLUDED.history_lookback, + custom_query = EXCLUDED.custom_query +-- Ignore locked tests +WHERE test_definitions.lock_refresh = 'N' + -- Don't update existing tests in "insert" mode + AND NOT COALESCE(:INSERT_ONLY, FALSE); diff --git a/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Table_Freshness.sql b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Table_Freshness.sql new file mode 100644 index 00000000..0fca64f2 --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Table_Freshness.sql @@ -0,0 +1,189 @@ +WITH latest_run AS ( + -- Latest complete profiling run before as-of-date + SELECT MAX(run_date) AS last_run_date + FROM profile_results + WHERE table_groups_id = :TABLE_GROUPS_ID ::UUID + AND run_date::DATE <= :AS_OF_DATE ::DATE +), +latest_results AS ( + -- Column results for latest run + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, + distinct_value_ct, record_ct, null_value_ct, + max_value, min_value, avg_value, stdev_value + FROM profile_results p + INNER JOIN latest_run lr ON p.run_date = lr.last_run_date + WHERE table_groups_id = :TABLE_GROUPS_ID ::UUID +), +-- IDs - TOP 2 +id_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY + CASE + WHEN functional_data_type ILIKE 'ID-Unique%' THEN 1 + WHEN functional_data_type = 'ID-Secondary' THEN 2 + ELSE 3 + END, distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'ID%' +), +-- Process Date - TOP 1 +process_date_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY + CASE + WHEN column_name ILIKE '%mod%' THEN 1 + WHEN column_name ILIKE '%up%' THEN 1 + WHEN column_name ILIKE '%cr%' THEN 2 + WHEN column_name ILIKE '%in%' THEN 2 + END, distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'process%' +), +-- Transaction Date - TOP 1 +tran_date_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' +), +-- Numeric Measures +numeric_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, +/* + -- Subscores + distinct_value_ct * 1.0 / NULLIF(record_ct, 0) AS cardinality_score, + (max_value - min_value) / NULLIF(ABS(NULLIF(avg_value, 0)), 1) AS range_score, + LEAST(1, LOG(GREATEST(distinct_value_ct, 2))) / LOG(GREATEST(record_ct, 2)) AS nontriviality_score, + stdev_value / NULLIF(ABS(NULLIF(avg_value, 0)), 1) AS variability_score, + 1.0 - (null_value_ct * 1.0 / NULLIF(NULLIF(record_ct, 0), 1)) AS null_penalty, +*/ + -- Weighted score + ( + 0.25 * (distinct_value_ct * 1.0 / NULLIF(record_ct, 0)) + + 0.15 * ((max_value - min_value) / NULLIF(ABS(NULLIF(avg_value, 0)), 1)) + + 0.10 * (LEAST(1, LOG(GREATEST(distinct_value_ct, 2))) / LOG(GREATEST(record_ct, 2))) + + 0.40 * (stdev_value / NULLIF(ABS(NULLIF(avg_value, 0)), 1)) + + 0.10 * (1.0 - (null_value_ct * 1.0 / NULLIF(NULLIF(record_ct, 0), 1))) + ) AS change_detection_score + FROM latest_results + WHERE general_type = 'N' + AND ( + functional_data_type ILIKE 'Measure%' + OR functional_data_type IN ('Sequence', 'Constant') + ) +), +numeric_cols_ranked AS ( + SELECT *, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY change_detection_score DESC, column_name + ) AS rank + FROM numeric_cols + WHERE change_detection_score IS NOT NULL +), +combined AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + 'ID' AS element_type, general_type, 10 + rank AS fingerprint_order + FROM id_cols + WHERE rank <= 2 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'DATE_P' AS element_type, general_type, 20 + rank AS fingerprint_order + FROM process_date_cols + WHERE rank = 1 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'DATE_T' AS element_type, general_type, 30 + rank AS fingerprint_order + FROM tran_date_cols + WHERE rank = 1 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'MEAS' AS element_type, general_type, 40 + rank AS fingerprint_order + FROM numeric_cols_ranked + WHERE rank = 1 +), +selected_tables AS ( + SELECT profile_run_id, schema_name, table_name, + 'CAST(COUNT(*) AS VARCHAR) || ''|'' || ' || + STRING_AGG( + REPLACE( + CASE + WHEN general_type = 'D' THEN 'CAST(MIN(@@@) AS VARCHAR) || ''|'' || CAST(MAX(@@@) AS VARCHAR) || ''|'' || CAST(COUNT(DISTINCT @@@) AS VARCHAR)' + WHEN general_type = 'A' THEN 'CAST(MIN(@@@) AS VARCHAR) || ''|'' || CAST(MAX(@@@) AS VARCHAR) || ''|'' || CAST(COUNT(DISTINCT @@@) AS VARCHAR) || ''|'' || CAST(SUM(LENGTH(@@@)) AS VARCHAR)' + WHEN general_type = 'N' THEN 'CAST(COUNT(@@@) AS VARCHAR) || ''|'' || + CAST(COUNT(DISTINCT MOD(CAST(CAST(COALESCE(@@@,0) AS DECIMAL(38,6)) * 1000000 AS DECIMAL(38,0)), 1000003)) AS VARCHAR) || ''|'' || + COALESCE(CAST(CAST(MIN(@@@) AS DECIMAL(38,6)) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(CAST(MAX(@@@) AS DECIMAL(38,6)) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(MOD(COALESCE(SUM(MOD(CAST(ABS(COALESCE(@@@,0)) AS DECIMAL(38,6)) * 1000000, 1000000007)), 0), 1000000007) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(MOD(COALESCE(SUM(MOD(CAST(ABS(COALESCE(@@@,0)) AS DECIMAL(38,6)) * 1000000, 1000000009)), 0), 1000000009) AS VARCHAR), '''')' + END, + '@@@', '"' || column_name || '"' + ), + ' || ''|'' || ' + ORDER BY element_type, fingerprint_order, column_name + ) AS fingerprint + FROM combined + GROUP BY profile_run_id, schema_name, table_name +) +-- Insert tests for selected tables +INSERT INTO test_definitions ( + table_groups_id, test_suite_id, test_type, + schema_name, table_name, + test_active, last_auto_gen_date, profiling_as_of_date, profile_run_id, + history_calculation, history_lookback, custom_query +) +SELECT + :TABLE_GROUPS_ID ::UUID AS table_groups_id, + :TEST_SUITE_ID ::UUID AS test_suite_id, + 'Table_Freshness' AS test_type, + s.schema_name, + s.table_name, + 'Y' AS test_active, + :RUN_DATE ::TIMESTAMP AS last_auto_gen_date, + :AS_OF_DATE ::TIMESTAMP AS profiling_as_of_date, + s.profile_run_id, + 'Value' AS history_calculation, + 1 AS history_lookback, + s.fingerprint AS custom_query +FROM selected_tables s + -- Only insert if test type is active +WHERE EXISTS (SELECT 1 FROM test_types WHERE test_type = 'Table_Freshness' AND active = 'Y') + -- Only insert if test type is included in generation set + AND EXISTS (SELECT 1 FROM generation_sets WHERE test_type = 'Table_Freshness' AND generation_set = :GENERATION_SET) + +-- Match "uix_td_autogen_table" unique index exactly +ON CONFLICT (test_suite_id, test_type, schema_name, table_name) +WHERE last_auto_gen_date IS NOT NULL + AND table_name IS NOT NULL + AND column_name IS NULL + +-- Update tests if they already exist +DO UPDATE SET + test_active = EXCLUDED.test_active, + last_auto_gen_date = EXCLUDED.last_auto_gen_date, + profiling_as_of_date = EXCLUDED.profiling_as_of_date, + profile_run_id = EXCLUDED.profile_run_id, + history_calculation = EXCLUDED.history_calculation, + history_lookback = EXCLUDED.history_lookback, + custom_query = EXCLUDED.custom_query +-- Ignore locked tests +WHERE test_definitions.lock_refresh = 'N'; diff --git a/testgen/template/flavors/salesforce_data360/profiling/project_profiling_query.sql b/testgen/template/flavors/salesforce_data360/profiling/project_profiling_query.sql new file mode 100644 index 00000000..a16aeb1c --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/profiling/project_profiling_query.sql @@ -0,0 +1,247 @@ +WITH target_table AS ( +-- TG-IF do_sample + SELECT * FROM "{DATA_TABLE}" ORDER BY RANDOM() LIMIT {SAMPLE_SIZE} +-- TG-ELSE + SELECT * FROM "{DATA_TABLE}" +-- TG-ENDIF +) +SELECT + {CONNECTION_ID} AS connection_id, + '{PROJECT_CODE}' AS project_code, + '{TABLE_GROUPS_ID}' AS table_groups_id, + '{DATA_SCHEMA}' AS schema_name, + '{RUN_DATE}' AS run_date, + '{DATA_TABLE}' AS table_name, + {COL_POS} AS position, + '{COL_NAME_SANITIZED}' AS column_name, + '{COL_TYPE}' AS column_type, + '{DB_DATA_TYPE}' AS db_data_type, + '{COL_GEN_TYPE}' AS general_type, + COUNT(*) AS record_ct, + COUNT("{COL_NAME}") AS value_ct, + COUNT(DISTINCT "{COL_NAME}") AS distinct_value_ct, + SUM(CASE WHEN "{COL_NAME}" IS NULL THEN 1 ELSE 0 END) AS null_value_ct, +-- TG-IF is_type_ADN + MIN(LENGTH(CAST("{COL_NAME}" AS VARCHAR))) AS min_length, + MAX(LENGTH(CAST("{COL_NAME}" AS VARCHAR))) AS max_length, + AVG(CAST(NULLIF(LENGTH(CAST("{COL_NAME}" AS VARCHAR)), 0) AS DOUBLE)) AS avg_length, +-- TG-ELSE + NULL AS min_length, + NULL AS max_length, + NULL AS avg_length, +-- TG-ENDIF +-- TG-IF is_type_A + SUM(CASE + WHEN REGEXP_LIKE(TRIM("{COL_NAME}"), '^0(\.0*)?$') THEN 1 ELSE 0 + END) AS zero_value_ct, +-- TG-ENDIF +-- TG-IF is_type_N + SUM(CASE WHEN CAST("{COL_NAME}" AS DOUBLE) = 0 THEN 1 ELSE 0 END) AS zero_value_ct, +-- TG-ENDIF +-- TG-IF is_not_A_not_N + NULL AS zero_value_ct, +-- TG-ENDIF +-- TG-IF is_type_A + COUNT(DISTINCT UPPER(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COL_NAME}", ' ', ''), '''', ''), ',', ''), '.', ''), '-', ''))) AS distinct_std_value_ct, + SUM(CASE + WHEN "{COL_NAME}" = '' THEN 1 + ELSE 0 + END) AS zero_length_ct, + SUM( CASE + WHEN "{COL_NAME}" BETWEEN ' !' AND '!' THEN 1 + ELSE 0 + END ) AS lead_space_ct, + SUM( CASE WHEN "{COL_NAME}" LIKE '"%"' OR "{COL_NAME}" LIKE '''%''' THEN 1 ELSE 0 END ) AS quoted_value_ct, + SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '[0-9]') THEN 1 ELSE 0 END ) AS includes_digit_ct, + SUM( CASE + WHEN REGEXP_LIKE(LOWER("{COL_NAME}"), '^(\.{1,}|-{1,}|\?{1,}|\s{1,}|0{2,}|9{2,}|x{2,}|z{2,})$') THEN 1 + WHEN LOWER("{COL_NAME}") IN ('blank','error','missing','tbd', + 'n/a','#na','none','null','unknown') THEN 1 + WHEN LOWER("{COL_NAME}") IN ('(blank)','(error)','(missing)','(tbd)', + '(n/a)','(#na)','(none)','(null)','(unknown)') THEN 1 + WHEN LOWER("{COL_NAME}") IN ('[blank]','[error]','[missing]','[tbd]', + '[n/a]','[#na]','[none]','[null]','[unknown]') THEN 1 + ELSE 0 + END ) AS filled_value_ct, + SUBSTR(MIN(NULLIF("{COL_NAME}", '')), 1, 100) AS min_text, + SUBSTR(MAX(NULLIF("{COL_NAME}", '')), 1, 100) AS max_text, + SUM(CASE + WHEN REGEXP_REPLACE("{COL_NAME}", '[A-Za-z]', '', 'g') = "{COL_NAME}" THEN 0 + WHEN REGEXP_REPLACE("{COL_NAME}", '[a-z]', '', 'g') = "{COL_NAME}" THEN 1 + ELSE 0 + END) AS upper_case_ct, + SUM(CASE + WHEN REGEXP_REPLACE("{COL_NAME}", '[A-Za-z]', '', 'g') = "{COL_NAME}" THEN 0 + WHEN REGEXP_REPLACE("{COL_NAME}", '[A-Z]', '', 'g') = "{COL_NAME}" THEN 1 + ELSE 0 + END) AS lower_case_ct, + SUM(CASE + WHEN REGEXP_REPLACE("{COL_NAME}", '[A-Za-z]', '', 'g') = "{COL_NAME}" THEN 1 + ELSE 0 + END) AS non_alpha_ct, + SUM(CASE WHEN REGEXP_REPLACE("{COL_NAME}", + '[' || CHR(160) || CHR(8201) || CHR(8203) || CHR(8204) || CHR(8205) || CHR(8206) || CHR(8207) || CHR(8239) || CHR(12288) || CHR(65279) || ']', + 'X', 'g') <> "{COL_NAME}" THEN 1 ELSE 0 END) AS non_printing_ct, + SUM(<%IS_NUM;SUBSTR("{COL_NAME}", 1, 31)%>) AS numeric_ct, + SUM(<%IS_DATE;SUBSTR("{COL_NAME}", 1, 26)%>) AS date_ct, + CASE + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^[0-9]{1,5}[a-zA-Z]?\s\w{1,5}\.?\s?\w*\s?\w*\s[a-zA-Z]{1,6}\.?\s?[0-9]{0,5}[A-Z]{0,1}$') + THEN 1 END) > CAST(0.8 * COUNT("{COL_NAME}") AS BIGINT) THEN 'STREET_ADDR' + WHEN SUM(CASE WHEN "{COL_NAME}" IN ('AL','AK','AS','AZ','AR','CA','CO','CT','DE','DC','FM','FL','GA','GU','HI','ID','IL','IN','IA','KS','KY','LA','ME','MH','MD','MA','MI','MN','MS','MO','MT','NE','NV','NH','NJ','NM','NY','NC','ND','MP','OH','OK','OR','PW','PA','PR','RI','SC','SD','TN','TX','UT','VT','VI','VA','WA','WV','WI','WY','AE','AP','AA') + THEN 1 END) > CAST(0.9 * COUNT("{COL_NAME}") AS BIGINT) THEN 'STATE_USA' + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^(\+1|1)?[ .\-]?(\([2-9][0-9]{2}\)|[2-9][0-9]{2})[ .\-]?[2-9][0-9]{2}[ .\-]?[0-9]{4}$') + THEN 1 END) > CAST(0.8 * COUNT("{COL_NAME}") AS BIGINT) THEN 'PHONE_USA' + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}$') + AND "{COL_NAME}" NOT LIKE '%://%' + THEN 1 END) > CAST(0.9 * COUNT("{COL_NAME}") AS BIGINT) THEN 'EMAIL' + WHEN SUM( CASE WHEN REGEXP_LIKE(REGEXP_REPLACE("{COL_NAME}", '[0-9]', '9', 'g'), '^(99999|999999999|99999-9999)$') + THEN 1 END) > CAST(0.9 * COUNT("{COL_NAME}") AS BIGINT) THEN 'ZIP_USA' + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^[\w\s\-]+\.(txt|csv|tsv|dat|doc|pdf|xlsx)$') + THEN 1 END) > CAST(0.9 * COUNT("{COL_NAME}") AS BIGINT) THEN 'FILE_NAME' + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^([0-9]{4}[- ]?){3}[0-9]{4}$') + THEN 1 END) > CAST(0.8 * COUNT("{COL_NAME}") AS BIGINT) THEN 'CREDIT_CARD' + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^([^,|\t]{1,20}[,|\t]){2,}[^,|\t]{0,20}([,|\t]{0,1}[^,|\t]{0,20})*$') + AND NOT REGEXP_LIKE("{COL_NAME}", '\s(and|but|or|yet)\s') + THEN 1 END) > CAST(0.8 * COUNT("{COL_NAME}") AS BIGINT) THEN 'DELIMITED_DATA' + WHEN SUM ( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^[0-8][0-9]{2}-[0-9]{2}-[0-9]{4}$') + AND SUBSTR("{COL_NAME}", 1, 3) NOT BETWEEN '734' AND '749' + AND SUBSTR("{COL_NAME}", 1, 3) <> '666' THEN 1 END) > CAST(0.9 * COUNT("{COL_NAME}") AS BIGINT) THEN 'SSN' + END AS std_pattern_match, +-- TG-ELSE + NULL AS distinct_std_value_ct, + NULL AS zero_length_ct, + NULL AS lead_space_ct, + NULL AS quoted_value_ct, + NULL AS includes_digit_ct, + NULL AS filled_value_ct, + NULL AS min_text, + NULL AS max_text, + NULL AS upper_case_ct, + NULL AS lower_case_ct, + NULL AS non_alpha_ct, + NULL AS non_printing_ct, + NULL AS numeric_ct, + NULL AS date_ct, + NULL AS std_pattern_match, +-- TG-ENDIF +-- TG-IF is_type_A + (SELECT SUBSTR(ARRAY_JOIN(ARRAY_AGG(pattern), ' | '), 1, 1000) AS concat_pats + FROM ( + SELECT CAST(COUNT(*) AS VARCHAR) || ' | ' || pattern AS pattern, + COUNT(*) AS ct + FROM ( SELECT REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( + "{COL_NAME}", '[a-z]', 'a', 'g'), + '[A-Z]', 'A', 'g'), + '[0-9]', 'N', 'g') AS pattern + FROM target_table + WHERE "{COL_NAME}" > ' ' AND (SELECT MAX(LENGTH("{COL_NAME}")) + FROM target_table) BETWEEN 3 and {MAX_PATTERN_LENGTH}) p + GROUP BY pattern + HAVING pattern > ' ' + ORDER BY COUNT(*) DESC + LIMIT 5 + ) ps) AS top_patterns, +-- TG-ELSE + NULL AS top_patterns, +-- TG-ENDIF +-- TG-IF is_type_N + MIN("{COL_NAME}") AS min_value, + MIN(CASE WHEN CAST("{COL_NAME}" AS DOUBLE) > 0 THEN "{COL_NAME}" ELSE NULL END) AS min_value_over_0, + MAX("{COL_NAME}") AS max_value, + AVG(CAST("{COL_NAME}" AS DOUBLE)) AS avg_value, + STDDEV(CAST("{COL_NAME}" AS DOUBLE)) AS stdev_value, + APPROX_PERCENTILE(CAST("{COL_NAME}" AS DOUBLE), 0.25) AS percentile_25, + APPROX_PERCENTILE(CAST("{COL_NAME}" AS DOUBLE), 0.50) AS percentile_50, + APPROX_PERCENTILE(CAST("{COL_NAME}" AS DOUBLE), 0.75) AS percentile_75, +-- TG-ELSE + NULL AS min_value, + NULL AS min_value_over_0, + NULL AS max_value, + NULL AS avg_value, + NULL AS stdev_value, + NULL AS percentile_25, + NULL AS percentile_50, + NULL AS percentile_75, +-- TG-ENDIF +-- TG-IF is_N_decimal + SUM(ROUND(ABS(MOD(CAST("{COL_NAME}" AS DOUBLE), 1)), 5)) AS fractional_sum, +-- TG-ELSE + NULL AS fractional_sum, +-- TG-ENDIF +-- TG-IF is_type_D + CASE + WHEN MIN("{COL_NAME}") IS NULL THEN NULL + ELSE GREATEST(MIN("{COL_NAME}"), CAST('0001-01-01' AS TIMESTAMP)) + END AS min_date, + MAX("{COL_NAME}") AS max_date, + SUM(CASE + WHEN DATEDIFF('month', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) > 12 THEN 1 + ELSE 0 + END) AS before_1yr_date_ct, + SUM(CASE + WHEN DATEDIFF('month', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) > 60 THEN 1 + ELSE 0 + END) AS before_5yr_date_ct, + SUM(CASE + WHEN DATEDIFF('month', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) > 240 THEN 1 + ELSE 0 + END) AS before_20yr_date_ct, + SUM(CASE + WHEN DATEDIFF('month', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) > 1200 THEN 1 + ELSE 0 + END) AS before_100yr_date_ct, + SUM(CASE + WHEN DATEDIFF('day', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) BETWEEN 0 AND 365 THEN 1 + ELSE 0 + END) AS within_1yr_date_ct, + SUM(CASE + WHEN DATEDIFF('day', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) BETWEEN 0 AND 30 THEN 1 + ELSE 0 + END) AS within_1mo_date_ct, + SUM(CASE + WHEN "{COL_NAME}" > CAST('{RUN_DATE}' AS TIMESTAMP) THEN 1 ELSE 0 + END) AS future_date_ct, + SUM(CASE + WHEN DATEDIFF('month', CAST('{RUN_DATE}' AS TIMESTAMP), CAST("{COL_NAME}" AS TIMESTAMP)) > 240 THEN 1 + ELSE 0 + END) AS distant_future_date_ct, + COUNT(DISTINCT DATEDIFF('day', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP))) AS date_days_present, + COUNT(DISTINCT DATEDIFF('week', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP))) AS date_weeks_present, + COUNT(DISTINCT DATEDIFF('month', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP))) AS date_months_present, +-- TG-ELSE + NULL AS min_date, + NULL AS max_date, + NULL AS before_1yr_date_ct, + NULL AS before_5yr_date_ct, + NULL AS before_20yr_date_ct, + NULL AS before_100yr_date_ct, + NULL AS within_1yr_date_ct, + NULL AS within_1mo_date_ct, + NULL AS future_date_ct, + NULL AS distant_future_date_ct, + NULL AS date_days_present, + NULL AS date_weeks_present, + NULL AS date_months_present, +-- TG-ENDIF +-- TG-IF is_type_B + SUM(CAST("{COL_NAME}" AS INTEGER)) AS boolean_true_ct, +-- TG-ELSE + NULL AS boolean_true_ct, +-- TG-ENDIF +-- TG-IF is_type_A + (SELECT COUNT(DISTINCT REGEXP_REPLACE( REGEXP_REPLACE( REGEXP_REPLACE( + "{COL_NAME}", '[a-z]', 'a', 'g'), + '[A-Z]', 'A', 'g'), + '[0-9]', 'N', 'g') + ) AS pattern_ct + FROM target_table + WHERE "{COL_NAME}" > ' ' ) AS distinct_pattern_ct, + SUM(CASE WHEN LENGTH(TRIM("{COL_NAME}")) - LENGTH(REGEXP_REPLACE(TRIM("{COL_NAME}"), ' ', '', 'g')) > 0 THEN 1 ELSE 0 END) AS embedded_space_ct, + AVG(CAST(LENGTH(TRIM("{COL_NAME}")) - LENGTH(REGEXP_REPLACE(TRIM("{COL_NAME}"), ' ', '', 'g')) AS DOUBLE)) AS avg_embedded_spaces, +-- TG-ELSE + NULL AS distinct_pattern_ct, + NULL AS embedded_space_ct, + NULL AS avg_embedded_spaces, +-- TG-ENDIF + '{PROFILE_RUN_ID}' AS profile_run_id + FROM target_table diff --git a/testgen/template/flavors/salesforce_data360/profiling/project_secondary_profiling_query.sql b/testgen/template/flavors/salesforce_data360/profiling/project_secondary_profiling_query.sql new file mode 100644 index 00000000..3e575d78 --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/profiling/project_secondary_profiling_query.sql @@ -0,0 +1,37 @@ +-- Get Freqs for selected columns +WITH target_table AS ( + SELECT * FROM "{DATA_TABLE}" +-- TG-IF do_sample_bool + ORDER BY RANDOM() LIMIT {SAMPLE_SIZE} +-- TG-ENDIF +), +ranked_vals AS ( + SELECT "{COL_NAME}", + COUNT(*) AS ct, + ROW_NUMBER() OVER (ORDER BY COUNT(*) DESC, "{COL_NAME}") AS rn + FROM target_table + WHERE "{COL_NAME}" > ' ' + GROUP BY "{COL_NAME}" +), +consol_vals AS ( + SELECT COALESCE(CASE WHEN rn <= 10 THEN '| ' || "{COL_NAME}" || ' | ' || CAST(ct AS VARCHAR) + ELSE NULL + END, '| Other Values (' || CAST(COUNT(DISTINCT "{COL_NAME}") as VARCHAR) || ') | ' || CAST(SUM(ct) as VARCHAR) ) AS val, + MIN(rn) as min_rn + FROM ranked_vals + GROUP BY CASE WHEN rn <= 10 THEN '| ' || "{COL_NAME}" || ' | ' || CAST(ct AS VARCHAR) + ELSE NULL + END +) +SELECT '{PROJECT_CODE}' as project_code, + '{DATA_SCHEMA}' as schema_name, + '{RUN_DATE}' as run_date, + '{DATA_TABLE}' as table_name, + '{COL_NAME}' as column_name, + REPLACE(ARRAY_JOIN(ARRAY_AGG(val), '^#^'), '^#^', CHR(10)) AS top_freq_values, + ( SELECT MD5(ARRAY_JOIN(ARRAY_AGG(v), '|')) as dvh + FROM (SELECT DISTINCT NULLIF("{COL_NAME}", '') AS v + FROM target_table + WHERE NULLIF("{COL_NAME}", '') IS NOT NULL + ORDER BY v) sorted_vals ) as distinct_value_hash + FROM (SELECT * FROM consol_vals ORDER BY min_rn LIMIT 11) ordered_vals; diff --git a/testgen/template/flavors/salesforce_data360/profiling/templated_functions.yaml b/testgen/template/flavors/salesforce_data360/profiling/templated_functions.yaml new file mode 100644 index 00000000..7ae06f79 --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/profiling/templated_functions.yaml @@ -0,0 +1,98 @@ +IS_NUM: CASE + WHEN REGEXP_LIKE({$1}, '^\s*[+-]?\$?\s*[0-9]+(,[0-9]{3})*(\.[0-9]*)?[%]?\s*$') THEN 1 + ELSE 0 + END + +IS_DATE: CASE + /* YYYY-MM-DD HH:MM:SS SSSSSS or YYYY-MM-DD HH:MM:SS */ + WHEN REGEXP_LIKE({$1}, '^(\d{4})-(0[1-9]|1[0-2])-(0[1-9]|[12][0-9]|3[01])\s(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\s[0-9]{6})?$') + THEN CASE + WHEN CAST(SUBSTR({$1}, 1, 4) AS INTEGER) BETWEEN 1800 AND 2200 + AND ( + ( SUBSTRING ({$1}, 6, 2) IN ('01', '03', '05', '07', '08', + '10', '12') + AND CAST(SUBSTRING ({$1}, 9, 2) AS INTEGER) BETWEEN 1 AND 31 ) + OR ( SUBSTRING ({$1}, 6, 2) IN ('04', '06', '09') + AND CAST(SUBSTRING ({$1}, 9, 2) AS INTEGER) BETWEEN 1 AND 30 ) + OR ( SUBSTRING ({$1}, 6, 2) = '02' + AND CAST(SUBSTRING ({$1}, 9, 2) AS INTEGER) BETWEEN 1 AND 29) + ) + THEN 1 + ELSE 0 + END + /* YYYYMMDDHHMMSSSSSS or YYYYMMDD */ + WHEN REGEXP_LIKE({$1}, '^(\d{4})(0[1-9]|1[0-2])(0[1-9]|[12][0-9]|3[01])(2[0-3]|[01][0-9])([0-5][0-9])([0-5][0-9])([0-9]{6})$') + OR REGEXP_LIKE({$1}, '^(\d{4})(0[1-9]|1[0-2])(0[1-9]|[12][0-9]|3[01])(2[0-3]|[01][0-9])$') + THEN CASE + WHEN CAST(SUBSTR({$1}, 1, 4) AS INTEGER) BETWEEN 1800 AND 2200 + AND ( + ( SUBSTRING({$1}, 5, 2) IN ('01', '03', '05', '07', '08', + '10', '12') + AND CAST(SUBSTRING({$1}, 7, 2) AS INTEGER) BETWEEN 1 AND 31 ) + OR ( SUBSTRING({$1}, 5, 2) IN ('04', '06', '09') + AND CAST(SUBSTRING({$1}, 7, 2) AS INTEGER) BETWEEN 1 AND 30 ) + OR ( SUBSTRING({$1}, 5, 2) = '02' + AND CAST(SUBSTRING({$1}, 7, 2) AS INTEGER) BETWEEN 1 AND 29) + ) + THEN 1 + ELSE 0 + END + /* Exclude anything else long */ + WHEN LENGTH({$1}) > 11 THEN 0 + /* YYYY-MMM/MM-DD */ + WHEN REGEXP_LIKE(REGEXP_REPLACE(UPPER({$1}), '(JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)', '12', 'g'), + '[12][09][0-9][0-9]-[0-1]?[0-9]-[0-3]?[0-9]') + THEN CASE + WHEN CAST(SPLIT_PART({$1}, '-', 1) AS INTEGER) BETWEEN 1800 AND 2200 + AND ( + ( UPPER(SPLIT_PART({$1}, '-', 2)) IN ('01', '03', '05', '07', '08', + '1', '3', '5', '7', '8', '10', '12', + 'JAN', 'MAR', 'MAY', 'JUL', 'AUG', + 'OCT', 'DEC') + AND CAST(SPLIT_PART({$1}, '-', 3) AS INTEGER) BETWEEN 1 AND 31 ) + OR ( UPPER(SPLIT_PART({$1}, '-', 2)) IN ('04', '06', '09', '4', '6', '9', '11', + 'APR', 'JUN', 'SEP', 'NOV') + AND CAST(SPLIT_PART({$1}, '-', 3) AS INTEGER) BETWEEN 1 AND 30 ) + OR ( UPPER(SPLIT_PART({$1}, '-', 2)) IN ('02', '2', 'FEB') + AND CAST(SPLIT_PART({$1}, '-', 3) AS INTEGER) BETWEEN 1 AND 29) + ) + THEN 1 + ELSE 0 + END + /* MM/-DD/-YY/YYYY */ + WHEN REGEXP_LIKE(REPLACE({$1}, '-', '/'), '^[0-1]?[0-9]/[0-3]?[0-9]/[12][09][0-9][0-9]$') + OR REGEXP_LIKE(REPLACE({$1}, '-', '/'), '^[0-1]?[0-9]/[0-3]?[0-9]/[0-9][0-9]$') + THEN + CASE + WHEN CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 1) AS INTEGER) BETWEEN 1 AND 12 + AND ( + ( CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 1) AS INTEGER) IN (1, 3, 5, 7, 8, 10, 12) + AND CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 2) AS INTEGER) BETWEEN 1 AND 31 ) + OR ( CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 1) AS INTEGER) IN (4, 6, 9, 11) + AND CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 2) AS INTEGER) BETWEEN 1 AND 30 ) + OR ( CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 1) AS INTEGER) = 2 + AND CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 2) AS INTEGER) BETWEEN 1 AND 29) + ) + AND + CAST('20' || SUBSTRING(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 3), LENGTH(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 3)) - 1) AS INTEGER) BETWEEN 1800 AND 2200 + THEN 1 + ELSE 0 + END + /* DD-MMM-YYYY */ + WHEN REGEXP_LIKE(UPPER({$1}), '[0-3]?[0-9]-(JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)-[12][09][0-9][0-9]') + THEN + CASE + WHEN CAST(SPLIT_PART({$1}, '-', 3) AS INTEGER) BETWEEN 1800 AND 2200 + AND ( + ( UPPER(SPLIT_PART({$1}, '-', 2)) IN ('JAN', 'MAR', 'MAY', 'JUL', 'AUG', 'OCT', 'DEC') + AND CAST(SPLIT_PART({$1}, '-', 1) AS INTEGER) BETWEEN 1 AND 31 ) + OR ( UPPER(SPLIT_PART({$1}, '-', 2)) IN ('APR', 'JUN', 'SEP', 'NOV') + AND CAST(SPLIT_PART({$1}, '-', 1) AS INTEGER) BETWEEN 1 AND 30 ) + OR ( UPPER(SPLIT_PART({$1}, '-', 2)) = 'FEB' + AND CAST(SPLIT_PART({$1}, '-', 1) AS INTEGER) BETWEEN 1 AND 29) + ) + THEN 1 + ELSE 0 + END + ELSE 0 + END diff --git a/testgen/template/flavors/sap_hana/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/sap_hana/gen_query_tests/gen_Freshness_Trend.sql index 06f09372..f9552a78 100644 --- a/testgen/template/flavors/sap_hana/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/sap_hana/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/gen_query_tests/gen_Freshness_Trend.sql index cc83e820..a3d5f718 100644 --- a/testgen/template/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/get_entities/get_test_generation_list.sql b/testgen/template/get_entities/get_test_generation_list.sql index b4322693..c14b1fa8 100644 --- a/testgen/template/get_entities/get_test_generation_list.sql +++ b/testgen/template/get_entities/get_test_generation_list.sql @@ -14,6 +14,7 @@ FROM test_definitions td JOIN test_suites ts ON td.test_suite_id = ts.id WHERE ts.project_code = :PROJECT_CODE AND ts.test_suite = :TEST_SUITE + AND ts.is_monitor IS NOT TRUE AND td.last_auto_gen_date IS NOT NULL GROUP BY ts.id, td.last_auto_gen_date, td.profiling_as_of_date, td.lock_refresh ORDER BY td.last_auto_gen_date desc; diff --git a/testgen/template/get_entities/get_test_info.sql b/testgen/template/get_entities/get_test_info.sql index 142ddc63..d07b4c78 100644 --- a/testgen/template/get_entities/get_test_info.sql +++ b/testgen/template/get_entities/get_test_info.sql @@ -39,6 +39,7 @@ INNER JOIN test_types tt ON td.test_type = tt.test_type INNER JOIN test_suites ts ON td.test_suite_id = ts.id WHERE ts.project_code = :PROJECT_CODE AND ts.test_suite = :TEST_SUITE + AND ts.is_monitor IS NOT TRUE ORDER BY td.schema_name, td.table_name, td.column_name, diff --git a/testgen/template/get_entities/get_test_run_list.sql b/testgen/template/get_entities/get_test_run_list.sql index 14079499..50f9ecc7 100644 --- a/testgen/template/get_entities/get_test_run_list.sql +++ b/testgen/template/get_entities/get_test_run_list.sql @@ -17,6 +17,7 @@ INNER JOIN test_results r ON tr.id = r.test_run_id INNER JOIN test_suites ts ON tr.test_suite_id = ts.id WHERE ts.project_code = :PROJECT_CODE AND ts.test_suite = :TEST_SUITE + AND ts.is_monitor IS NOT TRUE GROUP BY tr.id, ts.project_code, ts.test_suite, diff --git a/testgen/template/get_entities/get_test_suite.sql b/testgen/template/get_entities/get_test_suite.sql index fdbd9638..8d0a6c22 100644 --- a/testgen/template/get_entities/get_test_suite.sql +++ b/testgen/template/get_entities/get_test_suite.sql @@ -8,4 +8,5 @@ SELECT component_type FROM test_suites WHERE project_code = :PROJECT_CODE -AND test_suite = :TEST_SUITE; +AND test_suite = :TEST_SUITE +AND is_monitor IS NOT TRUE; diff --git a/testgen/template/get_entities/get_test_suite_list.sql b/testgen/template/get_entities/get_test_suite_list.sql index 4ba63e1f..1fe6e363 100644 --- a/testgen/template/get_entities/get_test_suite_list.sql +++ b/testgen/template/get_entities/get_test_suite_list.sql @@ -8,4 +8,6 @@ LEFT JOIN test_runs tr ON tr.test_suite_id = ts.id WHERE ts.project_code = :PROJECT_CODE + AND ts.is_monitor IS NOT TRUE + GROUP BY ts.id, ts.project_code, ts.test_suite, ts.connection_id, ts.test_suite_description ORDER BY ts.test_suite; diff --git a/testgen/template/prediction/get_freshness_fingerprint_events.sql b/testgen/template/prediction/get_freshness_fingerprint_events.sql new file mode 100644 index 00000000..ad892d7e --- /dev/null +++ b/testgen/template/prediction/get_freshness_fingerprint_events.sql @@ -0,0 +1,18 @@ +-- Fingerprint-change events from Freshness_Trend tests, used as secondary data for +-- freshness-gated SARIMAX prediction of Volume_Trend / Metric_Trend. +-- +-- Returns one row per detected fingerprint change (result_signal = '0'), ordered by +-- (schema, table, time). +SELECT DISTINCT + d.schema_name, + d.table_name, + r.test_run_id, + r.test_time +FROM test_results r +JOIN test_definitions d ON d.id = r.test_definition_id +WHERE r.test_suite_id = :TEST_SUITE_ID + AND d.test_suite_id = :TEST_SUITE_ID + AND d.test_type = 'Freshness_Trend' + AND d.test_active = 'Y' + AND r.result_signal = '0' +ORDER BY d.schema_name, d.table_name, r.test_time; diff --git a/testgen/template/prediction/get_historical_test_results.sql b/testgen/template/prediction/get_historical_test_results.sql index 800ecc10..cbf91a32 100644 --- a/testgen/template/prediction/get_historical_test_results.sql +++ b/testgen/template/prediction/get_historical_test_results.sql @@ -12,7 +12,10 @@ WITH filtered_defs AS ( AND history_calculation = 'PREDICT' ) SELECT r.test_definition_id, + r.test_run_id, d.test_type, + d.schema_name, + d.table_name, r.test_time, CASE WHEN r.result_signal ~ '^-?[0-9]*\.?[0-9]+$' THEN r.result_signal::NUMERIC diff --git a/testgen/template/quick_start/initial_data_seeding.sql b/testgen/template/quick_start/initial_data_seeding.sql index fb1283ca..0b5720dd 100644 --- a/testgen/template/quick_start/initial_data_seeding.sql +++ b/testgen/template/quick_start/initial_data_seeding.sql @@ -58,15 +58,14 @@ SELECT '823a1fef-9b6d-48d5-9d0f-2db9812cc318'::UUID AS id, 30 AS predict_min_lookback; INSERT INTO job_schedules - (id, project_code, key, args, kwargs, cron_expr, cron_tz, active) + (id, project_code, key, kwargs, cron_expr, cron_tz, active) SELECT 'eac9d722-d06a-4b1f-b8c4-bb2854bd4cfd'::UUID AS id, '{PROJECT_CODE}' AS project_code, 'run-monitors' AS key, - '[]'::JSONB AS args, '{"test_suite_id": "823a1fef-9b6d-48d5-9d0f-2db9812cc318"}'::JSONB AS kwargs, '0 */12 * * *' AS cron_expr, 'UTC' AS cron_tz, - TRUE AS TRUE; + TRUE AS active; UPDATE table_groups SET monitor_test_suite_id = '823a1fef-9b6d-48d5-9d0f-2db9812cc318'::UUID diff --git a/testgen/ui/app.py b/testgen/ui/app.py index be61a443..5bbd67eb 100644 --- a/testgen/ui/app.py +++ b/testgen/ui/app.py @@ -1,6 +1,7 @@ import logging import os from urllib.parse import urlparse +from uuid import uuid4 import streamlit as st @@ -8,14 +9,16 @@ from testgen.common import version_service from testgen.common.docker_service import check_basic_configuration from testgen.common.models import get_current_session, with_database_session -from testgen.common.models.project import Project from testgen.common.standalone_postgres import STANDALONE_URI_ENV_VAR, ensure_standalone_setup, is_standalone_mode from testgen.ui import bootstrap from testgen.ui.assets import get_asset_path from testgen.ui.components import widgets as testgen from testgen.ui.services import javascript_service +from testgen.ui.services.query_cache import select_projects_where from testgen.ui.session import session +LOG = logging.getLogger("testgen") + if is_standalone_mode() and (standalone_uri := os.environ.get(STANDALONE_URI_ENV_VAR)): ensure_standalone_setup(standalone_uri) @@ -72,7 +75,7 @@ def render(log_level: int = logging.INFO): with st.sidebar: testgen.sidebar( projects=[] if is_global_context else [ - p for p in Project.select_where() if session.auth.user_has_project_access(p.project_code) + p for p in select_projects_where() if session.auth.user_has_project_access(p.project_code) ], current_project=None if is_global_context else session.sidebar_project, menu=application.menu, @@ -84,6 +87,20 @@ def render(log_level: int = logging.INFO): ) application.router.run() + except Exception: + # Log the full traceback (tagged with a reference the user can quote) so it lands in app.log, + # which the in-app Application Logs dialog reads -- letting users download and share UI errors + # instead of needing container logs. Streamlit's rerun/stop signals are BaseException + # subclasses, so they pass through uncaught. + error_reference = uuid4().hex[:8].upper() + LOG.exception( + "Unhandled error rendering page '%s' [ref=%s]", session.current_page or "unknown", error_reference + ) + try: + _render_error_message(error_reference) + except Exception: + # Never let the error message itself break the run -- fall back to a bare message. + st.error("Something went wrong. Use the menu on the left to navigate to another page.") finally: # Safety net: commit any flushed-but-uncommitted work (e.g., PersistedSetting writes) # before RerunException propagates and bypasses database_session()'s normal commit. @@ -97,6 +114,18 @@ def render(log_level: int = logging.INFO): db_session.rollback() +def _render_error_message(reference: str) -> None: + support_email = settings.SUPPORT_EMAIL + st.error( + "**Something went wrong.**\n\n" + "An unexpected error occurred while loading this page. Use the menu on the left to navigate to " + "another page.\n\n" + "If this keeps happening, download the logs from **Help → Application Logs** and send them to " + f"[{support_email}](mailto:{support_email}) with this reference: **{reference}**.", + icon=":material/error:", + ) + + @st.cache_resource(validate=lambda _: not settings.IS_DEBUG, show_spinner=False) def get_application(log_level: int = logging.INFO): return bootstrap.run(log_level=log_level) diff --git a/testgen/ui/assets/flavors/salesforce_data360.svg b/testgen/ui/assets/flavors/salesforce_data360.svg new file mode 100644 index 00000000..beacb0d9 --- /dev/null +++ b/testgen/ui/assets/flavors/salesforce_data360.svg @@ -0,0 +1,83 @@ + + + + Salesforce.com logo + A cloud computing company based in San Francisco, California, United States + + + + image/svg+xml + + Salesforce.com logo + + + + + + + + + + + + + + + diff --git a/testgen/ui/auth.py b/testgen/ui/auth.py index 2abe32e8..8bb5b788 100644 --- a/testgen/ui/auth.py +++ b/testgen/ui/auth.py @@ -9,7 +9,11 @@ from testgen.common.models.project_membership import RoleType from testgen.common.models.user import User from testgen.ui.services.javascript_service import execute_javascript -from testgen.ui.services.query_cache import get_membership_by_user_and_project +from testgen.ui.services.query_cache import ( + get_membership_by_user_and_project, + get_user, + select_users_where, +) from testgen.ui.session import session LOG = logging.getLogger("testgen") @@ -62,7 +66,7 @@ def get_jwt_hashing_key(self) -> bytes: st.stop() def get_credentials(self): - users = User.select_where() + users = select_users_where() usernames = {} for item in users: usernames[item.username.lower()] = { @@ -72,7 +76,7 @@ def get_credentials(self): return {"usernames": usernames} def login_user(self, username: str) -> None: - self.user = User.get(username) + self.user = get_user(username) self.user.save(update_latest_login=True) self.load_user_role() MixpanelService().send_event("login", include_usage=True, role=self.role) @@ -83,7 +87,7 @@ def load_user_session(self) -> None: if token is not None: try: payload = decode_jwt_token(token) - self.user = User.get(payload["username"]) + self.user = get_user(payload["username"]) self.load_user_role() except Exception: LOG.debug("Invalid auth token found on cookies", exc_info=True, stack_info=True) diff --git a/testgen/ui/components/frontend/js/pages/data_catalog.js b/testgen/ui/components/frontend/js/pages/data_catalog.js index 079a0ebe..3ed1c10c 100644 --- a/testgen/ui/components/frontend/js/pages/data_catalog.js +++ b/testgen/ui/components/frontend/js/pages/data_catalog.js @@ -220,7 +220,6 @@ const DataCatalog = (/** @type Properties */ props) => { value: getValue(props.table_group_filter_options)?.find((op) => op.selected)?.value ?? null, options: getValue(props.table_group_filter_options) ?? [], style: 'font-size: 14px;', - testId: 'table-group-filter', onChange: (value) => emit('TableGroupSelected', {payload: value}), }), div( diff --git a/testgen/ui/components/frontend/js/pages/feedback_widget.js b/testgen/ui/components/frontend/js/pages/feedback_widget.js new file mode 100644 index 00000000..9f804b7b --- /dev/null +++ b/testgen/ui/components/frontend/js/pages/feedback_widget.js @@ -0,0 +1,225 @@ +import van from '/app/static/js/van.min.js'; +import { createEmitter, isEqual, loadStylesheet } from '/app/static/js/utils.js'; +import { Button } from '/app/static/js/components/button.js'; +import { Icon } from '/app/static/js/components/icon.js'; +import { Input } from '/app/static/js/components/input.js'; +import { Textarea } from '/app/static/js/components/textarea.js'; +const { div, span } = van.tags; + +const RATINGS = [ + { value: 1, emoji: '\u{1F620}', label: 'Frustrated' }, // 😠 + { value: 2, emoji: '\u{1F615}', label: 'Dissatisfied' }, // 😕 + { value: 3, emoji: '\u{1F610}', label: 'Neutral' }, // 😐 + { value: 4, emoji: '\u{1F642}', label: 'Satisfied' }, // 🙂 + { value: 5, emoji: '\u{1F929}', label: 'Love it!' }, // 🤩 +]; + +const FeedbackWidget = (props) => { + loadStylesheet('feedback-widget', stylesheet); + + const selectedRating = van.state(0); + const comment = van.state(''); + const email = van.state(''); + const expanded = van.state(false); + const showSuccess = van.state(false); + const submitting = van.state(false); + + const handleClose = () => { + props.emit('FeedbackDismissed', {}); + }; + + const handleSubmit = () => { + if (selectedRating.val === 0 || submitting.val) return; + submitting.val = true; + props.emit('FeedbackSubmitted', { + payload: { + rating: selectedRating.val, + comment: comment.val, + email: email.val, + }, + }); + showSuccess.val = true; + setTimeout(() => { + submitting.val = false; + props.emit('FeedbackDismissed', {}); + }, 2000); + }; + + return div( + { class: 'feedback-widget' }, + + () => !showSuccess.val + ? div( + { class: 'flex-column' }, + div( + { class: 'flex-row fx-justify-space-between p-4 pb-0' }, + div( + { class: 'flex-column fx-gap-1' }, + div({ class: 'text-bold' }, "How's your experience?"), + div({ class: 'text-caption' }, 'Your feedback helps us improve TestGen'), + ), + Button({ type: 'icon', color: 'basic', icon: 'close', onclick: handleClose }), + ), + div( + { class: 'flex-row fx-justify-space-between p-4' }, + ...RATINGS.map(rating => + div( + { + class: () => `rating-option ${selectedRating.val === rating.value ? 'selected' : ''}`, + onclick: () => { selectedRating.val = rating.value; }, + }, + span({ class: 'rating-emoji' }, rating.emoji), + span({ class: 'text-caption' }, rating.label), + ) + ), + ), + div( + { class: 'p-4 pt-0 flex-column fx-gap-3' }, + div( + { class: 'expander-row flex-row fx-justify-space-between clickable', onclick: () => { expanded.val = !expanded.val; } }, + span({ class: 'text-caption' }, 'Add a comment (optional)'), + Icon({ size: 18, classes: 'text-secondary' }, () => expanded.val ? 'keyboard_arrow_up' : 'keyboard_arrow_down'), + ), + div( + { class: 'flex-column fx-gap-3', style: () => expanded.val ? '' : 'display:none' }, + Textarea({ + label: 'Comment', + placeholder: "What's on your mind?", + value: comment, + onChange: (v) => { comment.val = v; }, + height: 64, + }), + Input({ + label: 'Email (optional)', + placeholder: 'you@company.com', + type: 'email', + value: email, + onChange: (v) => { email.val = v; }, + }), + ), + div( + { class: 'flex-row fx-justify-flex-end' }, + Button({ + type: 'flat', + color: 'primary', + label: 'Submit', + icon: 'send', + width: 'auto', + disabled: () => selectedRating.val === 0 || submitting.val, + onclick: handleSubmit, + }), + ), + ), + ) + : div( + { class: 'flex-column fx-align-flex-center p-5 feedback-success' }, + Icon({ size: 48, classes: 'text-green mb-3' }, 'check_circle'), + div({ class: 'text-bold mb-1' }, 'Thanks for your feedback!'), + div({ class: 'text-caption' }, 'We appreciate you taking the time.'), + ), + ); +}; + +const stylesheet = new CSSStyleSheet(); +stylesheet.replace(` +.feedback-widget { + position: fixed; + bottom: 24px; + right: 24px; + width: 340px; + font-family: 'Roboto', 'Helvetica Neue', sans-serif; + font-size: 14px; + color: var(--primary-text-color); + background: var(--portal-background); + border: 1px solid var(--border-color); + border-radius: 12px; + box-shadow: var(--portal-box-shadow); + overflow: hidden; + transition: opacity .25s, transform .25s; + transform-origin: bottom right; + z-index: 9999; +} + +.feedback-widget.hidden { + opacity: 0; + transform: scale(.95) translateY(8px); + pointer-events: none; +} + +.rating-option { + flex: 1; + display: flex; + flex-direction: column; + align-items: center; + gap: 4px; + padding: 8px 4px; + border-radius: 8px; + cursor: pointer; + transition: .2s; + border: 2px solid transparent; +} + +.rating-option:hover { + background: var(--select-hover-background); +} + +.rating-option.selected { + background: var(--select-hover-background); + border-color: var(--primary-color); +} + +.rating-emoji { + font-size: 28px; + line-height: 1; + filter: saturate(.8); + transition: .15s; +} + +.rating-option:hover .rating-emoji, +.rating-option.selected .rating-emoji { + transform: scale(1.15); + filter: saturate(1); +} + +.rating-option.selected .text-caption { + color: var(--primary-color); + font-weight: 500; +} + +.expander-row { + padding: 4px; + border-radius: 6px; +} + +.expander-row:hover { + background: var(--select-hover-background); +} + +.feedback-success { + text-align: center; + min-height: 160px; +} +`); + +export default (component) => { + const { data, setTriggerValue, parentElement } = component; + + let componentState = parentElement.state; + if (componentState === undefined) { + componentState = {}; + for (const [key, value] of Object.entries(data)) { + componentState[key] = van.state(value); + } + parentElement.state = componentState; + componentState.emit = createEmitter(setTriggerValue); + van.add(parentElement, FeedbackWidget(componentState)); + } else { + for (const [key, value] of Object.entries(data)) { + if (!isEqual(componentState[key].val, value)) { + componentState[key].val = value; + } + } + } + + return () => { parentElement.state = null; }; +}; diff --git a/testgen/ui/components/frontend/js/pages/hygiene_issues.js b/testgen/ui/components/frontend/js/pages/hygiene_issues.js index 4ee6810b..2682f265 100644 --- a/testgen/ui/components/frontend/js/pages/hygiene_issues.js +++ b/testgen/ui/components/frontend/js/pages/hygiene_issues.js @@ -458,7 +458,7 @@ const HygieneIssues = (/** @type Properties */ props) => { // Table header bar (actions above the table) const tableHeader = div( - { class: 'flex-row fx-align-center fx-gap-2 p-2' }, + { 'data-testid': 'table-header', class: 'flex-row fx-align-center fx-gap-2 p-2' }, Toggle({ label: () => { return div( @@ -481,7 +481,7 @@ const HygieneIssues = (/** @type Properties */ props) => { if (!permissions.val.can_disposition) return ''; const disabled = allSelectedArePassed.val; return div( - { class: 'flex-row fx-gap-1' }, + { 'data-testid': 'disposition-actions', class: 'flex-row fx-gap-1' }, Button({ type: 'icon', icon: 'check_circle', tooltip: 'Confirm selected as relevant', disabled, onclick: () => onDisposition('Confirmed') }), Button({ type: 'icon', icon: 'cancel', tooltip: 'Dismiss selected as not relevant', disabled, onclick: () => onDisposition('Dismissed') }), Button({ type: 'icon', icon: 'notifications_off', tooltip: 'Mute selected for future runs', disabled, onclick: () => onDisposition('Inactive') }), @@ -563,14 +563,12 @@ const HygieneIssues = (/** @type Properties */ props) => { profilingColumn: van.derive(() => getValue(props.profiling_column) ?? null), onClose: () => emit('ProfilingClosed', {}), width: '50rem', - testId: 'profiling-dialog', }), SourceDataDialog({ emit, sourceData: van.derive(() => getValue(props.source_data) ?? null), onClose: () => emit('SourceDataClosed', {}), renderHeader: HygieneSourceDataHeader, width: '60rem', - testId: 'source-data-dialog', }), // Summary row @@ -578,14 +576,14 @@ const HygieneIssues = (/** @type Properties */ props) => { { class: 'flex-row fx-gap-5 fx-align-flex-end mb-3 fx-flex-wrap' }, () => othersSummary.val.length ? div( - { class: 'flex-column fx-gap-1' }, + { 'data-testid': 'hygiene-issues-summary', class: 'flex-column fx-gap-1' }, div({ class: 'text-caption' }, 'Hygiene Issues'), SummaryCounts({ items: othersSummary.val }), ) : '', () => piiSummary.val.length ? div( - { class: 'flex-column fx-gap-1' }, + { 'data-testid': 'hygiene-pii-summary', class: 'flex-column fx-gap-1' }, div({ class: 'text-caption' }, 'Potential PII (Risk)'), SummaryCounts({ items: piiSummary.val }), ) @@ -596,7 +594,7 @@ const HygieneIssues = (/** @type Properties */ props) => { div( { class: 'flex-column' }, div({ class: 'text-caption'}, 'Score'), - div({ style: 'font-size: 28px' }, score), + div({ 'data-testid': 'hygiene-score', style: 'font-size: 28px' }, score), ), Button({ type: 'icon', @@ -616,7 +614,6 @@ const HygieneIssues = (/** @type Properties */ props) => { label: 'Likelihood', value: likelihoodFilter.val, options: LIKELIHOOD_OPTIONS, - testId: 'likelihood-filter', style: 'min-width: 160px', onChange: onLikelihoodChange, allowNull: true, @@ -625,7 +622,6 @@ const HygieneIssues = (/** @type Properties */ props) => { label: 'Table', value: tableFilter.val, options: tableOptions.val, - testId: 'table-filter', style: 'min-width: 160px', filterable: true, onChange: onTableChange, @@ -635,7 +631,6 @@ const HygieneIssues = (/** @type Properties */ props) => { label: 'Column', value: columnFilter.val, options: columnOptions.val, - testId: 'column-filter', style: 'min-width: 160px', filterable: true, acceptNewOptions: true, @@ -646,7 +641,6 @@ const HygieneIssues = (/** @type Properties */ props) => { label: 'Issue Type', value: issueTypeFilter.val, options: issueTypeOptions.val, - testId: 'issue-type-filter', style: 'min-width: 200px', filterable: true, onChange: onIssueTypeChange, @@ -657,7 +651,6 @@ const HygieneIssues = (/** @type Properties */ props) => { label: 'Action', value: actionFilter.val, options: ACTION_OPTIONS, - testId: 'action-filter', style: 'min-width: 160px', onChange: onActionChange, allowNull: true, @@ -675,7 +668,7 @@ const HygieneIssues = (/** @type Properties */ props) => { if (!sel) return ''; return div( - { class: 'tg-hi--detail flex-column fx-gap-4' }, + { 'data-testid': 'hygiene-issue-detail', class: 'tg-hi--detail flex-column fx-gap-4' }, div( { class: 'flex-row fx-gap-2 fx-justify-content-flex-end' }, sel.table_name !== '(multi-table)' diff --git a/testgen/ui/components/frontend/js/pages/monitors_dashboard.js b/testgen/ui/components/frontend/js/pages/monitors_dashboard.js index 5809d7ec..f3eb9612 100644 --- a/testgen/ui/components/frontend/js/pages/monitors_dashboard.js +++ b/testgen/ui/components/frontend/js/pages/monitors_dashboard.js @@ -308,7 +308,6 @@ const MonitorsDashboard = (/** @type Properties */ props) => { })), allowNull: false, style: 'font-size: 14px;', - testId: 'table-group-filter', onChange: (value) => emit('SetParamValues', {payload: {table_group_id: value, table_name: null}}), }), () => getValue(props.has_monitor_test_suite) @@ -371,7 +370,6 @@ const MonitorsDashboard = (/** @type Properties */ props) => { width: 230, style: 'font-size: 14px;', icon: 'search', - testId: 'search-tables', value: tableNameFilterValue, onChange: (value, state) => emit('SetParamValues', {payload: {table_name_filter: value, current_page: 0}}), }), diff --git a/testgen/ui/components/frontend/js/pages/notification_settings.js b/testgen/ui/components/frontend/js/pages/notification_settings.js index 9b879fb7..27d6b598 100644 --- a/testgen/ui/components/frontend/js/pages/notification_settings.js +++ b/testgen/ui/components/frontend/js/pages/notification_settings.js @@ -248,7 +248,6 @@ const NotificationSettings = (/** @type Properties */ props) => { title: newNotificationItemForm.isEdit.val ? span({ class: 'notifications--editing' }, 'Edit Notification') : span({ class: 'text-green' }, 'Add Notification'), - testId: 'notification-item-editor', expanded: newNotificationItemForm.isEdit.val, }, div( diff --git a/testgen/ui/components/frontend/js/pages/profiling_results.js b/testgen/ui/components/frontend/js/pages/profiling_results.js index 10adb4eb..e7c4263d 100644 --- a/testgen/ui/components/frontend/js/pages/profiling_results.js +++ b/testgen/ui/components/frontend/js/pages/profiling_results.js @@ -253,7 +253,6 @@ const ProfilingResults = (/** @type Properties */ props) => { label: 'Table', value: tableFilter.val, options: tableOptions.val, - testId: 'table-filter', style: 'min-width: 200px', filterable: true, acceptNewOptions: true, @@ -264,7 +263,6 @@ const ProfilingResults = (/** @type Properties */ props) => { label: 'Column', value: columnFilter.val, options: columnOptions.val, - testId: 'column-filter', style: 'min-width: 200px', filterable: true, acceptNewOptions: true, diff --git a/testgen/ui/components/frontend/js/pages/profiling_runs.js b/testgen/ui/components/frontend/js/pages/profiling_runs.js index 74bfa329..035976f3 100644 --- a/testgen/ui/components/frontend/js/pages/profiling_runs.js +++ b/testgen/ui/components/frontend/js/pages/profiling_runs.js @@ -199,7 +199,6 @@ const ProfilingRuns = (/** @type Properties */ props) => { checked: allSelected, indeterminate: partiallySelected, onChange: (checked) => items.forEach(item => selectedRuns[item.job_execution_id].val = checked), - testId: 'select-all-profiling-run', }) : '', ); @@ -359,7 +358,6 @@ const Toolbar = ( options: getValue(props.table_group_options) ?? [], allowNull: true, style: 'font-size: 14px;', - testId: 'table-group-filter', onChange: (value) => emit('FilterApplied', { payload: { table_group_id: value } }), }), div( @@ -401,7 +399,6 @@ const Toolbar = ( tooltipPosition: 'left', style: 'background: var(--button-generic-background-color);', onclick: () => emit('RefreshData', {}), - testId: 'profiling-runs-refresh', }), ), ); @@ -426,7 +423,6 @@ const ProfilingRunItem = ( Checkbox({ checked: selected, onChange: (checked) => selected.val = checked, - testId: 'select-profiling-run', }), ) : '', diff --git a/testgen/ui/components/frontend/js/pages/project_dashboard.js b/testgen/ui/components/frontend/js/pages/project_dashboard.js index ea8bd237..4abc86f7 100644 --- a/testgen/ui/components/frontend/js/pages/project_dashboard.js +++ b/testgen/ui/components/frontend/js/pages/project_dashboard.js @@ -100,7 +100,6 @@ const ProjectDashboard = (/** @type Properties */ props) => { icon: 'search', clearable: true, placeholder: 'Search table group names', - testId: 'table-groups-filter', onChange: (value) => tableGroupsSearchTerm.val = value, }), Select({ @@ -108,7 +107,6 @@ const ProjectDashboard = (/** @type Properties */ props) => { value: tableGroupsSortOption, options: props.table_groups_sort_options?.val ?? [], style: 'font-size: 14px;', - testId: 'table-groups-sort', }), ) : '', diff --git a/testgen/ui/components/frontend/js/pages/quality_dashboard.js b/testgen/ui/components/frontend/js/pages/quality_dashboard.js index a592678b..a2909520 100644 --- a/testgen/ui/components/frontend/js/pages/quality_dashboard.js +++ b/testgen/ui/components/frontend/js/pages/quality_dashboard.js @@ -111,7 +111,6 @@ const Toolbar = ( placeholder: 'Search scorecards', value: filterBy, onChange: options?.onsearch, - testId: 'scorecards-filter', }), Select({ id: 'score-dashboard-sort', @@ -120,7 +119,6 @@ const Toolbar = ( value: sortedBy, options: sortOptions, onChange: options?.onsort, - testId: 'scorecards-sort', }), span({ style: 'margin: 0 auto;' }), Button({ @@ -132,7 +130,6 @@ const Toolbar = ( onclick: () => emit('LinkClicked', { href: 'quality-dashboard:explorer', params: { project_code: projectSummary.project_code }, - testId: 'scorecards-goto-explorer', }), }), Button({ @@ -142,7 +139,6 @@ const Toolbar = ( tooltipPosition: 'left', style: 'background: var(--button-generic-background-color);', onclick: () => emit('RefreshData', {}), - testId: 'scorecards-refresh', }), ); }; diff --git a/testgen/ui/components/frontend/js/pages/schedule_list.js b/testgen/ui/components/frontend/js/pages/schedule_list.js index 1ced8e54..31559eb5 100644 --- a/testgen/ui/components/frontend/js/pages/schedule_list.js +++ b/testgen/ui/components/frontend/js/pages/schedule_list.js @@ -75,7 +75,7 @@ const ScheduleList = (/** @type Properties */ props) => { const content = div( { id: domId, class: 'flex-column fx-gap-2', style: 'height: 100%; overflow-y: auto;' }, ExpansionPanel( - {title: span({ class: 'text-green' }, 'Add Schedule'), testId: 'scheduler-cron-editor'}, + {title: span({ class: 'text-green' }, 'Add Schedule')}, div( { class: 'flex-row fx-gap-2' }, () => Select({ diff --git a/testgen/ui/components/frontend/js/pages/score_details.js b/testgen/ui/components/frontend/js/pages/score_details.js index c8cf3ea0..94b8bc2f 100644 --- a/testgen/ui/components/frontend/js/pages/score_details.js +++ b/testgen/ui/components/frontend/js/pages/score_details.js @@ -70,7 +70,7 @@ const ScoreDetails = (/** @type {Properties} */ props) => { () => { const score = getValue(props.score); return getValue(props.permissions)?.can_edit ?? false ? div( - { class: 'flex-row tg-test-suites--card-actions' }, + { class: 'flex-row tg-score-details--card-actions' }, Button({ type: 'icon', icon: 'notifications', tooltip: 'Configure Notifications', onclick: () => emit('EditNotifications', {}) }), Button({ type: 'icon', icon: 'edit', tooltip: 'Edit Scorecard', onclick: () => emit('LinkClicked', { href: 'quality-dashboard:explorer', params: { definition_id: score.id, project_code: score.project_code } }) }), Button({ type: 'icon', icon: 'delete', tooltip: 'Delete Scorecard', onclick: () => { deleteDialogOpen.val = true; } }), @@ -171,6 +171,10 @@ stylesheet.replace(` .tg-score-details { min-height: 900px; } + +.tg-score-details--card-actions { + margin-top: -10px; +} `); export { ScoreDetails }; diff --git a/testgen/ui/components/frontend/js/pages/score_explorer.js b/testgen/ui/components/frontend/js/pages/score_explorer.js index f1744d34..addd28b3 100644 --- a/testgen/ui/components/frontend/js/pages/score_explorer.js +++ b/testgen/ui/components/frontend/js/pages/score_explorer.js @@ -396,13 +396,11 @@ const Toolbar = ( Checkbox({ label: 'Total Score', checked: displayTotalScore, - testId: 'include-total-score', onChange: (checked) => displayTotalScore.val = checked, }), Checkbox({ label: 'CDE Score', checked: displayCDEScore, - testId: 'include-cde-score', onChange: (checked) => displayCDEScore.val = checked, }), div( @@ -410,7 +408,6 @@ const Toolbar = ( Checkbox({ label: 'Category:', checked: displayCategory, - testId: 'include-category', onChange: (checked) => displayCategory.val = checked, }), Select({ @@ -419,7 +416,6 @@ const Toolbar = ( value: selectedCategory, options: categories.map((c) => ({ value: c, label: TRANSLATIONS[c] })), disabled: van.derive(() => !getValue(displayCategory)), - testId: 'category-selector', }), ), ), @@ -430,7 +426,6 @@ const Toolbar = ( label: 'Scorecard Name', height: 40, value: scoreName, - testId: 'scorecard-name-input', onChange: debounce((name) => scoreName.val = name, 300), }), () => { diff --git a/testgen/ui/components/frontend/js/pages/table_group_list.js b/testgen/ui/components/frontend/js/pages/table_group_list.js index e5c126ba..3f52a24a 100644 --- a/testgen/ui/components/frontend/js/pages/table_group_list.js +++ b/testgen/ui/components/frontend/js/pages/table_group_list.js @@ -84,7 +84,7 @@ const TableGroupList = (props) => { if (key !== wizardKey) { wizardContainer.innerHTML = ''; wizardKey = key; - van.add(wizardContainer, TableGroupWizard({ emit, + van.add(wizardContainer, TableGroupWizard({ emit, project_code: van.derive(() => getValue(props.wizard)?.project_code), connections: van.derive(() => getValue(props.wizard)?.connections), table_group: van.derive(() => getValue(props.wizard)?.table_group), @@ -115,7 +115,7 @@ const TableGroupList = (props) => { if (key !== editDialogKey) { editDialogContainer.innerHTML = ''; editDialogKey = key; - van.add(editDialogContainer, TableGroupEditDialog({ emit, + van.add(editDialogContainer, TableGroupEditDialog({ emit, dialog: van.derive(() => getValue(props.edit_dialog)?.dialog), connections: van.derive(() => getValue(props.edit_dialog)?.connections), table_group: van.derive(() => getValue(props.edit_dialog)?.table_group), @@ -208,7 +208,6 @@ const TableGroupList = (props) => { ? div( { class: 'flex-column fx-gap-4' }, ...tableGroups.map((tableGroup) => Card({ - testId: 'table-group-card', class: '', title: div( { class: 'flex-column fx-gap-2 tg-tablegroup--card-title', 'data-testid': 'tablegroup-card-title' }, @@ -226,7 +225,7 @@ const TableGroupList = (props) => { { class: 'flex-row fx-gap-3' }, div( { class: 'flex-column fx-flex fx-gap-3' }, - Link({ emit, + Link({ emit, label: 'View test suites', href: 'test-suites', params: { 'project_code': projectSummary.project_code, 'table_group_id': tableGroup.id }, @@ -239,7 +238,7 @@ const TableGroupList = (props) => { { class: 'flex-column fx-flex fx-gap-4' }, div( { class: 'flex-column fx-flex' }, - Caption({content: 'DB Schema', style: 'margin-bottom: 4px;'}), + Caption({content: tableGroup.connection.flavor.flavor === 'salesforce_data360' ? 'Data Space' : 'Schema', style: 'margin-bottom: 4px;'}), span(tableGroup.table_group_schema || '--'), ), div( @@ -448,7 +447,6 @@ const Toolbar = (permissions, connections, selectedConnection, tableGroupNameFil {class: 'flex-row fx-align-flex-end fx-gap-3'}, () => (getValue(connections) ?? [])?.length > 1 ? Select({ - testId: 'connection-select', label: 'Connection', allowNull: true, value: connection, @@ -460,7 +458,6 @@ const Toolbar = (permissions, connections, selectedConnection, tableGroupNameFil }) : '', Input({ - testId: 'table-groups-name-filter', icon: 'search', label: '', placeholder: 'Search table group names', diff --git a/testgen/ui/components/frontend/js/pages/table_monitoring_trends.js b/testgen/ui/components/frontend/js/pages/table_monitoring_trends.js index 8aa0891f..8e0c86cd 100644 --- a/testgen/ui/components/frontend/js/pages/table_monitoring_trends.js +++ b/testgen/ui/components/frontend/js/pages/table_monitoring_trends.js @@ -455,6 +455,10 @@ const ChartsSection = (props, { schemaChartSelection, getDataStructureLogs }) => originalUpperTolerance: e.upper_tolerance != undefined ? parseInt(e.upper_tolerance) : undefined, + // Freshness-gated baseline (only present on gated runs). + originalThreshold: e.threshold_value != undefined + ? parseFloat(e.threshold_value) + : undefined, label: 'Row count', isAnomaly: e.is_anomaly, isTraining: e.is_training, @@ -490,6 +494,7 @@ const ChartsSection = (props, { schemaChartSelection, getDataStructureLogs }) => originalY: e.value, originalLowerTolerance: e.lower_tolerance, originalUpperTolerance: e.upper_tolerance, + originalThreshold: e.threshold_value, isAnomaly: e.is_anomaly, isTraining: e.is_training, isPending: e.is_pending, diff --git a/testgen/ui/components/frontend/js/pages/test_definitions.js b/testgen/ui/components/frontend/js/pages/test_definitions.js index aeff2ab9..47bf21d2 100644 --- a/testgen/ui/components/frontend/js/pages/test_definitions.js +++ b/testgen/ui/components/frontend/js/pages/test_definitions.js @@ -75,6 +75,7 @@ const BLANK_PARAM_FIELDS = { const ClearFlagButton = ({ disabled, onclick }) => { return withTooltip(btn( { + 'data-testid': 'button', class: 'tg-button tg-icon-button tg-basic-button', disabled, onclick, @@ -394,7 +395,7 @@ const TestDefinitions = (/** @type object */ props) => { // Table header bar: multi-select toggle + edit buttons | dashed separator | disposition buttons + export const tableHeader = div( - { class: 'flex-row fx-align-center fx-gap-2 p-2 fx-flex-wrap' }, + { 'data-testid': 'table-header', class: 'flex-row fx-align-center fx-gap-2 p-2 fx-flex-wrap' }, () => canDisposition.val ? Toggle({ label: () => { @@ -429,7 +430,7 @@ const TestDefinitions = (/** @type object */ props) => { test_type: r.test_type, lock_refresh: r.lock_refresh, })); return div( - { class: 'flex-row fx-gap-1' }, + { 'data-testid': 'edit-actions', class: 'flex-row fx-gap-1' }, Button({ type: 'icon', icon: 'file_copy', tooltip: 'Copy/Move', disabled: !hasSelection, onclick: () => emit('CopyMoveDialogOpened', { payload: isAll ? 'all' : minimalSelected() }) }), Button({ type: 'icon', icon: 'delete', tooltip: 'Delete', disabled: !hasSelection, @@ -462,7 +463,7 @@ const TestDefinitions = (/** @type object */ props) => { } }; return div( - { class: 'flex-row fx-gap-1' }, + { 'data-testid': 'disposition-actions', class: 'flex-row fx-gap-1' }, Button({ type: 'icon', icon: 'check_circle', tooltip: 'Activate selected', disabled: noSelection || allActive, onclick: () => emitAttribute('test_active', true) }), Button({ type: 'icon', icon: 'notifications_off', tooltip: 'Deactivate selected', disabled: noSelection || allInactive, onclick: () => emitAttribute('test_active', false) }), div({ class: 'td-header-separator' }), @@ -758,7 +759,7 @@ const TestDefinitions = (/** @type object */ props) => { const row = singleSelected.val; if (!row) return ''; return div( - { class: 'tg-td--detail flex-column fx-gap-4' }, + { 'data-testid': 'test-definition-detail', class: 'tg-td--detail flex-column fx-gap-4' }, div( { class: 'flex-row fx-gap-2 fx-justify-content-flex-end' }, canEdit.val ? Button({ @@ -853,6 +854,7 @@ const AddDialogComponent = ({ open, info, validateResult: validateResultProp, on const tableGroupsId = van.derive(() => getValue(info)?.table_groups_id ?? ''); const testSuite = van.derive(() => getValue(info)?.test_suite ?? {}); const tableColumns = van.derive(() => getValue(info)?.table_columns ?? []); + const qualifiesTableRefsWithSchema = van.derive(() => getValue(info)?.qualifies_table_refs_with_schema ?? true); const validateResult = van.derive(() => getValue(validateResultProp) ?? null); const scopeFilter = { @@ -959,6 +961,7 @@ const AddDialogComponent = ({ open, info, validateResult: validateResultProp, on formValues: fv, tableColumns: tableColumns.rawVal, testSuite: testSuite.rawVal, + qualifiesTableRefsWithSchema: qualifiesTableRefsWithSchema.rawVal, validateResult: vr, mode: 'add', onFormChange: (changes) => { @@ -978,6 +981,7 @@ const EditDialogComponent = ({ open, info, validateResult: validateResultProp, o const dialogInfo = van.derive(() => getValue(info) ?? null); const tableColumns = van.derive(() => dialogInfo.val?.table_columns ?? []); const testSuite = van.derive(() => dialogInfo.val?.test_suite ?? {}); + const qualifiesTableRefsWithSchema = van.derive(() => dialogInfo.val?.qualifies_table_refs_with_schema ?? true); const validateResult = van.derive(() => getValue(validateResultProp) ?? null); const formValues = van.state(null); @@ -1021,6 +1025,7 @@ const EditDialogComponent = ({ open, info, validateResult: validateResultProp, o formValues: fv, tableColumns: tableColumns.rawVal, testSuite: testSuite.rawVal, + qualifiesTableRefsWithSchema: qualifiesTableRefsWithSchema.rawVal, validateResult: vr, mode: 'edit', onFormChange: (changes) => { @@ -1036,7 +1041,7 @@ const EditDialogComponent = ({ open, info, validateResult: validateResultProp, o }; // Shared form content for add/edit dialogs -const TestDefFormContent = ({ formValues, tableColumns, testSuite, validateResult, mode, onFormChange, onValidate, onSave, onCancel }) => { +const TestDefFormContent = ({ formValues, tableColumns, testSuite, validateResult, mode, qualifiesTableRefsWithSchema, onFormChange, onValidate, onSave, onCancel }) => { const testScope = formValues.test_scope ?? 'column'; const runType = formValues.run_type ?? 'CAT'; const testType = formValues.test_type ?? ''; @@ -1070,7 +1075,7 @@ const TestDefFormContent = ({ formValues, tableColumns, testSuite, validateResul { label: 'Regularity', value: 'Regularity' }, { label: 'Usability', value: 'Usability' }, ]; - const showImpactDimensionOverride = testType === 'CUSTOM' || testType === 'Condition_Flag' || testScope === 'referential'; + const showImpactDimensionOverride = ['custom', 'referential'].includes(testScope); const tableNameOptions = [ ...new Set((tableColumns ?? []).map(c => c.table_name).filter(Boolean)) @@ -1176,12 +1181,14 @@ const TestDefFormContent = ({ formValues, tableColumns, testSuite, validateResul ), // Schema (read-only) - Input({ - name: 'schema_name', - label: 'Schema', - value: formValues.schema_name ?? '', - disabled: true, - }), + qualifiesTableRefsWithSchema + ? Input({ + name: 'schema_name', + label: 'Schema', + value: formValues.schema_name ?? '', + disabled: true, + }) + : null, // Table name testScope !== 'tablegroup' @@ -1241,6 +1248,7 @@ const TestDefFormContent = ({ formValues, tableColumns, testSuite, validateResul { class: 'td-form-params-section' }, TestDefinitionForm({ definition: formValues, + qualifiesTableRefsWithSchema, onChange: (changes) => { if (Object.keys(changes).length === 0) return; const updated = { ...fv.rawVal, ...changes }; diff --git a/testgen/ui/components/frontend/js/pages/test_results.js b/testgen/ui/components/frontend/js/pages/test_results.js index d2a578fd..8fdc3bab 100644 --- a/testgen/ui/components/frontend/js/pages/test_results.js +++ b/testgen/ui/components/frontend/js/pages/test_results.js @@ -74,6 +74,7 @@ const STATUS_COLORS = { const ClearFlagButton = ({ disabled, onclick }) => { return withTooltip(btn( { + 'data-testid': 'button', class: 'tg-button tg-icon-button tg-basic-button', tooltip: 'Clear flag', disabled, @@ -540,7 +541,7 @@ const TestResults = (/** @type Properties */ props) => { // Table header bar const tableHeader = div( - { class: 'flex-row fx-align-center fx-gap-2 p-2' }, + { 'data-testid': 'table-header', class: 'flex-row fx-align-center fx-gap-2 p-2' }, Toggle({ label: () => { return div( @@ -569,7 +570,7 @@ const TestResults = (/** @type Properties */ props) => { ? !isAll && count === 0 : (() => { const row = selectedRow.val; return !row || row.result_status === 'Passed'; })(); return div( - { class: 'flex-row fx-gap-1' }, + { 'data-testid': 'disposition-actions', class: 'flex-row fx-gap-1' }, Button({ type: 'icon', icon: 'check_circle', tooltip: 'Confirm selected as relevant', disabled, onclick: () => onDisposition('Confirmed') }), Button({ type: 'icon', icon: 'cancel', tooltip: 'Dismiss selected as not relevant', disabled, onclick: () => onDisposition('Dismissed') }), Button({ type: 'icon', icon: 'notifications_off', tooltip: 'Mute selected tests for future runs', disabled, onclick: () => onDisposition('Inactive') }), @@ -599,7 +600,7 @@ const TestResults = (/** @type Properties */ props) => { }; return div( - { class: 'flex-row fx-gap-1' }, + { 'data-testid': 'flag-actions', class: 'flex-row fx-gap-1' }, span({ style: 'width: 0px; height: 24px; border-right: 1px dashed var(--border-color);'}, ''), Button({ type: 'icon', icon: 'flag', tooltip: 'Flag selected', disabled: noSelection, @@ -757,7 +758,7 @@ const TestResults = (/** @type Properties */ props) => { div( { class: 'tg-tr--score flex-column fx-align-center' }, small({ class: 'text-caption' }, 'Score'), - span({ class: 'tg-tr--score-value' }, () => getValue(props.score) ?? '--'), + span({ 'data-testid': 'test-run-score', class: 'tg-tr--score-value' }, () => getValue(props.score) ?? '--'), ), Button({ type: 'icon', @@ -776,7 +777,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Status', value: statusFilter.val, options: STATUS_FILTER_OPTIONS, - testId: 'status-filter', style: 'min-width: 160px', onChange: onStatusFilterChange, allowNull: true, @@ -785,7 +785,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Table', value: tableFilter.val, options: tableOptions.val, - testId: 'table-filter', style: 'min-width: 180px', filterable: true, onChange: onTableFilterChange, @@ -795,7 +794,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Column', value: columnFilter.val, options: columnOptions.val, - testId: 'column-filter', style: 'min-width: 180px', filterable: true, acceptNewOptions: true, @@ -806,7 +804,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Test Type', value: testTypeFilter.val, options: testTypeOptions.val, - testId: 'test-type-filter', style: 'min-width: 160px', filterable: true, onChange: onTestTypeFilterChange, @@ -816,7 +813,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Action', value: actionFilter.val, options: ACTION_FILTER_OPTIONS, - testId: 'action-filter', style: 'min-width: 140px', onChange: onActionFilterChange, allowNull: true, @@ -825,7 +821,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Flagged', value: flaggedFilter.val, options: FLAGGED_FILTER_OPTIONS, - testId: 'flagged-filter', style: 'min-width: 140px', onChange: onFlaggedFilterChange, allowNull: true, @@ -846,7 +841,7 @@ const TestResults = (/** @type Properties */ props) => { const hasData = si && si.test_result_id === row.test_result_id; return div( - { class: 'tg-tr--detail flex-column fx-gap-4' }, + { 'data-testid': 'test-result-detail', class: 'tg-tr--detail flex-column fx-gap-4' }, // Action buttons row div( @@ -885,7 +880,7 @@ const TestResults = (/** @type Properties */ props) => { { class: 'flex-column fx-flex', style: 'min-width: 0' }, h3({ class: 'tg-tr--detail-title' }, row.test_name_short), row.test_description - ? p({ class: 'tg-tr--detail-desc' }, row.test_description) + ? p({ 'data-testid': 'test-result-description', class: 'tg-tr--detail-desc' }, row.test_description) : '', row.measure_uom_description ? small({ class: 'text-caption' }, row.measure_uom_description) @@ -901,7 +896,7 @@ const TestResults = (/** @type Properties */ props) => { { class: 'flex-column fx-flex', style: 'min-width: 0' }, hasData ? Tabs( - { testId: 'test-result-detail' }, + {}, Tab( { label: 'History' }, si.history?.length diff --git a/testgen/ui/components/frontend/js/pages/test_runs.js b/testgen/ui/components/frontend/js/pages/test_runs.js index 4a5b06a3..800c4b8c 100644 --- a/testgen/ui/components/frontend/js/pages/test_runs.js +++ b/testgen/ui/components/frontend/js/pages/test_runs.js @@ -199,7 +199,6 @@ const TestRuns = (/** @type Properties */ props) => { checked: allSelected, indeterminate: partiallySelected, onChange: (checked) => items.forEach(item => selectedRuns[item.job_execution_id].val = checked), - testId: 'select-all-test-run', }) : '', ); @@ -355,7 +354,6 @@ const Toolbar = ( options: getValue(props.table_group_options) ?? [], allowNull: true, style: 'font-size: 14px;', - testId: 'table-group-filter', onChange: (value) => emit('FilterApplied', { payload: { table_group_id: value } }), }), () => Select({ @@ -364,7 +362,6 @@ const Toolbar = ( options: getValue(props.test_suite_options) ?? [], allowNull: true, style: 'font-size: 14px;', - testId: 'test-suite-filter', onChange: (value) => emit('FilterApplied', { payload: { test_suite_id: value } }), }), ), @@ -407,7 +404,6 @@ const Toolbar = ( tooltipPosition: 'left', style: 'background: var(--button-generic-background-color);', onclick: () => emit('RefreshData', {}), - testId: 'test-runs-refresh', }), ), ); @@ -433,7 +429,6 @@ const TestRunItem = ( Checkbox({ checked: selected, onChange: (checked) => selected.val = checked, - testId: 'select-test-run', }), ) : '', diff --git a/testgen/ui/components/frontend/js/pages/test_suites.js b/testgen/ui/components/frontend/js/pages/test_suites.js index f9899747..c73f98f3 100644 --- a/testgen/ui/components/frontend/js/pages/test_suites.js +++ b/testgen/ui/components/frontend/js/pages/test_suites.js @@ -135,14 +135,12 @@ const TestSuites = (/** @type Properties */ props) => { options: getValue(props.table_group_filter_options) ?? [], allowNull: true, style: 'font-size: 14px;', - testId: 'table-group-filter', onChange: (value) => { console.log(value) emit('FilterApplied', { payload: { table_group_id: value } }) }, }), () => Input({ - testId: 'test-suite-name-filter', icon: 'search', label: '', placeholder: 'Search test suite names', diff --git a/testgen/ui/components/frontend/js/shared/profiling_results_dialog.js b/testgen/ui/components/frontend/js/shared/profiling_results_dialog.js index 6f8d9140..4f87ef43 100644 --- a/testgen/ui/components/frontend/js/shared/profiling_results_dialog.js +++ b/testgen/ui/components/frontend/js/shared/profiling_results_dialog.js @@ -10,7 +10,6 @@ import { ColumnProfilingResults } from '../data_profiling/column_profiling_resul * @param {object} props.profilingColumn - reactive state: set to column data to open, null to close * @param {function} props.onClose - called when dialog is closed * @param {string} [props.width='52rem'] - * @param {string} [props.testId] */ const ProfilingResultsDialog = (props) => { const emit = props.emit; @@ -31,7 +30,7 @@ const ProfilingResultsDialog = (props) => { const columnJson = van.derive(() => columnData.val ? JSON.stringify(columnData.val) : null); return Dialog( - { title: 'Column Profiling Results', open, onClose, width: props.width || '52rem', testId: props.testId }, + { title: 'Column Profiling Results', open, onClose, width: props.width || '52rem' }, () => columnJson.val ? ColumnProfilingResults({ emit, column: columnJson }) : '', ); }; diff --git a/testgen/ui/components/frontend/js/shared/source_data_dialog.js b/testgen/ui/components/frontend/js/shared/source_data_dialog.js index 8645c18c..72a3775b 100644 --- a/testgen/ui/components/frontend/js/shared/source_data_dialog.js +++ b/testgen/ui/components/frontend/js/shared/source_data_dialog.js @@ -15,7 +15,6 @@ const { div, h4, small } = van.tags; * @param {function} props.onClose - called when dialog is closed * @param {function} [props.renderHeader] - (data) => VanJS node for page-specific metadata header * @param {string} [props.width='70rem'] - * @param {string} [props.testId] */ const SourceDataDialog = (props) => { const emit = props.emit; @@ -35,7 +34,7 @@ const SourceDataDialog = (props) => { }; return Dialog( - { title: 'Source Data', open, onClose, width: props.width || '70rem', testId: props.testId }, + { title: 'Source Data', open, onClose, width: props.width || '70rem' }, () => { const d = data.val; if (!d) return ''; diff --git a/testgen/ui/components/frontend/standalone/project_settings/index.js b/testgen/ui/components/frontend/standalone/project_settings/index.js index 447641f8..3f291eb5 100644 --- a/testgen/ui/components/frontend/standalone/project_settings/index.js +++ b/testgen/ui/components/frontend/standalone/project_settings/index.js @@ -5,10 +5,14 @@ import van from '/app/static/js/van.min.js'; import { Card } from '/app/static/js/components/card.js'; import { Input } from '/app/static/js/components/input.js'; import { Button } from '/app/static/js/components/button.js'; -import { required } from '/app/static/js/form_validators.js'; +import { numberBetween, required } from '/app/static/js/form_validators.js'; import { Alert } from '/app/static/js/components/alert.js'; import { Checkbox } from '/app/static/js/components/checkbox.js'; -import { createEmitter, getValue, isEqual } from '/app/static/js/utils.js'; +import { CrontabInput } from '/app/static/js/components/crontab_input.js'; +import { Select } from '/app/static/js/components/select.js'; +import { timezones } from '/app/static/js/values.js'; +import { formatTimestamp } from '/app/static/js/display_utils.js'; +import { createEmitter, debounce, getValue, isEqual } from '/app/static/js/utils.js'; const { div, span } = van.tags; @@ -26,29 +30,112 @@ const { div, span } = van.tags; * @property {VanState} observability_api_url * @property {VanState} observability_api_key * @property {VanState} observability_test_results + * @property {VanState} data_retention_enabled + * @property {VanState} data_retention_days + * @property {VanState} retention_cron_expr + * @property {VanState} retention_cron_tz + * @property {VanState} retention_cron_sample + * @property {VanState} retention_last_run + * @property {VanState<{profiling_count: number, test_count: number}?>} retention_preview * * @param {Properties} props */ const ProjectSettings = (props) => { const { emit } = props; + // Persisted values are reactive: after a Save, the props update with the + // newly-stored values and these derives recompute, letting + // `showRetentionConfirmation` settle back to a clean state. + const browserTz = Intl.DateTimeFormat().resolvedOptions().timeZone || 'UTC'; + const persistedName = van.derive(() => props.name.val ?? ''); + const persistedUseWeights = van.derive(() => props.use_dq_score_weights.val ?? true); + const persistedObsUrl = van.derive(() => props.observability_api_url.val ?? ''); + const persistedObsKey = van.derive(() => props.observability_api_key.val ?? ''); + const persistedRetentionEnabled = van.derive(() => props.data_retention_enabled.val ?? false); + const persistedRetentionDays = van.derive(() => props.data_retention_days.val ?? 180); + const persistedRetentionCron = van.derive(() => props.retention_cron_expr.val ?? '0 1 * * *'); + const persistedRetentionTz = van.derive(() => props.retention_cron_tz.val ?? browserTz); const /** @type Properties */ form = { name: van.state(props.name.rawVal ?? ''), use_dq_score_weights: van.state(props.use_dq_score_weights.rawVal ?? true), observability_api_key: van.state(props.observability_api_key.rawVal ?? ''), observability_api_url: van.state(props.observability_api_url.rawVal ?? ''), + data_retention_enabled: van.state(persistedRetentionEnabled.val), + data_retention_days: van.state(persistedRetentionDays.val), + retention_cron_expr: van.state(persistedRetentionCron.val), + retention_cron_tz: van.state(persistedRetentionTz.val), }; const formValidity = { name: van.state(!!form.name.rawVal), observability_api_key: van.state(true), observability_api_url: van.state(true), + data_retention_days: van.state(Number.isFinite(form.data_retention_days.rawVal)), }; - const saveDisabled = van.derive(() => !formValidity.name.val || !formValidity.observability_api_url.val || !formValidity.observability_api_key.val); + // Retention is unchanged when the enabled flag matches the persisted value and, + // while enabled, the days/cron/tz also match. When retention is off, days/cron/tz + // are hidden and the backend clears them, so they don't count as unsaved changes — + // only the enabled flag matters. + const retentionUnchanged = van.derive(() => { + if (form.data_retention_enabled.val !== persistedRetentionEnabled.val) return false; + if (!form.data_retention_enabled.val) return true; + return form.data_retention_days.val === persistedRetentionDays.val + && form.retention_cron_expr.val === persistedRetentionCron.val + && form.retention_cron_tz.val === persistedRetentionTz.val; + }); + // No unsaved changes when every field matches its persisted value. Because the + // persisted derives are reactive, this settles back to `true` after a Save once + // the props update with the stored values, disabling the button again. + const noChanges = van.derive(() => form.name.val === persistedName.val + && form.use_dq_score_weights.val === persistedUseWeights.val + && form.observability_api_url.val === persistedObsUrl.val + && form.observability_api_key.val === persistedObsKey.val + && retentionUnchanged.val); + const saveDisabled = van.derive(() => !formValidity.name.val + || !formValidity.observability_api_url.val + || !formValidity.observability_api_key.val + || (form.data_retention_enabled.val && !formValidity.data_retention_days.val) + || noChanges.val); const testObservabilityDisabled = van.derive(() => form.observability_api_url.val.length <= 0 || form.observability_api_key.val.length <= 0); + const retentionCronEditorValue = van.derive(() => { + if (form.retention_cron_expr.val && form.retention_cron_tz.val && form.data_retention_enabled.val) { + emit('GetCronSample', { + payload: { cron_expr: form.retention_cron_expr.val, tz: form.retention_cron_tz.val }, + }); + } + return { + timezone: form.retention_cron_tz.val, + expression: form.retention_cron_expr.val, + }; + }); + // True when the form would enlarge the next cleanup's delete set — + // turning retention on, or shortening the retention period of a project + // that already has it on. Both cases warrant a delete-preview confirmation + // before saving. + const showRetentionConfirmation = van.derive(() => { + if (!form.data_retention_enabled.val) return false; + if (!persistedRetentionEnabled.val) return true; + return form.data_retention_days.val < persistedRetentionDays.val; + }); + // Debounce so rapid days edits collapse to a single round-trip. + const previewPending = van.state(false); + const emitPreviewRequest = debounce((days) => { + emit('GetRetentionPreview', { payload: { retention_days: days } }); + }, 300); + van.derive(() => { + if (showRetentionConfirmation.val && formValidity.data_retention_days.val) { + previewPending.val = true; + emitPreviewRequest(form.data_retention_days.val); + } + }); + van.derive(() => { + if (getValue(props.retention_preview) !== null && getValue(props.retention_preview) !== undefined) { + previewPending.val = false; + } + }); return div( { class: 'flex-column fx-gap-3' }, div( - { class: 'flex-column fx-gap-1' }, + { class: 'flex-column fx-gap-1', style: 'max-width: 700px;' }, span({ class: 'body m' }, 'Project Info'), Card({ class: 'mb-0', @@ -74,7 +161,7 @@ const ProjectSettings = (props) => { }), ), div( - { class: 'flex-column fx-gap-1' }, + { class: 'flex-column fx-gap-1', style: 'max-width: 700px;' }, span({ class: 'body m' }, 'Observability Integration'), Card({ class: 'mb-0', @@ -130,7 +217,98 @@ const ProjectSettings = (props) => { }), ), div( - { class: 'flex-row fx-justify-content-flex-end' }, + { class: 'flex-column fx-gap-1', style: 'max-width: 700px;' }, + span({ class: 'body m' }, 'Data Retention'), + Card({ + class: 'mb-0', + border: true, + content: div( + { class: 'flex-column fx-gap-3' }, + Checkbox({ + label: 'Automatically delete old profiling and test history', + checked: form.data_retention_enabled, + help: 'Old profiling and test runs are permanently deleted to keep the database from growing without bound. The most recent run in each test suite and table group is always kept.', + onChange: (checked) => { form.data_retention_enabled.val = checked; }, + }), + () => form.data_retention_enabled.val + ? div( + { class: 'flex-column fx-gap-3' }, + Input({ + label: 'Delete history older than (days)', + value: form.data_retention_days, + type: 'number', + step: 1, + validators: [ required, numberBetween(30, 9999, 0) ], + onChange: (value, validity) => { + form.data_retention_days.val = value === '' ? NaN : parseInt(value); + formValidity.data_retention_days.val = validity.valid; + }, + }), + () => { + const days = form.data_retention_days.val; + return days >= 30 && days < 60 + ? span( + { class: 'text-caption', style: 'color: var(--purple);' }, + 'Monitors perform better with more historical data — at least two months is recommended.', + ) + : ''; + }, + div( + { class: 'flex-row fx-gap-3 fx-flex-wrap fx-align-flex-start' }, + () => Select({ + label: 'Timezone', + options: timezones.map((tz_) => ({ label: tz_, value: tz_ })), + value: form.retention_cron_tz, + allowNull: false, + filterable: true, + onChange: (value) => { form.retention_cron_tz.val = value; }, + portalClass: 'short-select-portal', + style: 'flex: auto;', + }), + div( + { style: 'flex: auto;' }, + CrontabInput({ + emit, + name: 'data_retention_schedule', + sample: props.retention_cron_sample, + value: retentionCronEditorValue, + modes: ['x_hours', 'x_days'], + hideExpression: true, + onChange: (value) => { form.retention_cron_expr.val = value; }, + }), + ), + ), + () => { + const lastRun = getValue(props.retention_last_run); + const sample = getValue(props.retention_cron_sample) ?? {}; + const nextSample = (sample.samples ?? [])[0]; + return div( + { class: 'flex-column fx-gap-1 text-caption' }, + span(`Last cleanup ran: ${lastRun ? formatTimestamp(lastRun) : 'never'}`), + nextSample ? span(`Next cleanup: ${nextSample}`) : '', + ); + }, + () => { + if (!showRetentionConfirmation.val) return ''; + const preview = getValue(props.retention_preview); + const profilingCt = preview?.profiling_count ?? 0; + const testCt = preview?.test_count ?? 0; + const showing = preview !== null && preview !== undefined && !previewPending.val; + const message = !showing + ? 'Calculating impact…' + : `This will delete approximately ${profilingCt} profiling run${profilingCt === 1 ? '' : 's'} and ${testCt} test run${testCt === 1 ? '' : 's'} at the next cleanup. Deleted data cannot be recovered.`; + return Alert( + { type: 'warn' }, + span(message), + ); + }, + ) + : '', + ), + }), + ), + div( + { class: 'flex-row fx-justify-content-flex-end', style: 'max-width: 700px;' }, Button({ type: 'stroked', color: 'primary', diff --git a/testgen/ui/components/widgets/__init__.py b/testgen/ui/components/widgets/__init__.py index 6b3f23a9..6fdf28ce 100644 --- a/testgen/ui/components/widgets/__init__.py +++ b/testgen/ui/components/widgets/__init__.py @@ -146,3 +146,10 @@ js="pages/sidebar.js", isolate_styles=False, )) + +feedback_widget = component_v2_wrapped(components_v2.component( + name="dataops-testgen.feedback_widget", + js="pages/feedback_widget.js", + isolate_styles=False, +)) + diff --git a/testgen/ui/components/widgets/page.py b/testgen/ui/components/widgets/page.py index 737f4a55..9775306f 100644 --- a/testgen/ui/components/widgets/page.py +++ b/testgen/ui/components/widgets/page.py @@ -9,6 +9,7 @@ import testgen.common.logs as logs from testgen import settings from testgen.common import version_service +from testgen.common.mixpanel_service import MixpanelService from testgen.ui.services.rerun_service import safe_rerun from testgen.ui.session import session @@ -44,6 +45,9 @@ def page_header( st.html('
') + # Feedback widget (bottom-right) + render_feedback_widget() + # Render app logs dialog widget (outside the header container) logs_data = st.session_state.get(APP_LOGS_DIALOG_KEY) if logs_data: @@ -114,6 +118,10 @@ def open_app_logs(): close_help() st.session_state[APP_LOGS_DIALOG_KEY] = _read_log_data() + def open_feedback(): + close_help() + session.show_feedback_popup = True + with help_container.container(): flex_row_end() with st.popover("Help"): @@ -127,9 +135,11 @@ def open_app_logs(): "version": version.__dict__, "permissions": { "can_edit": session.auth.user_has_permission("edit"), + "is_logged_in": session.auth.is_logged_in, }, }, on_AppLogsClicked_change=lambda _: open_app_logs(), + on_FeedbackClicked_change=lambda _: open_feedback(), on_ExternalLinkClicked_change=lambda _: close_help(rerun=True), ) @@ -175,3 +185,36 @@ def _apply_html(html: str, container: DeltaGenerator | None = None): container.html(html) else: st.html(html) + + +def render_feedback_widget(): + """Render the feedback popup widget in the bottom-right corner. + + Visibility is driven by session.show_feedback_popup: + - set by router on session start (30-day eligibility gate) + - set when the user manually clicks "Give Feedback" + + Feedback submissions are sent to MixPanel. + """ + if not bool(session.show_feedback_popup): + return + + def on_dismissed(_): + session.show_feedback_popup = False + + def on_submitted(payload): + if payload: + MixpanelService().send_feedback( + rating=int(payload.get("rating", 0)), + comment=payload.get("comment") or None, + email=payload.get("email") or None, + ) + + from testgen.ui.components.widgets import feedback_widget + feedback_widget( + key="feedback_widget", + data={}, + on_FeedbackDismissed_change=on_dismissed, + on_FeedbackSubmitted_change=on_submitted, + ) + diff --git a/testgen/ui/navigation/router.py b/testgen/ui/navigation/router.py index c53c8759..c86636cd 100644 --- a/testgen/ui/navigation/router.py +++ b/testgen/ui/navigation/router.py @@ -2,11 +2,14 @@ import logging import time +from datetime import UTC, datetime import streamlit as st import testgen.ui.navigation.page +from testgen import settings from testgen.common.mixpanel_service import MixpanelService +from testgen.common.models.user import FEEDBACK_POPUP_INTERVAL, PreferenceKey from testgen.ui.session import session from testgen.utils.singleton import Singleton @@ -31,6 +34,25 @@ def _init_session(self, url: str): source = st.query_params.pop("source", None) MixpanelService().send_event(f"nav-{url}", page_load=True, source=source) + def _evaluate_feedback_popup(self) -> None: + session.show_feedback_popup = False + try: + if settings.DISABLE_FEEDBACK_POPUP or not (user := session.auth.user): + return + + if (last_popup_str := user.get_preference(PreferenceKey.LAST_FEEDBACK_POPUP)): + try: + last_popup_dt = datetime.fromisoformat(last_popup_str) + if datetime.now(UTC) - last_popup_dt < FEEDBACK_POPUP_INTERVAL: + return + except (ValueError, TypeError): + pass # Corrupted value — treat as no prior popup + + user.set_preference(PreferenceKey.LAST_FEEDBACK_POPUP, datetime.now(UTC).isoformat()) + session.show_feedback_popup = True + except Exception: + LOG.exception("Error evaluating feedback popup eligibility") + def run(self) -> None: streamlit_pages = [route.streamlit_page for route in self._routes.values()] @@ -63,6 +85,9 @@ def run(self) -> None: st.query_params.from_dict(session.page_args_pending_router) session.page_args_pending_router = None + if session.show_feedback_popup is None and session.auth.is_logged_in: + self._evaluate_feedback_popup() + session.current_page = current_page.url_path current_page.run() else: diff --git a/testgen/ui/pdf/test_result_report.py b/testgen/ui/pdf/test_result_report.py index a7485c7c..9b57ff73 100644 --- a/testgen/ui/pdf/test_result_report.py +++ b/testgen/ui/pdf/test_result_report.py @@ -59,7 +59,7 @@ def build_summary_table(document, tr_data): *[ (cmd[0], *coords, *cmd[1:]) for coords in ( - ((3, 3), (3, -3)), + ((3, 3), (3, -4)), ((0, 0), (0, -2)) ) for cmd in ( @@ -83,10 +83,11 @@ def build_summary_table(document, tr_data): ("SPAN", (1, 6), (2, 6)), ("SPAN", (4, 6), (5, 6)), ("SPAN", (1, 7), (5, 7)), - ("SPAN", (0, 8), (5, 8)), + ("SPAN", (1, 8), (5, 8)), + ("SPAN", (0, 9), (5, 9)), # Link cell - ("BACKGROUND", (0, 8), (5, 8), colors.white), + ("BACKGROUND", (0, 9), (5, 9), colors.white), # Measure cell ("FONT", (1, 1), (1, 1), "Helvetica-Bold"), diff --git a/testgen/ui/queries/table_group_queries.py b/testgen/ui/queries/table_group_queries.py index 8cc32c73..0db27e12 100644 --- a/testgen/ui/queries/table_group_queries.py +++ b/testgen/ui/queries/table_group_queries.py @@ -5,11 +5,14 @@ import streamlit as st -from testgen.commands.queries.refresh_data_chars_query import ColumnChars, RefreshDataCharsSQL +from testgen.commands.queries.refresh_data_chars_query import RefreshDataCharsSQL from testgen.commands.run_refresh_data_chars import write_data_chars +from testgen.common.database.column_chars import ColumnChars +from testgen.common.database.flavor.flavor_service import resolve_connection_params from testgen.common.models.connection import Connection from testgen.common.models.table_group import TableGroup from testgen.ui.services.database_service import fetch_from_target_db +from testgen.ui.services.query_cache import get_connection class StatsPreview(TypedDict): @@ -54,7 +57,7 @@ def get_table_group_preview( if connection or table_group.connection_id: try: - connection = connection or Connection.get(table_group.connection_id) + connection = connection or get_connection(table_group.connection_id) table_group_preview, data_chars, sql_generator = _get_preview(table_group, connection) def save_data_chars(table_group_id: UUID) -> None: @@ -109,8 +112,13 @@ def _get_preview( connection: Connection, ) -> tuple[TableGroupPreview, list[ColumnChars], RefreshDataCharsSQL]: sql_generator = RefreshDataCharsSQL(connection, table_group) - data_chars = fetch_from_target_db(connection, *sql_generator.get_schema_ddf()) - data_chars = [ColumnChars(**column) for column in data_chars] + if sql_generator.flavor_service.metadata_via_api: + params = resolve_connection_params(connection.__dict__) + api_columns = sql_generator.flavor_service.get_schema_columns(params, table_group.table_group_schema) or [] + data_chars = sql_generator.filter_schema_columns(api_columns) + else: + rows = fetch_from_target_db(connection, *sql_generator.get_schema_ddf()) + data_chars = [ColumnChars(**column) for column in rows] preview: TableGroupPreview = { "stats": { diff --git a/testgen/ui/scripts/patch_streamlit.py b/testgen/ui/scripts/patch_streamlit.py index 16de43cc..88476737 100644 --- a/testgen/ui/scripts/patch_streamlit.py +++ b/testgen/ui/scripts/patch_streamlit.py @@ -1,4 +1,3 @@ -# ruff: noqa: TRY002 import pathlib import re diff --git a/testgen/ui/services/query_cache.py b/testgen/ui/services/query_cache.py index 7dca4918..55be1e67 100644 --- a/testgen/ui/services/query_cache.py +++ b/testgen/ui/services/query_cache.py @@ -12,14 +12,28 @@ import streamlit as st -from testgen.common.models.connection import Connection -from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunSummary +from testgen.common.models.connection import Connection, ConnectionMinimal +from testgen.common.models.entity import ENTITY_HASH_FUNCS +from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunMinimal, ProfilingRunSummary from testgen.common.models.project import Project, ProjectSummary from testgen.common.models.project_membership import ProjectMembership -from testgen.common.models.table_group import TableGroup, TableGroupStats, TableGroupSummary -from testgen.common.models.test_definition import TestType, TestTypeSummary -from testgen.common.models.test_run import TestRun, TestRunSummary -from testgen.common.models.test_suite import TestSuite, TestSuiteSummary +from testgen.common.models.scheduler import RUN_MONITORS_JOB_KEY, JobSchedule +from testgen.common.models.table_group import ( + TableGroup, + TableGroupMinimal, + TableGroupStats, + TableGroupSummary, +) +from testgen.common.models.test_definition import ( + TestDefinition, + TestDefinitionMinimal, + TestDefinitionSummary, + TestType, + TestTypeSummary, +) +from testgen.common.models.test_run import TestRun, TestRunMinimal, TestRunSummary +from testgen.common.models.test_suite import TestSuite, TestSuiteMinimal, TestSuiteSummary +from testgen.common.models.user import User # -- Project ------------------------------------------------------------------ @@ -126,3 +140,173 @@ def get_profiling_run_summaries( page_size: int = 20, ) -> tuple[list[ProfilingRunSummary], int]: return ProfilingRun.select_summary(project_code, table_group_id, page=page, page_size=page_size) + + +# -- JobSchedule -------------------------------------------------------------- + +@st.cache_data(show_spinner=False) +def get_monitor_schedule(monitor_suite_id: str | UUID) -> JobSchedule | None: + return JobSchedule.get( + JobSchedule.key == RUN_MONITORS_JOB_KEY, + JobSchedule.kwargs["test_suite_id"].astext == str(monitor_suite_id), + ) + + +# -- Connection --------------------------------------------------------------- + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def get_connection(identifier: str | int | UUID, *clauses) -> Connection | None: + return Connection.get(identifier, *clauses) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_connections_where(*clauses, order_by=None) -> list[Connection]: + return list(Connection.select_where(*clauses, order_by=order_by)) + + +@st.cache_data(show_spinner=False) +def get_connection_minimal(identifier: int) -> ConnectionMinimal | None: + return Connection.get_minimal(identifier) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_connections_minimal_where(*clauses, order_by=None) -> list[ConnectionMinimal]: + if order_by is None: + return list(Connection.select_minimal_where(*clauses)) + return list(Connection.select_minimal_where(*clauses, order_by=order_by)) + + +# -- User --------------------------------------------------------------------- + +@st.cache_data(show_spinner=False) +def get_user(identifier: str) -> User | None: + return User.get(identifier) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_users_where(*clauses, order_by=None) -> list[User]: + return list(User.select_where(*clauses, order_by=order_by)) + + +# -- TableGroup --------------------------------------------------------------- + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def get_table_group(identifier: str | UUID, *clauses) -> TableGroup | None: + return TableGroup.get(identifier, *clauses) + + +@st.cache_data(show_spinner=False) +def get_table_group_minimal(identifier: str | UUID) -> TableGroupMinimal | None: + return TableGroup.get_minimal(identifier) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_table_groups_minimal_where(*clauses, order_by=None) -> list[TableGroupMinimal]: + if order_by is None: + return list(TableGroup.select_minimal_where(*clauses)) + return list(TableGroup.select_minimal_where(*clauses, order_by=order_by)) + + +# -- TestSuite ---------------------------------------------------------------- + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def get_test_suite(identifier: str | UUID, *clauses) -> TestSuite | None: + return TestSuite.get(identifier, *clauses) + + +@st.cache_data(show_spinner=False) +def get_test_suite_minimal(identifier: int) -> TestSuiteMinimal | None: + return TestSuite.get_minimal(identifier) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_test_suites_minimal_where(*clauses, order_by=None) -> list[TestSuiteMinimal]: + if order_by is None: + return list(TestSuite.select_minimal_where(*clauses)) + return list(TestSuite.select_minimal_where(*clauses, order_by=order_by)) + + +# -- TestRun ------------------------------------------------------------------ + +@st.cache_data(show_spinner=False) +def get_test_run_minimal(run_id: str | UUID) -> TestRunMinimal | None: + return TestRun.get_minimal(run_id) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_test_runs_where(*clauses, order_by=None) -> list[TestRun]: + return list(TestRun.select_where(*clauses, order_by=order_by)) + + +# -- ProfilingRun ------------------------------------------------------------- + +@st.cache_data(show_spinner=False) +def get_profiling_run_minimal(run_id: str | UUID) -> ProfilingRunMinimal | None: + return ProfilingRun.get_minimal(run_id) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_profiling_runs_where(*clauses, order_by=None) -> list[ProfilingRun]: + return list(ProfilingRun.select_where(*clauses, order_by=order_by)) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_profiling_runs_minimal_where(*clauses, order_by=None) -> list[ProfilingRunMinimal]: + if order_by is None: + return list(ProfilingRun.select_minimal_where(*clauses)) + return list(ProfilingRun.select_minimal_where(*clauses, order_by=order_by)) + + +# -- TestDefinition ----------------------------------------------------------- + +@st.cache_data(show_spinner=False) +def get_test_definition(identifier: str | UUID) -> TestDefinitionSummary | None: + return TestDefinition.get(identifier) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_test_definitions_where(*clauses, order_by=None) -> list[TestDefinitionSummary]: + if order_by is None: + return list(TestDefinition.select_where(*clauses)) + return list(TestDefinition.select_where(*clauses, order_by=order_by)) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_test_definitions_minimal_where(*clauses, order_by=None) -> list[TestDefinitionMinimal]: + if order_by is None: + return list(TestDefinition.select_minimal_where(*clauses)) + return list(TestDefinition.select_minimal_where(*clauses, order_by=order_by)) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_test_definitions_page( + *clauses, + order_by=None, + page: int = 1, + limit: int = 500, +) -> tuple[list[TestDefinitionSummary], int]: + return TestDefinition.select_page(*clauses, order_by=order_by, page=page, limit=limit) + + +# -- Project ------------------------------------------------------------------ + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def get_project(identifier: str, *clauses) -> Project | None: + return Project.get(identifier, *clauses) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_projects_where(*clauses, order_by=None) -> list[Project]: + return list(Project.select_where(*clauses, order_by=order_by)) + + +# -- ProjectMembership -------------------------------------------------------- + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def get_project_membership(identifier: str | UUID, *clauses) -> ProjectMembership | None: + return ProjectMembership.get(identifier, *clauses) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_project_memberships_where(*clauses, order_by=None) -> list[ProjectMembership]: + return list(ProjectMembership.select_where(*clauses, order_by=order_by)) diff --git a/testgen/ui/session.py b/testgen/ui/session.py index 9f50ed33..5b8389f5 100644 --- a/testgen/ui/session.py +++ b/testgen/ui/session.py @@ -35,6 +35,8 @@ class TestgenSession(Singleton): add_project: bool version: Version | None + show_feedback_popup: bool | None + testgen_event_id: ClassVar[dict[str, str]] = {} sidebar_event_id: str | None link_event_id: str | None diff --git a/testgen/ui/static/css/style.css b/testgen/ui/static/css/style.css index 01dee345..66b56aa5 100644 --- a/testgen/ui/static/css/style.css +++ b/testgen/ui/static/css/style.css @@ -422,6 +422,10 @@ Use as testgen.text("text", "extra_styles") */ div[data-testid="stPopoverBody"]:has(i.tg-header--help-wrapper) { padding: 0; } + +.st-key-feedback_widget { + z-index: 9999; +} /* */ /* Summary bar component */ diff --git a/testgen/ui/static/js/components/alert.js b/testgen/ui/static/js/components/alert.js index c01f2fc8..76d4e7a5 100644 --- a/testgen/ui/static/js/components/alert.js +++ b/testgen/ui/static/js/components/alert.js @@ -7,7 +7,6 @@ * @property {string?} class * @property {'info'|'success'|'warn'|'error'} type * @property {Function?} onClose - * @property {string?} testId */ import van from '../van.min.js'; import { getValue, loadStylesheet, getRandomId } from '../utils.js'; @@ -32,7 +31,7 @@ const Alert = (/** @type Properties */ props, /** @type Array */ .. { ...props, id: elementId, - 'data-testid': getValue(props.testId) ?? '', + 'data-testid': 'alert', class: () => `tg-alert flex-row ${getValue(props.class) ?? ''} tg-alert-${getValue(props.type)}`, role: 'alert', }, diff --git a/testgen/ui/static/js/components/attribute.js b/testgen/ui/static/js/components/attribute.js index a7bb60eb..8e8ab4bc 100644 --- a/testgen/ui/static/js/components/attribute.js +++ b/testgen/ui/static/js/components/attribute.js @@ -19,9 +19,13 @@ const Attribute = (/** @type Properties */ props) => { loadStylesheet('attribute', stylesheet); return div( - { style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}`, class: props.class }, + { + 'data-testid': 'attribute', + style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}`, + class: props.class, + }, div( - { class: 'flex-row fx-gap-1 text-caption mb-1' }, + { 'data-testid': 'attribute-label', class: 'flex-row fx-gap-1 text-caption mb-1' }, props.label, () => getValue(props.help) ? withTooltip( @@ -31,7 +35,7 @@ const Attribute = (/** @type Properties */ props) => { : null, ), div( - { class: 'attribute-value' }, + { 'data-testid': 'attribute-value', class: 'attribute-value' }, () => { const value = getValue(props.value); if (value === PII_REDACTED) { diff --git a/testgen/ui/static/js/components/breadcrumbs.js b/testgen/ui/static/js/components/breadcrumbs.js index 5280dd22..94f7f71a 100644 --- a/testgen/ui/static/js/components/breadcrumbs.js +++ b/testgen/ui/static/js/components/breadcrumbs.js @@ -8,7 +8,6 @@ * @typedef Properties * @type {object} * @property {Array.} breadcrumbs - * @property {string?} testId */ import van from '../van.min.js'; import { getValue, loadStylesheet } from '../utils.js'; @@ -18,10 +17,8 @@ const { a, div, span } = van.tags; const Breadcrumbs = (/** @type Properties */ props) => { loadStylesheet('breadcrumbs', stylesheet); - const testId = getValue(props.testId) ?? ''; - return div( - { class: 'tg-breadcrumbs-wrapper', 'data-testid': testId }, + { class: 'tg-breadcrumbs-wrapper', 'data-testid': 'breadcrumbs' }, () => { const breadcrumbs = getValue(props.breadcrumbs) || []; @@ -30,7 +27,7 @@ const Breadcrumbs = (/** @type Properties */ props) => { breadcrumbs.reduce((items, b, idx) => { const isLastItem = idx === breadcrumbs.length - 1; items.push(a({ - 'data-testid': testId ? `${testId}-item-${idx}` : '', + 'data-testid': 'breadcrumb-item', class: `tg-breadcrumbs--${ isLastItem ? 'current' : 'active'}`, onclick: (event) => { event.preventDefault(); diff --git a/testgen/ui/static/js/components/button.js b/testgen/ui/static/js/components/button.js index e839fc88..31f3b870 100644 --- a/testgen/ui/static/js/components/button.js +++ b/testgen/ui/static/js/components/button.js @@ -15,7 +15,6 @@ * @property {(bool)} loading * @property {('normal' | 'small')?} size * @property {string?} style - * @property {string?} testId */ import { getValue, loadStylesheet } from '../utils.js'; import van from '../van.min.js'; @@ -49,7 +48,7 @@ const Button = (/** @type Properties */ props) => { style: () => `width: ${isIconOnly ? '' : (width ?? '100%')}; ${getValue(props.style)}`, onclick: onClickHandler, disabled: isDisabled, - 'data-testid': getValue(props.testId) ?? '', + 'data-testid': 'button', }, span({class: 'tg-button-focus-state-indicator'}, ''), props.icon ? i({ diff --git a/testgen/ui/static/js/components/caption.js b/testgen/ui/static/js/components/caption.js index 8f7f21f4..e1820356 100644 --- a/testgen/ui/static/js/components/caption.js +++ b/testgen/ui/static/js/components/caption.js @@ -13,7 +13,7 @@ const Caption = (/** @type Properties */ props) => { loadStylesheet('caption', stylesheet); return span( - { class: 'tg-caption', style: props.style }, + { class: 'tg-caption', style: props.style, 'data-testid': 'caption' }, props.content ); } diff --git a/testgen/ui/static/js/components/card.js b/testgen/ui/static/js/components/card.js index 9102947e..54c1ee8b 100644 --- a/testgen/ui/static/js/components/card.js +++ b/testgen/ui/static/js/components/card.js @@ -19,7 +19,7 @@ const Card = (/** @type Properties */ props) => { return div( { id: props.id ?? '', - 'data-testid': props.testId ?? '', + 'data-testid': props.testId ?? 'card', class: () => { const classes = ['tg-card']; if (getValue(props.border)) { diff --git a/testgen/ui/static/js/components/checkbox.js b/testgen/ui/static/js/components/checkbox.js index da7ed63f..35c69150 100644 --- a/testgen/ui/static/js/components/checkbox.js +++ b/testgen/ui/static/js/components/checkbox.js @@ -8,7 +8,6 @@ * @property {boolean?} indeterminate * @property {function(boolean, Event)?} onChange * @property {number?} width - * @property {string?} testId * @property {boolean?} disabled */ import van from '../van.min.js'; @@ -31,7 +30,7 @@ const Checkbox = (/** @type Properties */ props) => { return label( { class: 'flex-row fx-gap-2 clickable', - 'data-testid': props.testId ?? props.name ?? '', + 'data-testid': 'checkbox', style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}`, }, input({ diff --git a/testgen/ui/static/js/components/code.js b/testgen/ui/static/js/components/code.js index 414bd968..bada687d 100644 --- a/testgen/ui/static/js/components/code.js +++ b/testgen/ui/static/js/components/code.js @@ -2,7 +2,6 @@ * @typedef Options * @type {object} * @property {string?} id - * @property {string?} testId * @property {string?} class * @property {string?} language - Language for syntax highlighting (e.g. 'sql', 'html'). Omit for no highlighting. */ @@ -32,7 +31,7 @@ const Code = (options, ...children) => { ); const el = div( - { id: domId, class: `tg-code ${options.class ?? ''}`, 'data-testid': options.testId ?? '' }, + { id: domId, class: `tg-code ${options.class ?? ''}`, 'data-testid': 'code' }, pre({}, codeEl), Icon( { diff --git a/testgen/ui/static/js/components/connection_form.js b/testgen/ui/static/js/components/connection_form.js index 4c5ec91c..1fadaacf 100644 --- a/testgen/ui/static/js/components/connection_form.js +++ b/testgen/ui/static/js/components/connection_form.js @@ -87,6 +87,9 @@ const defaultPorts = { sap_hana: '39015', }; +// Salesforce Data 360's Hyper engine has a lower expression-depth limit than other databases +const defaultMaxQueryChars = (flavorCode) => flavorCode === 'salesforce_data360' ? 15000 : 20000; + /** * * @param {Properties} props @@ -114,7 +117,7 @@ const ConnectionForm = (props, saveButton) => { const connectionFlavor = van.state(connection?.sql_flavor_code); const connectionName = van.state(connection?.connection_name ?? ''); const connectionMaxThreads = van.state(connection?.max_threads ?? 4); - const connectionQueryChars = van.state(connection?.max_query_chars ?? 20000); + const connectionQueryChars = van.state(connection?.max_query_chars ?? defaultMaxQueryChars(connection?.sql_flavor_code)); const privateKeyFile = van.state(getValue(props.cachedPrivateKeyFile) ?? null); const serviceAccountKeyFile = van.state(getValue(props.cachedServiceAccountKeyFile) ?? null); @@ -139,7 +142,7 @@ const ConnectionForm = (props, saveButton) => { sql_flavor_code: connectionFlavor.rawVal ?? '', connection_name: connectionName.rawVal ?? '', max_threads: connectionMaxThreads.rawVal ?? 4, - max_query_chars: connectionQueryChars.rawVal ?? 20000, + max_query_chars: connectionQueryChars.rawVal ?? defaultMaxQueryChars(connectionFlavor.rawVal), }); const dynamicConnectionUrl = van.state(props.dynamicConnectionUrl?.rawVal ?? ''); @@ -179,6 +182,7 @@ const ConnectionForm = (props, saveButton) => { setFieldValidity('redshift_spectrum_form', isValid); }, connection, + dynamicConnectionUrl, ), azure_mssql: () => AzureMSSQLForm( updatedConnection, @@ -274,6 +278,17 @@ const ConnectionForm = (props, saveButton) => { connection, getValue(props.cachedServiceAccountKeyFile) ?? null ), + salesforce_data360: () => SalesforceData360Form( + updatedConnection, + getValue(props.flavors).find(f => f.value === connectionFlavor.rawVal), + (formValue, fileValue, isValid) => { + updatedConnection.val = {...updatedConnection.val, ...formValue}; + privateKeyFile.val = fileValue; + setFieldValidity('salesforce_data360_form', isValid); + }, + connection, + getValue(props.cachedPrivateKeyFile) ?? null, + ), }; const setFieldValidity = (field, validity) => { @@ -287,17 +302,6 @@ const ConnectionForm = (props, saveButton) => { return authenticationForms[flavor.value](); }); - van.derive(() => { - const selectedFlavorCode = connectionFlavor.val; - const previousFlavorCode = connectionFlavor.oldVal; - const updatedConnection_ = updatedConnection.rawVal; - - const isCustomPort = updatedConnection_?.project_port !== defaultPorts[previousFlavorCode]; - if (selectedFlavorCode !== previousFlavorCode && (!isCustomPort || !updatedConnection_?.project_port)) { - updatedConnection.val = {...updatedConnection_, project_port: defaultPorts[selectedFlavorCode]}; - } - }); - van.derive(() => { const selectedFlavor = connectionFlavor.val; const flavorObject = getValue(props.flavors).find(f => f.value === selectedFlavor); @@ -331,7 +335,6 @@ const ConnectionForm = (props, saveButton) => { options: props.flavors, disabled: props.disableFlavor, help: 'Type of database server to connect to. This determines the database driver and SQL dialect that will be used by TestGen.', - testId: 'sql_flavor', }), Input({ name: 'connection_name', @@ -410,7 +413,6 @@ const ConnectionForm = (props, saveButton) => { /** * @param {VanState} connection * @param {Flavor} flavor - * @param {boolean} maskPassword * @param {(params: Partial, isValid: boolean) => void} onChange * @param {Connection?} originalConnection * @param {VanState} dynamicConnectionUrl @@ -789,7 +791,6 @@ const MSSQLForm = RedshiftForm; /** * @param {VanState} connection * @param {Flavor} flavor - * @param {boolean} maskPassword * @param {(params: Partial, isValid: boolean) => void} onChange * @param {Connection?} originalConnection * @param {VanState} dynamicConnectionUrl @@ -1031,7 +1032,6 @@ const DatabricksForm = ( /** * @param {VanState} connection * @param {Flavor} flavor - * @param {boolean} maskPassword * @param {(params: Partial, fileValue: FileValue, isValid: boolean) => void} onChange * @param {Connection?} originalConnection * @param {string?} cachedFile @@ -1328,6 +1328,168 @@ const SnowflakeForm = ( ); }; +/** + * @param {VanState} connection + * @param {Flavor} flavor + * @param {(params: Partial, fileValue: FileValue, isValid: boolean) => void} onChange + * @param {Connection?} originalConnection + * @param {string?} cachedFile + * @returns {HTMLElement} + */ +const SalesforceData360Form = ( + connection, + flavor, + onChange, + originalConnection, + cachedFile, +) => { + const isValid = van.state(false); + const authMethod = van.state( + originalConnection?.connection_id + ? (connection.rawVal.connect_by_key ? 'jwt' : 'client_credentials') + : 'jwt' + ); + const loginUrl = van.state(connection.rawVal.project_host ?? ''); + const consumerKey = van.state(connection.rawVal.project_user ?? ''); + const consumerSecret = van.state(connection.rawVal?.project_pw_encrypted ?? ''); + const permittedUser = van.state(connection.rawVal.project_db ?? ''); + const connectionPrivateKey = van.state(connection.rawVal?.private_key ?? ''); + + const validityPerField = {}; + + const privateKeyFileRaw = van.state(cachedFile); + + van.derive(() => { + onChange({ + project_host: loginUrl.val, + project_user: consumerKey.val, + project_pw_encrypted: consumerSecret.val, + project_db: permittedUser.val, + connect_by_key: authMethod.val === 'jwt', + private_key: connectionPrivateKey.val, + }, privateKeyFileRaw.val, isValid.val); + }); + + return div( + { class: 'flex-column fx-gap-3 fx-flex' }, + div( + { class: 'flex-column border border-radius-1 p-3 mt-1 fx-gap-1', style: 'position: relative;' }, + Caption({ content: 'Org', style: 'position: absolute; top: -10px; background: var(--app-background-color); padding: 0px 8px;' }), + Input({ + name: 'login_url', + label: 'Login URL', + help: 'My Domain URL of the Salesforce org', + value: loginUrl, + onChange: (value, state) => { + loginUrl.val = value; + validityPerField['login_url'] = state.valid; + isValid.val = Object.values(validityPerField).every(v => v); + }, + validators: [required, maxLength(250)], + }), + ), + div( + { class: 'flex-column border border-radius-1 p-3 mt-1 fx-gap-1', style: 'position: relative;' }, + Caption({ content: 'Authentication', style: 'position: absolute; top: -10px; background: var(--app-background-color); padding: 0px 8px;' }), + RadioGroup({ + label: 'Connection Strategy', + options: [ + { label: 'JWT Bearer Flow', value: 'jwt' }, + { label: 'Client Credentials Flow', value: 'client_credentials' }, + ], + value: authMethod, + onChange: (value) => { + authMethod.val = value; + if (value === 'jwt') { + delete validityPerField['consumer_secret']; + } else { + delete validityPerField['permitted_user']; + delete validityPerField['private_key']; + } + isValid.val = Object.values(validityPerField).every(v => v); + }, + layout: 'inline', + }), + Input({ + name: 'consumer_key', + label: 'Consumer Key', + help: 'Consumer key from the Salesforce external client app', + value: consumerKey, + onChange: (value, state) => { + consumerKey.val = value; + validityPerField['consumer_key'] = state.valid; + isValid.val = Object.values(validityPerField).every(v => v); + }, + validators: [required, maxLength(250)], + }), + () => { + if (authMethod.val === 'jwt') { + return div( + { class: 'flex-column fx-gap-3' }, + Input({ + name: 'permitted_user', + label: 'Username', + help: 'Salesforce user the JWT token will impersonate. Must be pre-authorized on the external client app.', + value: permittedUser, + onChange: (value, state) => { + permittedUser.val = value; + validityPerField['permitted_user'] = state.valid; + isValid.val = Object.values(validityPerField).every(v => v); + }, + validators: [required, maxLength(250)], + }), + FileInput({ + name: 'private_key', + label: 'Upload private key (.pem, .key)', + placeholder: (originalConnection?.connection_id && originalConnection?.private_key) + ? 'Drop file here or browse files to replace existing key' + : undefined, + value: privateKeyFileRaw, + onChange: (value, state) => { + let isFieldValid = state.valid; + + privateKeyFileRaw.val = value; + try { + if (value?.content) { + connectionPrivateKey.val = value.content.split(',')?.[1] ?? ''; + } + } catch (err) { + console.error(err); + isFieldValid = false; + } + + validityPerField['private_key'] = isFieldValid; + isValid.val = Object.values(validityPerField).every(v => v); + }, + validators: [ + requiredIf(() => !originalConnection?.connection_id || !originalConnection?.private_key), + sizeLimit(200 * 1024 * 1024), + ], + }), + ); + }; + return Input({ + name: 'consumer_secret', + label: 'Consumer Secret', + help: 'Consumer secret from the Salesforce external client app', + type: 'password', + passwordSuggestions: false, + value: consumerSecret, + placeholder: (originalConnection?.connection_id && originalConnection?.project_pw_encrypted) ? secretsPlaceholder : '', + onChange: (value, state) => { + consumerSecret.val = value; + validityPerField['consumer_secret'] = state.valid; + isValid.val = Object.values(validityPerField).every(v => v); + }, + validators: [ + requiredIf(() => !originalConnection?.connection_id || !originalConnection?.project_pw_encrypted), + ], + }); + }, + ), + ); +}; + /** * @param {VanState} connection * @param {Flavor} flavor diff --git a/testgen/ui/static/js/components/crontab_input.js b/testgen/ui/static/js/components/crontab_input.js index cd60de89..ee607310 100644 --- a/testgen/ui/static/js/components/crontab_input.js +++ b/testgen/ui/static/js/components/crontab_input.js @@ -16,7 +16,7 @@ * @type {object} * @property {(string|null)} id * @property {(string|null)} name - * @property {string?} testId + * @property {string?} class * @property {CronSample?} sample * @property {InitialValue?} value @@ -68,7 +68,7 @@ const CrontabInput = (/** @type Options */ props) => { id: domId, class: () => `tg-crontab-input ${getValue(props.class) ?? ''}`, style: 'position: relative', - 'data-testid': getValue(props.testId) ?? null, + 'data-testid': 'crontab-input', }, div( {onclick: () => { diff --git a/testgen/ui/static/js/components/dialog.js b/testgen/ui/static/js/components/dialog.js index be8f09ef..465c3faf 100644 --- a/testgen/ui/static/js/components/dialog.js +++ b/testgen/ui/static/js/components/dialog.js @@ -5,7 +5,6 @@ * @property {import('../van.min.js').State} open - Reactive open state * @property {Function} onClose - Called when the dialog is closed (backdrop click or X button) * @property {string} [width] - CSS width value, default '30rem' - * @property {string?} testId */ import van from '../van.min.js'; import { getValue, loadStylesheet } from '../utils.js'; @@ -29,22 +28,19 @@ const { button, div, i, span } = van.tags; * @param {DialogProps} props * @param {...(Element | string)} children - Content rendered in the dialog body */ -const Dialog = ({ title, open, onClose, width = '30rem', testId }, ...children) => { +const Dialog = ({ title, open, onClose, width = '30rem' }, ...children) => { loadStylesheet('dialog', stylesheet); - const testIdValue = getValue(testId) ?? ''; - const overlay = div( { class: 'tg-dialog-overlay', - 'data-testid': testIdValue ? `${testIdValue}-backdrop` : '', style: () => open.val ? '' : 'display: none', onclick: () => onClose(), }, div( { class: 'tg-dialog', - 'data-testid': testIdValue, + 'data-testid': 'dialog', role: 'dialog', 'aria-modal': 'true', tabindex: '-1', @@ -53,13 +49,13 @@ const Dialog = ({ title, open, onClose, width = '30rem', testId }, ...children) }, div( { class: 'tg-dialog-header' }, - span({ 'data-testid': testIdValue ? `${testIdValue}-title` : '', class: 'tg-dialog-title' }, title), + span({ 'data-testid': 'dialog-title', class: 'tg-dialog-title' }, title), ), - div({ class: 'tg-dialog-content' }, ...children), + div({ 'data-testid': 'dialog-content', class: 'tg-dialog-content' }, ...children), button( { class: 'tg-dialog-close', - 'data-testid': testIdValue ? `${testIdValue}-close` : '', + 'data-testid': 'dialog-close', 'aria-label': 'Close', onclick: () => onClose(), }, diff --git a/testgen/ui/static/js/components/dropdown_button.js b/testgen/ui/static/js/components/dropdown_button.js index e97fdce6..b3045466 100644 --- a/testgen/ui/static/js/components/dropdown_button.js +++ b/testgen/ui/static/js/components/dropdown_button.js @@ -45,10 +45,11 @@ const DropdownButton = (props) => { () => { const items = typeof props.items === 'function' ? props.items() : props.items; return div( - { class: 'tg-dropdown-button--menu' }, + { class: 'tg-dropdown-button--menu', 'data-testid': 'dropdown-menu' }, ...items.map(item => div({ class: 'tg-dropdown-button--item', + 'data-testid': 'dropdown-item', style: item.separator ? 'border-top: var(--button-stroked-border);' : '', onclick: () => { menuOpen.val = false; item.onclick(); }, }, item.label), diff --git a/testgen/ui/static/js/components/empty_state.js b/testgen/ui/static/js/components/empty_state.js index d5240c7a..f8497b67 100644 --- a/testgen/ui/static/js/components/empty_state.js +++ b/testgen/ui/static/js/components/empty_state.js @@ -17,7 +17,7 @@ * @property {Link?} link * @property {any?} button * @property {string?} class -* @property {string?} testId + */ import van from '../van.min.js'; import { Card } from '../components/card.js'; @@ -70,7 +70,6 @@ const EmptyState = (/** @type Properties */ props) => { loadStylesheet('empty-state', stylesheet); return Card({ - testId: getValue(props.testId), class: `tg-empty-state flex-column fx-align-flex-center ${getValue(props.class ?? '')}`, content: [ span({ class: 'tg-empty-state--title mb-5' }, props.label), diff --git a/testgen/ui/static/js/components/expansion_panel.js b/testgen/ui/static/js/components/expansion_panel.js index 2cd5dd21..f584b127 100644 --- a/testgen/ui/static/js/components/expansion_panel.js +++ b/testgen/ui/static/js/components/expansion_panel.js @@ -2,7 +2,6 @@ * @typedef Options * @type {object} * @property {string} title - * @property {string?} testId * @property {bool} expanded */ @@ -46,7 +45,7 @@ const ExpansionPanel = (options, ...children) => { }); return div( - { class: 'tg-expansion-panel', 'data-testid': options.testId ?? '' }, + { class: 'tg-expansion-panel', 'data-testid': 'expansion-panel' }, titleDiv, contentDiv, ); diff --git a/testgen/ui/static/js/components/file_input.js b/testgen/ui/static/js/components/file_input.js index 5845c319..3abb66d4 100644 --- a/testgen/ui/static/js/components/file_input.js +++ b/testgen/ui/static/js/components/file_input.js @@ -113,7 +113,7 @@ const FileInput = (options) => { }; return div( - { class: cssClass }, + { class: cssClass, 'data-testid': 'file-input' }, div( { class: 'tg-file-uploader--label text-caption flex-row fx-gap-1' }, options.label, diff --git a/testgen/ui/static/js/components/frequency_bars.js b/testgen/ui/static/js/components/frequency_bars.js index d26073ce..1a65ea38 100644 --- a/testgen/ui/static/js/components/frequency_bars.js +++ b/testgen/ui/static/js/components/frequency_bars.js @@ -36,6 +36,7 @@ const FrequencyBars = (/** @type Properties */ props) => { }); return () => div( + { 'data-testid': 'frequency-bars' }, div( { class: 'mb-2 text-secondary' }, props.title, diff --git a/testgen/ui/static/js/components/help_menu.js b/testgen/ui/static/js/components/help_menu.js index 2a2fd9cb..33d04f44 100644 --- a/testgen/ui/static/js/components/help_menu.js +++ b/testgen/ui/static/js/components/help_menu.js @@ -8,6 +8,7 @@ * @typedef Permissions * @type {object} * @property {boolean} can_edit + * @property {boolean} is_logged_in * * @typedef Properties * @type {object} @@ -71,6 +72,13 @@ const HelpMenu = (/** @type Properties */ props) => { ) : null, span({ class: 'help-divider' }), + getValue(props.permissions)?.is_logged_in + ? div( + { class: 'help-item', onclick: () => emit('FeedbackClicked') }, + Icon({ classes: 'help-item-icon' }, 'rate_review'), + 'Give Feedback', + ) + : null, HelpLink(slackUrl, 'Slack Community', 'group'), getValue(props.support_email) ? HelpLink( diff --git a/testgen/ui/static/js/components/input.js b/testgen/ui/static/js/components/input.js index 1efb0924..75663e35 100644 --- a/testgen/ui/static/js/components/input.js +++ b/testgen/ui/static/js/components/input.js @@ -30,7 +30,6 @@ * @property {string?} style * @property {string?} type * @property {string?} class - * @property {string?} testId * @property {any?} prefix * @property {number} step * @property {Array?} validators @@ -103,7 +102,7 @@ const Input = (/** @type Properties */ props) => { id: domId, class: () => `flex-column fx-gap-1 tg-input--label ${getValue(props.class) ?? ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}; ${getValue(props.style)}`, - 'data-testid': props.testId ?? props.name ?? '', + 'data-testid': 'input', }, div( { class: 'flex-row fx-gap-1 text-caption' }, @@ -137,6 +136,7 @@ const Input = (/** @type Properties */ props) => { name: props.name ?? '', type: inputType, disabled: props.disabled, + ...(inputType.val !== 'password' ? {'data-value': value} : {}), ...(inputType.val === 'number' ? {step: getValue(props.step)} : {}), ...(props.readonly ? {readonly: true} : {}), ...(props.passwordSuggestions ?? true ? {} : {autocomplete: 'off', 'data-op-ignore': true}), diff --git a/testgen/ui/static/js/components/line_chart.js b/testgen/ui/static/js/components/line_chart.js index fd16bd06..722588ce 100644 --- a/testgen/ui/static/js/components/line_chart.js +++ b/testgen/ui/static/js/components/line_chart.js @@ -223,7 +223,6 @@ const LineChart = ( tooltipExtraStyle.val = ''; showTooltip.val = false; }, - testId: lineId, }, line, ) diff --git a/testgen/ui/static/js/components/link.js b/testgen/ui/static/js/components/link.js index 630d6d76..c78e821c 100644 --- a/testgen/ui/static/js/components/link.js +++ b/testgen/ui/static/js/components/link.js @@ -18,7 +18,6 @@ * @property {string?} tooltipPosition * @property {boolean?} disabled * @property {((event: any) => void)?} onClick - * @property {string?} testId */ import { getValue, loadStylesheet } from '../utils.js'; import van from '../van.min.js'; @@ -38,7 +37,7 @@ const Link = (/** @type Properties */ props) => { return a( { - 'data-testid': getValue(props.testId) ?? '', + 'data-testid': 'link', class: `tg-link ${getValue(props.underline) ? 'tg-link--underline' : ''} ${getValue(props.disabled) ? 'disabled' : ''} diff --git a/testgen/ui/static/js/components/monitoring_sparkline.js b/testgen/ui/static/js/components/monitoring_sparkline.js index f81251e9..4eb2c61a 100644 --- a/testgen/ui/static/js/components/monitoring_sparkline.js +++ b/testgen/ui/static/js/components/monitoring_sparkline.js @@ -24,6 +24,7 @@ * @property {boolean?} isPending * @property {number?} lowerTolerance * @property {number?} upperTolerance + * @property {number?} originalThreshold * * @typedef PredictionPoint * @type {Object} @@ -253,6 +254,9 @@ const MonitoringSparklineChartTooltip = (point) => { {class: 'flex-column'}, span({class: 'text-left mb-1'}, formatTimestamp(point.originalX)), span({class: 'text-left text-small'}, `${point.label || 'Value'}: ${formatNumber(point.originalY)}`), + point.originalThreshold != undefined + ? span({class: 'text-left text-small'}, `Baseline: ${formatNumber(point.originalThreshold)}`) + : '', point.lowerTolerance != undefined ? span({class: 'text-left text-small'}, `Lower bound: ${formatNumber(point.originalLowerTolerance)}`) : '', diff --git a/testgen/ui/static/js/components/notification_settings.js b/testgen/ui/static/js/components/notification_settings.js index b3cf9bce..d644f467 100644 --- a/testgen/ui/static/js/components/notification_settings.js +++ b/testgen/ui/static/js/components/notification_settings.js @@ -231,7 +231,6 @@ const NotificationSettings = (/** @type Properties */ props) => { title: () => newNotificationItemForm.isEdit.val ? span({ class: 'notifications--editing' }, 'Edit Notification') : 'Add Notification', - testId: 'notification-item-editor', expanded: panelExpanded, }, div( diff --git a/testgen/ui/static/js/components/paginator.js b/testgen/ui/static/js/components/paginator.js index cd4a4be8..663c5add 100644 --- a/testgen/ui/static/js/components/paginator.js +++ b/testgen/ui/static/js/components/paginator.js @@ -5,7 +5,6 @@ * @property {number} pageSize * @property {number?} pageIndex * @property {function(number)?} onChange - * @property {string?} testId */ import van from '../van.min.js'; @@ -18,7 +17,6 @@ const Paginator = (/** @type Properties */ props) => { loadStylesheet('paginator', stylesheet); const { count, pageSize } = props; - const testId = getValue(props.testId) ?? ''; const pageIndexState = van.derive(() => getValue(props.pageIndex) ?? 0); van.derive(() => { @@ -27,9 +25,9 @@ const Paginator = (/** @type Properties */ props) => { }); return div( - { class: 'tg-paginator', 'data-testid': testId }, + { class: 'tg-paginator', 'data-testid': 'paginator' }, span( - { class: 'tg-paginator--label', 'data-testid': testId ? `${testId}-info` : '' }, + { class: 'tg-paginator--label', 'data-testid': 'paginator-info' }, () => { const pageIndex = pageIndexState.val; const countValue = getValue(count); @@ -40,7 +38,7 @@ const Paginator = (/** @type Properties */ props) => { button( { class: 'tg-paginator--button', - 'data-testid': testId ? `${testId}-first` : '', + 'aria-label': 'First page', onclick: () => pageIndexState.val = 0, disabled: () => pageIndexState.val === 0, }, @@ -49,7 +47,7 @@ const Paginator = (/** @type Properties */ props) => { button( { class: 'tg-paginator--button', - 'data-testid': testId ? `${testId}-prev` : '', + 'aria-label': 'Previous page', onclick: () => pageIndexState.val--, disabled: () => pageIndexState.val === 0, }, @@ -58,7 +56,7 @@ const Paginator = (/** @type Properties */ props) => { button( { class: 'tg-paginator--button', - 'data-testid': testId ? `${testId}-next` : '', + 'aria-label': 'Next page', onclick: () => pageIndexState.val++, disabled: () => pageIndexState.val === Math.ceil(getValue(count) / getValue(pageSize)) - 1, }, @@ -67,7 +65,7 @@ const Paginator = (/** @type Properties */ props) => { button( { class: 'tg-paginator--button', - 'data-testid': testId ? `${testId}-last` : '', + 'aria-label': 'Last page', onclick: () => pageIndexState.val = Math.ceil(getValue(count) / getValue(pageSize)) - 1, disabled: () => pageIndexState.val === Math.ceil(getValue(count) / getValue(pageSize)) - 1, }, diff --git a/testgen/ui/static/js/components/portal.js b/testgen/ui/static/js/components/portal.js index aca28080..fa7a88c2 100644 --- a/testgen/ui/static/js/components/portal.js +++ b/testgen/ui/static/js/components/portal.js @@ -97,6 +97,7 @@ const Portal = (/** @type Options */ options, ...args) => { return div( { id, + 'data-testid': 'portal', class: getValue(options.class) ?? '', style: `position: ${fixed ? 'fixed' : 'absolute'}; z-index: ${zIndex}; ${coords} ${getValue(options.style) ?? ''}`, }, diff --git a/testgen/ui/static/js/components/radio_group.js b/testgen/ui/static/js/components/radio_group.js index 97aef2df..1d4b29c2 100644 --- a/testgen/ui/static/js/components/radio_group.js +++ b/testgen/ui/static/js/components/radio_group.js @@ -31,7 +31,7 @@ const RadioGroup = (/** @type Properties */ props) => { const disabled = getValue(props.disabled) ?? false; return div( - { class: () => `tg-radio-group--wrapper ${layout}${disabled ? ' disabled' : ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}` }, + { class: () => `tg-radio-group--wrapper ${layout}${disabled ? ' disabled' : ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}`, 'data-testid': 'radio-group' }, div( { class: 'text-caption tg-radio-group--label flex-row fx-gap-1' }, props.label, diff --git a/testgen/ui/static/js/components/schedule_list.js b/testgen/ui/static/js/components/schedule_list.js index ccd75e63..a81d590b 100644 --- a/testgen/ui/static/js/components/schedule_list.js +++ b/testgen/ui/static/js/components/schedule_list.js @@ -78,7 +78,7 @@ const ScheduleList = (/** @type Properties */ props) => { const content = div( { id: domId, class: 'flex-column fx-gap-2', style: 'height: 100%; overflow-y: auto;' }, ExpansionPanel( - {title: span({ class: 'text-green' }, 'Add Schedule'), testId: 'scheduler-cron-editor'}, + {title: span({ class: 'text-green' }, 'Add Schedule')}, div( { class: 'flex-row fx-gap-2' }, () => Select({ diff --git a/testgen/ui/static/js/components/score_breakdown.js b/testgen/ui/static/js/components/score_breakdown.js index 717c4d36..0b843405 100644 --- a/testgen/ui/static/js/components/score_breakdown.js +++ b/testgen/ui/static/js/components/score_breakdown.js @@ -33,7 +33,6 @@ const ScoreBreakdown = (score, breakdown, category, scoreType, onViewDetails, em .map(([value, label]) => ({ value, label })), height: 32, onChange: (value) => emit('CategoryChanged', { payload: value }), - testId: 'groupby-selector', }); }, span('for'), @@ -50,7 +49,6 @@ const ScoreBreakdown = (score, breakdown, category, scoreType, onViewDetails, em options: scoreTypeOptions.map((s) => ({ label: SCORE_TYPE_LABEL[s], value: s })), height: 32, onChange: (value) => emit('ScoreTypeChanged', { payload: value }), - testId: 'score-type-selector', }); }, ), diff --git a/testgen/ui/static/js/components/score_card.js b/testgen/ui/static/js/components/score_card.js index 130bc470..c76bced4 100644 --- a/testgen/ui/static/js/components/score_card.js +++ b/testgen/ui/static/js/components/score_card.js @@ -90,7 +90,7 @@ const ScoreCard = (score, actions, options) => { : '', (score_.cde_score && categories.length > 0) ? i({ class: 'mr-4 ml-4' }) : '', categories.length > 0 ? div( - { class: 'flex-column' }, + { class: 'flex-column tg-score-card--breakdown' }, span({ class: 'mb-2 text-caption' }, categoriesLabel), div( { class: 'tg-score-card--categories' }, @@ -164,13 +164,17 @@ stylesheet.replace(` margin-bottom: unset !important; } +.tg-score-card--breakdown { + margin-top: -12px; +} + .tg-score-card--categories { display: flex; flex-direction: column; flex-wrap: wrap; - row-gap: 8px; + row-gap: 4px; column-gap: 16px; - max-height: 100px; + max-height: 140px; overflow-y: auto; } .tg-score-card--categories > div { diff --git a/testgen/ui/static/js/components/score_legend.js b/testgen/ui/static/js/components/score_legend.js index e5b53281..13265d6d 100644 --- a/testgen/ui/static/js/components/score_legend.js +++ b/testgen/ui/static/js/components/score_legend.js @@ -6,7 +6,7 @@ const { div, span } = van.tags; const ScoreLegend = (/** @type string */ style) => { return div( - { class: 'flex-row fx-gap-3 text-secondary', style }, + { 'data-testid': 'score-legend', class: 'flex-row fx-gap-3 text-secondary', style }, span({ class: 'fx-flex' }), LegendItem('N/A', NaN), LegendItem('0-85', 0), diff --git a/testgen/ui/static/js/components/select.js b/testgen/ui/static/js/components/select.js index 70bc57bd..5a7776cb 100644 --- a/testgen/ui/static/js/components/select.js +++ b/testgen/ui/static/js/components/select.js @@ -21,7 +21,6 @@ * @property {number?} width * @property {number?} height * @property {string?} style - * @property {string?} testId * @property {number?} portalClass * @property {('top' | 'bottom')?} portalPosition * @property {boolean?} filterable @@ -189,7 +188,7 @@ const Select = (/** @type {Properties} */ props) => { id: domId, class: () => `flex-column fx-gap-1 text-caption tg-select--label ${getValue(props.disabled) ? 'disabled' : ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}; ${getValue(props.style)}`, - 'data-testid': getValue(props.testId) ?? '', + 'data-testid': 'select', onclick: (/** @type Event */ event) => { event.stopPropagation(); event.stopImmediatePropagation(); @@ -307,7 +306,7 @@ const MultiSelect = (props) => { id: domId, class: () => `flex-column fx-gap-1 text-caption tg-select--label ${getValue(props.disabled) ? 'disabled' : ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}; ${getValue(props.style)}`, - 'data-testid': getValue(props.testId) ?? '', + 'data-testid': 'select', onclick: (/** @type Event */ event) => { event.stopPropagation(); event.stopImmediatePropagation(); diff --git a/testgen/ui/static/js/components/slider.js b/testgen/ui/static/js/components/slider.js index 2582fc8b..e59fc6f4 100644 --- a/testgen/ui/static/js/components/slider.js +++ b/testgen/ui/static/js/components/slider.js @@ -25,7 +25,7 @@ const Slider = (/** @type Properties */ props) => { }; return label( - { class: 'flex-col fx-gap-1 clickable tg-slider--label text-caption' }, + { class: 'flex-col fx-gap-1 clickable tg-slider--label text-caption', 'data-testid': 'slider' }, props.label, input({ type: "range", diff --git a/testgen/ui/static/js/components/spark_line.js b/testgen/ui/static/js/components/spark_line.js index 89985808..ee0f9a4f 100644 --- a/testgen/ui/static/js/components/spark_line.js +++ b/testgen/ui/static/js/components/spark_line.js @@ -8,7 +8,7 @@ * @property {boolean?} interactive * @property {Function?} onPointMouseEnter * @property {Function?} onPointMouseLeave - * @property {string?} testId + * * @typedef Point * @type {object} @@ -35,7 +35,7 @@ const SparkLine = ( ) => { const display = van.derive(() => getValue(options.hidden) === true ? 'none' : ''); return g( - { fill: 'none', opacity: options.opacity ?? 1, style: 'overflow: visible;', 'data-testid': options.testId, display }, + { fill: 'none', opacity: options.opacity ?? 1, style: 'overflow: visible;', 'data-testid': 'sparkline', display }, polyline({ points: line.map(point => `${point.x} ${point.y}`).join(', '), style: `stroke: ${options.color}; stroke-width: ${options.stroke ?? 1};`, diff --git a/testgen/ui/static/js/components/summary_bar.js b/testgen/ui/static/js/components/summary_bar.js index c16dcc61..c133453f 100644 --- a/testgen/ui/static/js/components/summary_bar.js +++ b/testgen/ui/static/js/components/summary_bar.js @@ -25,6 +25,7 @@ const SummaryBar = (/** @type Properties */ props) => { const total = van.derive(() => getValue(props.items).reduce((sum, item) => sum + item.value, 0)); return div( + { 'data-testid': 'summary-bar' }, () => props.label ? div( { class: 'tg-summary-bar--label' }, props.label, diff --git a/testgen/ui/static/js/components/summary_counts.js b/testgen/ui/static/js/components/summary_counts.js index 46f5533a..d307f58e 100644 --- a/testgen/ui/static/js/components/summary_counts.js +++ b/testgen/ui/static/js/components/summary_counts.js @@ -19,13 +19,13 @@ const SummaryCounts = (/** @type Properties */ props) => { loadStylesheet('summaryCounts', stylesheet); return div( - { class: 'flex-row fx-gap-5 fx-flex-wrap' }, + { class: 'flex-row fx-gap-5 fx-flex-wrap', 'data-testid': 'summary-counts' }, getValue(props.items).map(item => div( - { class: 'flex-row fx-align-stretch fx-gap-2' }, + { 'data-testid': 'summary-count', class: 'flex-row fx-align-stretch fx-gap-2' }, div({ class: 'tg-summary-counts--bar', style: `background-color: ${colorMap[item.color] || item.color};` }), div( - div({ class: 'text-caption' }, item.label), - div({ class: 'tg-summary-counts--count' }, formatNumber(item.value)), + div({ 'data-testid': 'summary-count-label', class: 'text-caption' }, item.label), + div({ 'data-testid': 'summary-count-value', class: 'tg-summary-counts--count' }, formatNumber(item.value)), ) )), ); diff --git a/testgen/ui/static/js/components/table.js b/testgen/ui/static/js/components/table.js index 8dbe9712..ae365fba 100644 --- a/testgen/ui/static/js/components/table.js +++ b/testgen/ui/static/js/components/table.js @@ -181,6 +181,7 @@ const Table = (options, rows) => { return div( { + 'data-testid': 'table', class: () => `tg-table flex-column border border-radius-1 ${getValue(options.highDensity) ? 'tg-table-high-density' : ''} ${getValue(options.dynamicWidth) ? 'tg-table-dynamic-width' : ''} ${(getValue(options.uppercaseHeader) ?? true) ? 'tg-table-uppercase-header' : ''} ${options.selection?.onRowsSelected ? 'tg-table-hoverable' : ''}`, style: () => `height: ${getValue(options.height) ? getValue(options.height) : defaultHeight}; ${getValue(options.maxHeight) ? 'max-height: ' + getValue(options.maxHeight) + ';' : ''}`, }, @@ -227,7 +228,7 @@ const Table = (options, rows) => { const rows_ = getValue(rows); if (rows_.length <= 0 && options.emptyState) { return tbody( - {class: 'tg-table-empty-state-body'}, + {'data-testid': 'table-empty', class: 'tg-table-empty-state-body'}, tr( td( {colspan: dataColumns.val.length}, @@ -418,7 +419,7 @@ const Paginatior = ( const sizeOptions = (pageSizeOptions ?? defaultPageSizeOptions).map(n => ({ label: String(n), value: n })); return div( - {class: `tg-table-paginator flex-row fx-justify-content-flex-end ${highDensity ? '' : 'p-1'} text-secondary`}, + {'data-testid': 'table-paginator', class: `tg-table-paginator flex-row fx-justify-content-flex-end ${highDensity ? '' : 'p-1'} text-secondary`}, leftContent, leftContent != undefined ? span({class: 'fx-flex'}) : '', @@ -426,7 +427,6 @@ const Paginatior = ( span({class: 'mr-2'}, 'Rows per page:'), Select({ triggerStyle: 'inline', - testId: 'items-per-page', value: itemsPerPage, options: sizeOptions, portalPosition: 'top', diff --git a/testgen/ui/static/js/components/table_group_form.js b/testgen/ui/static/js/components/table_group_form.js index 8fc96dc2..93becaa9 100644 --- a/testgen/ui/static/js/components/table_group_form.js +++ b/testgen/ui/static/js/components/table_group_form.js @@ -1,6 +1,6 @@ /** * @import { Connection } from './connection_form.js'; - * + * * @typedef TableGroup * @type {object} * @property {string?} id @@ -29,12 +29,12 @@ * @property {string?} stakeholder_group * @property {string?} transform_level * @property {string?} data_product - * + * * @typedef FormState * @type {object} * @property {boolean} dirty * @property {boolean} valid - * + * * @typedef Properties * @type {object} * @property {TableGroup} tableGroup @@ -42,6 +42,7 @@ * @property {boolean?} showConnectionSelector * @property {boolean?} disableConnectionSelector * @property {boolean?} disableSchemaField + * @property {string?} sqlFlavor * @property {boolean?} disablePiiFlag * @property {(tg: TableGroup, state: FormState) => void} onChange */ @@ -65,9 +66,9 @@ const normalizeTableSet = (value) => { } /** - * - * @param {Properties} props - * @returns + * + * @param {Properties} props + * @returns */ const TableGroupForm = (props) => { loadStylesheet('table-group-form', stylesheet); @@ -111,6 +112,13 @@ const TableGroupForm = (props) => { const showConnectionSelector = getValue(props.showConnectionSelector) ?? false; const disableSchemaField = van.derive(() => getValue(props.disableSchemaField) ?? false) + const isSalesforce = van.derive(() => { + const connections = getValue(props.connections) ?? []; + const selected = connections.find(c => c.connection_id === tableGroupConnectionId.val); + const flavor = selected?.sql_flavor ?? getValue(props.sqlFlavor); + return flavor === 'salesforce_data360'; + }); + const updatedTableGroup = van.derive(() => { return { id: tableGroup.id, @@ -176,7 +184,7 @@ const TableGroupForm = (props) => { }) : undefined, MainForm( - { disableSchemaField, setValidity: setFieldValidity }, + { disableSchemaField, isSalesforce, setValidity: setFieldValidity }, tableGroupsName, tableGroupSchema, ), @@ -238,12 +246,14 @@ const MainForm = ( }, validators: [ required ], }), - Input({ + () => Input({ name: 'table_group_schema', - label: 'Schema', + label: getValue(options.isSalesforce) ? 'Data Space' : 'Schema', value: tableGroupSchema, class: 'tg-column-flex', - help: 'Database schema containing the tables for the Table Group', + help: getValue(options.isSalesforce) + ? 'Salesforce data space containing the tables for the Table Group' + : 'Database schema containing the tables for the Table Group', helpPlacement: 'bottom-left', disabled: options.disableSchemaField, onChange: (value, state) => { @@ -340,7 +350,7 @@ const SettingsForm = ( ) => { return div( { class: 'flex-row fx-gap-3 fx-flex-wrap fx-align-flex-start border border-radius-1 p-3 mt-1', style: 'position: relative;' }, - Caption({content: 'Settings', style: 'position: absolute; top: -10px; background: var(--app-background-color); padding: 0px 8px;' }), + Caption({content: 'Settings', style: 'position: absolute; top: -10px; background: var(--app-background-color); padding: 0px 8px;' }), div( { class: 'tg-column-flex flex-column fx-gap-3' }, Checkbox({ @@ -400,7 +410,7 @@ const SamplingForm = ( profileSampleMinCount, ) => { return ExpansionPanel( - { title: 'Sampling Parameters', testId: 'sampling-panel' }, + { title: 'Sampling Parameters' }, div( { class: 'flex-column fx-gap-3' }, Checkbox({ @@ -454,7 +464,7 @@ const TaggingForm = ( dataProduct, ) => { return ExpansionPanel( - { title: 'Table Group Tags', testId: 'tags-panel' }, + { title: 'Table Group Tags' }, Input({ name: 'description', class: 'fx-flex mb-3', diff --git a/testgen/ui/static/js/components/table_group_wizard.js b/testgen/ui/static/js/components/table_group_wizard.js index 05873061..03610738 100644 --- a/testgen/ui/static/js/components/table_group_wizard.js +++ b/testgen/ui/static/js/components/table_group_wizard.js @@ -461,7 +461,7 @@ const TableGroupWizard = (props) => { return div( { class: 'flex-column' }, div( - { class: 'flex-column fx-gap-4 mb-4 p-5 border border-radius-2' }, + { class: 'flex-column fx-gap-4 mb-4 p-5 border border-radius-2', 'data-testid': 'wizard-success-panel' }, div( { class: 'flex-row fx-gap-2' }, Icon({ style: 'color: var(--green);' }, 'check_circle'), diff --git a/testgen/ui/static/js/components/tabs.js b/testgen/ui/static/js/components/tabs.js index 23d315d6..d6cbc51c 100644 --- a/testgen/ui/static/js/components/tabs.js +++ b/testgen/ui/static/js/components/tabs.js @@ -19,7 +19,6 @@ const Tab = ({ label }, ...children) => ({ /** * @typedef {Object} TabsProps - * @property {string?} testId * @property {string?} class * * @param {TabsProps} props @@ -28,8 +27,7 @@ const Tab = ({ label }, ...children) => ({ const Tabs = (props, ...tabs) => { loadStylesheet('tabs', stylesheet); - const { testId: testIdProp, ...restProps } = props; - const testId = getValue(testIdProp) ?? ''; + const { ...restProps } = props; const activeTab = van.state(0); @@ -52,7 +50,7 @@ const Tabs = (props, ...tabs) => { ...tabs.map((tab, i) => button({ class: () => `tg-tabs--tab--label ${i === activeTab.val ? 'active' : ''}`, - 'data-testid': testId ? `${testId}-tab-${i}` : '', + 'data-testid': 'tab', onclick: () => (activeTab.val = i), }, tab.label @@ -60,9 +58,9 @@ const Tabs = (props, ...tabs) => { highlightEl, ); - const tabsContainerEl = div({ ...restProps, 'data-testid': testId, class: () => `${getValue(restProps.class) ?? ''} tg-tabs--container` }, + const tabsContainerEl = div({ ...restProps, 'data-testid': 'tabs', class: () => `${getValue(restProps.class) ?? ''} tg-tabs--container` }, labelsContainerEl, - div({ class: "tg-tabs--content", 'data-testid': testId ? `${testId}-panel` : '' }, () => div({class: "tg-tabs--content-inner"}, tabs[activeTab.val].children)), + div({ class: "tg-tabs--content", 'data-testid': 'tab-panel' }, () => div({class: "tg-tabs--content-inner"}, tabs[activeTab.val].children)), ); van.derive(() => { diff --git a/testgen/ui/static/js/components/test_definition_form.js b/testgen/ui/static/js/components/test_definition_form.js index 18b173dc..e7fa14d6 100644 --- a/testgen/ui/static/js/components/test_definition_form.js +++ b/testgen/ui/static/js/components/test_definition_form.js @@ -60,6 +60,7 @@ * @type {object} * @property {TestDefinition} definition * @property {string?} class + * @property {boolean} qualifiesTableRefsWithSchema * @property {(changes: object, valid: boolean) => void} onChange */ @@ -83,7 +84,7 @@ const thresholdColumns = [ ]; // Columns using the default { type: 'text' } do not need to be specified here -const PARAMETER_CONFIG = { +const PARAMETER_CONFIG = { custom_query: { type: 'textarea' }, lower_tolerance: { type: 'number' }, upper_tolerance: { type: 'number' }, @@ -94,6 +95,7 @@ const TestDefinitionForm = (/** @type Properties */ props) => { loadStylesheet('test-definition-form', stylesheet); const definition = getValue(props.definition); + const qualifiesTableRefsWithSchema = getValue(props.qualifiesTableRefsWithSchema) ?? true; const paramColumns = (definition.default_parm_columns || '').split(',').map(v => v.trim()); const paramLabels = (definition.default_parm_prompts || '').split(',').map(v => v.trim()); @@ -110,6 +112,8 @@ const TestDefinitionForm = (/** @type Properties */ props) => { validators: paramRequired[index] ? [required] : undefined, })) .filter(config => !hasThresholds || !thresholdColumns.includes(config.column)) + // Drop the field for flavors whose SQL doesn't qualify table refs with a schema + .filter(config => qualifiesTableRefsWithSchema || config.column !== 'match_schema_name') const updatedDefinition = van.state({ ...definition }); const validityPerField = van.state({}); @@ -258,9 +262,9 @@ const historyCalcOptions = [ * @type {object} * @property {(updatedValues: object) => void} setFieldValues * @property {(field: string, valid: boolean) => void} setFieldValidity - * - * @param {ThresholdFormOptions} options - * @param {TestDefinition} definition + * + * @param {ThresholdFormOptions} options + * @param {TestDefinition} definition */ const ThresholdForm = (options, definition) => { const { setFieldValues, setFieldValidity } = options; diff --git a/testgen/ui/static/js/components/textarea.js b/testgen/ui/static/js/components/textarea.js index bdfc411a..5a004a7b 100644 --- a/testgen/ui/static/js/components/textarea.js +++ b/testgen/ui/static/js/components/textarea.js @@ -22,7 +22,6 @@ * @property {string?} class * @property {number?} width * @property {number?} height - * @property {string?} testId * @property {Array?} validators */ import van from '../van.min.js'; @@ -68,7 +67,7 @@ const Textarea = (/** @type Properties */ props) => { id: domId, class: () => `flex-column fx-gap-1 ${getValue(props.class) ?? ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}; ${getValue(props.style)}`, - 'data-testid': props.testId ?? props.name ?? '', + 'data-testid': 'textarea', }, div( { class: 'flex-row fx-gap-1 text-caption' }, @@ -87,6 +86,7 @@ const Textarea = (/** @type Properties */ props) => { class: () => `tg-textarea--field ${getValue(props.disabled) ? 'tg-textarea--disabled' : ''}`, style: () => `min-height: ${getValue(props.height) || defaultHeight}px;`, value, + 'data-value': value, name: props.name ?? '', disabled: props.disabled, placeholder: () => getValue(props.placeholder) ?? '', diff --git a/testgen/ui/static/js/components/toggle.js b/testgen/ui/static/js/components/toggle.js index eb723c38..4911d953 100644 --- a/testgen/ui/static/js/components/toggle.js +++ b/testgen/ui/static/js/components/toggle.js @@ -19,7 +19,7 @@ const Toggle = (/** @type Properties */ props) => { const disabled = props.disabled?.val ?? props.disabled ?? false; return label( - { class: `flex-row fx-gap-2 ${disabled ? '' : 'clickable'}`, style: props.style ?? '', 'data-testid': props.name ?? '' }, + { class: `flex-row fx-gap-2 ${disabled ? '' : 'clickable'}`, style: props.style ?? '', 'data-testid': 'toggle' }, input({ type: 'checkbox', role: 'switch', diff --git a/testgen/ui/static/js/components/tree.js b/testgen/ui/static/js/components/tree.js index b9902269..2e8dcd84 100644 --- a/testgen/ui/static/js/components/tree.js +++ b/testgen/ui/static/js/components/tree.js @@ -90,6 +90,7 @@ const Tree = (/** @type Properties */ props, /** @type any? */ searchOptionsCont { id: props.id, class: () => `flex-column ${getValue(props.classes)}`, + 'data-testid': 'tree', }, Toolbar(treeNodes, multiSelect, props, searchOptionsContent, filtersContent, emit), div( diff --git a/testgen/ui/utils.py b/testgen/ui/utils.py index fe097761..270de953 100644 --- a/testgen/ui/utils.py +++ b/testgen/ui/utils.py @@ -1,20 +1,10 @@ -import zoneinfo from collections.abc import Callable -from datetime import datetime from typing import TypedDict -import cron_converter -import cron_descriptor - +from testgen.common.cron_service import get_cron_sample from testgen.ui.session import temp_value -class CronSample(TypedDict): - id: str | None - error: str | None - samples: list[str] | list[int] | None - readable_expr: str | None - class CronSampleHandlerPayload(TypedDict): tz: str cron_expr: str @@ -23,33 +13,6 @@ class CronSampleHandlerPayload(TypedDict): CronSampleCallback = Callable[[CronSampleHandlerPayload], None] -def get_cron_sample( - cron_expr: str, - cron_tz: str, - sample_count: int, - *, - reference_time: datetime | None = None, - formatted: bool = False, -) -> CronSample: - try: - cron_obj = cron_converter.Cron(cron_expr) - cron_schedule = cron_obj.schedule(reference_time or datetime.now(zoneinfo.ZoneInfo(cron_tz))) - readble_cron_schedule = cron_descriptor.get_description(cron_expr) - if formatted: - samples = [cron_schedule.next().strftime("%a %b %-d, %-I:%M %p") for _ in range(sample_count)] - else: - samples = [int(cron_schedule.next().timestamp()) for _ in range(sample_count)] - except ValueError as e: - return {"error": str(e)} - except Exception as e: - return {"error": "Error validating the Cron expression"} - else: - return { - "samples": samples, - "readable_expr": readble_cron_schedule, - } - - def get_cron_sample_handler(key: str, *, sample_count: int = 3) -> tuple[dict | None, CronSampleCallback]: cron_sample_result, set_cron_sample = temp_value(key, default={}) diff --git a/testgen/ui/views/connections.py b/testgen/ui/views/connections.py index 35fa58cf..619e10ed 100644 --- a/testgen/ui/views/connections.py +++ b/testgen/ui/views/connections.py @@ -19,6 +19,7 @@ from testgen import settings from testgen.common.database.database_service import empty_cache, get_flavor_service from testgen.common.database.flavor.flavor_service import resolve_connection_params +from testgen.common.enums import JobSource from testgen.common.models import get_current_session, with_database_session from testgen.common.models.connection import Connection, ConnectionMinimal from testgen.common.models.job_execution import JobExecution @@ -29,6 +30,11 @@ from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page +from testgen.ui.services.query_cache import ( + get_connection, + select_connections_where, + select_table_groups_minimal_where, +) from testgen.ui.session import session, temp_value from testgen.ui.utils import get_cron_sample_handler @@ -70,14 +76,14 @@ def render(self, project_code: str, **_kwargs) -> None: "connect-your-database/manage-connections/", ) - connections = Connection.select_where(Connection.project_code == project_code) + connections = select_connections_where(Connection.project_code == project_code) connection: Connection = connections[0] if len(connections) > 0 else Connection( sql_flavor="postgresql", sql_flavor_code="postgresql", project_code=project_code, ) has_table_groups = ( - connection.id and len(TableGroup.select_minimal_where(TableGroup.connection_id == connection.connection_id) or []) > 0 + connection.id and len(select_table_groups_minimal_where(TableGroup.connection_id == connection.connection_id) or []) > 0 ) user_is_admin = session.auth.user_has_permission("administer") @@ -185,8 +191,8 @@ def on_setup_table_group_clicked(*_args) -> None: success = True try: connection.save() - Connection.select_where.clear() - Connection.get.clear() + select_connections_where.clear() + get_connection.clear() message = "Changes have been saved successfully." except Exception as error: message = "Something went wrong while creating the connection." @@ -426,7 +432,6 @@ def on_close_clicked(_params: dict) -> None: key=RUN_TESTS_JOB_KEY, cron_expr=standard_test_suite_data["schedule"], cron_tz=standard_test_suite_data["timezone"], - args=[], kwargs={"test_suite_id": str(standard_test_suite.id)}, ).save() @@ -458,7 +463,6 @@ def on_close_clicked(_params: dict) -> None: key=RUN_MONITORS_JOB_KEY, cron_expr=monitor_test_suite_data.get("schedule"), cron_tz=monitor_test_suite_data.get("timezone"), - args=[], kwargs={"test_suite_id": str(monitor_test_suite.id)}, ).save() @@ -473,7 +477,7 @@ def on_close_clicked(_params: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group.id)}, - source="ui", + source=JobSource.ui, project_code=table_group.project_code, ) message = f"Profiling run started for table group {table_group.table_groups_name}." @@ -639,6 +643,12 @@ class ConnectionFlavor: flavor="sap_hana", icon=get_asset_data_url("flavors/sap_hana.svg"), ), + ConnectionFlavor( + label="Salesforce Data 360", + value="salesforce_data360", + flavor="salesforce_data360", + icon=get_asset_data_url("flavors/salesforce_data360.svg"), + ), ConnectionFlavor( label="Snowflake", value="snowflake", diff --git a/testgen/ui/views/data_catalog.py b/testgen/ui/views/data_catalog.py index 80ebb16b..d316f237 100644 --- a/testgen/ui/views/data_catalog.py +++ b/testgen/ui/views/data_catalog.py @@ -11,6 +11,7 @@ from streamlit.delta_generator import DeltaGenerator from testgen.common.database.database_service import get_flavor_service +from testgen.common.enums import JobSource from testgen.common.models import database_session, with_database_session from testgen.common.models.connection import Connection from testgen.common.models.job_execution import JobExecution @@ -23,6 +24,7 @@ mask_profiling_pii, mask_source_data_pii, ) +from testgen.common.profile_top_values import parse_top_freq_values, parse_top_patterns from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( FILE_DATA_TYPE, @@ -45,7 +47,14 @@ get_tables_by_table_group, ) from testgen.ui.services.database_service import execute_db_query, fetch_all_from_db, fetch_from_target_db -from testgen.ui.services.query_cache import get_profiling_run_summaries, get_project_summary, get_table_group_stats +from testgen.ui.services.query_cache import ( + get_profiling_run_summaries, + get_project_summary, + get_table_group, + get_table_group_stats, + select_profiling_runs_minimal_where, + select_table_groups_minimal_where, +) from testgen.ui.session import session from testgen.ui.views.dialogs.import_metadata_dialog import ( apply_metadata_import, @@ -98,7 +107,7 @@ def render(self, project_code: str, table_group_id: str | None = None, selected: project_summary = get_project_summary(project_code) user_can_navigate = session.auth.user_has_permission("view") - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) if not table_group_id or table_group_id not in [ str(item.id) for item in table_groups ]: table_group_id = str(table_groups[0].id) if table_groups else None @@ -115,7 +124,7 @@ def render(self, project_code: str, table_group_id: str | None = None, selected: selected_item["connection_id"] = str(selected_table_group.connection_id) else: on_item_selected(None) - + def on_run_profiling_clicked(_) -> None: if selected_table_group: st.session_state[DC_RUN_PROFILING_DIALOG_KEY] = str(selected_table_group.id) @@ -143,7 +152,7 @@ def on_run_profiling_confirmed(table_group: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group["id"])}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: @@ -201,7 +210,7 @@ def on_import_confirmed(_) -> None: try: apply_metadata_import(preview, tg_id) from testgen.ui.queries.profiling_queries import get_column_by_id, get_table_by_id - for func in [get_table_group_columns, get_table_by_id, get_column_by_id, get_tag_values, TableGroup.select_minimal_where]: + for func in [get_table_group_columns, get_table_by_id, get_column_by_id, get_tag_values, select_table_groups_minimal_where]: func.clear() st.session_state["data_catalog:last_saved_timestamp"] = datetime.now().timestamp() parts = [] @@ -464,7 +473,7 @@ def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: Ta include_tags=True, include_active_tests=True, ) - + data = pd.DataFrame(table_data + column_data) @@ -499,13 +508,12 @@ def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: Ta axis=1, ) data["top_freq_values"] = data["top_freq_values"].apply( - lambda val: "\n".join([f"{part.split(' | ')[1]} | {part.split(' | ')[0]}" for part in val[2:].split("\n| ")]) + lambda val: "\n".join(f"{count} | {value}" for value, count in parse_top_freq_values(val)) if not pd.isna(val) and val != PII_REDACTED else val ) - nl = "\n" # For Python 3.11 compatibility data["top_patterns"] = data["top_patterns"].apply( - lambda val: "".join([f"{part}{nl if index % 2 else ' | '}" for index, part in enumerate(val.split(" | "))]) + lambda val: "\n".join(f"{count} | {pattern}" for pattern, count in parse_top_patterns(val)) if not pd.isna(val) and val != PII_REDACTED else val ) @@ -663,7 +671,7 @@ def on_tags_changed(spinner_container: DeltaGenerator, payload: dict) -> FILE_DA if disable_flags: table_group_id = st.query_params.get("table_group_id") if table_group_id: - table_group = TableGroup.get(table_group_id) + table_group = get_table_group(table_group_id) changed = False if "profile_flag_cdes" in disable_flags and table_group.profile_flag_cdes: table_group.profile_flag_cdes = False @@ -674,7 +682,7 @@ def on_tags_changed(spinner_container: DeltaGenerator, payload: dict) -> FILE_DA if changed: table_group.save() - for func in [ get_table_group_columns, get_table_by_id, get_column_by_id, get_tag_values, TableGroup.select_minimal_where ]: + for func in [ get_table_group_columns, get_table_by_id, get_column_by_id, get_tag_values, select_table_groups_minimal_where ]: func.clear() st.session_state["data_catalog:last_saved_timestamp"] = datetime.now().timestamp() @@ -683,7 +691,7 @@ def on_tags_changed(spinner_container: DeltaGenerator, payload: dict) -> FILE_DA def get_table_group_columns(table_group_id: str) -> list[dict]: if not is_uuid4(table_group_id): return [] - + query = f""" SELECT CONCAT('column_', column_chars.column_id) AS column_id, CONCAT('table_', table_chars.table_id) AS table_id, @@ -773,7 +781,7 @@ def get_latest_test_issues(table_group_id: str, table_name: str, column_name: st test_results.test_type = test_types.test_type ) WHERE test_suites.table_groups_id = :table_group_id - AND test_suites.is_monitor = false + AND test_suites.is_monitor IS NOT TRUE AND table_name = :table_name {"AND column_names = :column_name" if column_name else ""} AND result_status NOT IN ('Passed', 'Log') @@ -808,7 +816,7 @@ def get_related_test_suites(table_group_id: str, table_name: str, column_name: s test_definitions.test_suite_id = test_suites.id ) WHERE test_suites.table_groups_id = :table_group_id - AND test_suites.is_monitor = false + AND test_suites.is_monitor IS NOT TRUE AND table_name = :table_name {"AND column_name = :column_name" if column_name else ""} GROUP BY test_suites.id @@ -831,7 +839,7 @@ def _build_history_dialog_data( column_name: str, add_date: int, ) -> dict | None: - profiling_runs = ProfilingRun.select_minimal_where( + profiling_runs = select_profiling_runs_minimal_where( ProfilingRun.table_groups_id == table_group_id, ProfilingRun.profiling_starttime >= sa_func.to_timestamp(add_date), ) @@ -899,15 +907,15 @@ def get_preview_data( return {"title": title, "status": "ERR", "message": "Connection not found."} flavor_service = get_flavor_service(connection.sql_flavor) - row_limiting = flavor_service.row_limiting_clause + prefix, suffix = flavor_service.row_limit_clauses(100) quote = flavor_service.quote_character + table_ref = flavor_service.get_table_ref(schema_name, table_name) query = f""" SELECT DISTINCT - {"TOP 100" if row_limiting == "top" else ""} + {prefix} {f"{quote}{column_name}{quote}" if column_name else "*"} - FROM {quote}{schema_name}{quote}.{quote}{table_name}{quote} - {"LIMIT 100" if row_limiting == "limit" else ""} - {"FETCH FIRST 100 ROWS ONLY" if row_limiting == "fetch" else ""} + FROM {table_ref} + {suffix} """ try: diff --git a/testgen/ui/views/dialogs/generate_tests_dialog.py b/testgen/ui/views/dialogs/generate_tests_dialog.py index e894dcb8..cf4cc715 100644 --- a/testgen/ui/views/dialogs/generate_tests_dialog.py +++ b/testgen/ui/views/dialogs/generate_tests_dialog.py @@ -1,5 +1,6 @@ import streamlit as st +from testgen.common.enums import JobSource from testgen.common.models import database_session, with_database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.test_suite import TestSuiteMinimal @@ -42,7 +43,7 @@ def on_generate_tests_confirmed(data: dict) -> None: JobExecution.submit( job_key="run-test-generation", kwargs={"test_suite_id": str(test_suite_id), "generation_set": selected_set}, - source="ui", + source=JobSource.ui, project_code=test_suite.project_code, ) st.session_state[RESULT_KEY] = {"success": True, "message": f"Test generation started for test suite '{test_suite_name}'."} diff --git a/testgen/ui/views/dialogs/import_metadata_dialog.py b/testgen/ui/views/dialogs/import_metadata_dialog.py index 524750ea..b6b1f4c8 100644 --- a/testgen/ui/views/dialogs/import_metadata_dialog.py +++ b/testgen/ui/views/dialogs/import_metadata_dialog.py @@ -4,9 +4,9 @@ import pandas as pd -from testgen.common.models.table_group import TableGroup from testgen.ui.queries.profiling_queries import TAG_FIELDS from testgen.ui.services.database_service import execute_db_query, fetch_all_from_db +from testgen.ui.services.query_cache import get_table_group from testgen.ui.session import session LOG = logging.getLogger("testgen") @@ -192,7 +192,7 @@ def _match_and_validate( matched_columns = sum(1 for r in preview_rows if r.get("column_name") and r.get("_status") in _importable) skipped = sum(1 for r in preview_rows if r.get("_status") not in _importable) - table_group = TableGroup.get(table_group_id) + table_group = get_table_group(table_group_id) return { "table_rows": table_rows, @@ -328,7 +328,7 @@ def apply_metadata_import(preview: dict, table_group_id: str | None = None) -> d def _disable_autoflags(table_group_id: str, metadata_columns: list[str]) -> None: - table_group = TableGroup.get(table_group_id) + table_group = get_table_group(table_group_id) changed = False if "critical_data_element" in metadata_columns and table_group.profile_flag_cdes: table_group.profile_flag_cdes = False diff --git a/testgen/ui/views/dialogs/manage_schedules.py b/testgen/ui/views/dialogs/manage_schedules.py index 4caf4c59..4d917f7c 100644 --- a/testgen/ui/views/dialogs/manage_schedules.py +++ b/testgen/ui/views/dialogs/manage_schedules.py @@ -33,7 +33,7 @@ def get_arg_value(self, job): def get_arg_value_options(self) -> list[dict[str, str]]: raise NotImplementedError - def get_job_arguments(self, arg_value: str) -> tuple[list[Any], dict[str, Any]]: + def get_job_arguments(self, arg_value: str) -> dict[str, Any]: raise NotImplementedError def build_data(self) -> dict: @@ -86,7 +86,7 @@ def on_resume(self, item: dict) -> None: st.session_state.pop(RESULT_KEY, None) def on_cron_sample(self, payload: dict) -> None: - from testgen.ui.utils import get_cron_sample + from testgen.common.cron_service import get_cron_sample sample = get_cron_sample(payload["cron_expr"], payload["tz"], CRON_SAMPLE_COUNT, formatted=True) st.session_state[CRON_SAMPLE_KEY] = sample @@ -98,15 +98,13 @@ def on_add(self, payload: dict) -> None: is_form_valid = bool(arg_value) and bool(cron_tz) and bool(cron_expr) if is_form_valid: cron_obj = cron_converter.Cron(cron_expr) - args, kwargs = self.get_job_arguments(arg_value) sched_model = JobSchedule( project_code=self.project_code, key=self.job_key, cron_expr=cron_obj.to_string(), cron_tz=cron_tz, active=True, - args=args, - kwargs=kwargs, + kwargs=self.get_job_arguments(arg_value), ) with_database_session(sched_model.save)() st.session_state[RESULT_KEY] = {"success": True, "message": "Schedule added"} diff --git a/testgen/ui/views/dialogs/run_profiling_dialog.py b/testgen/ui/views/dialogs/run_profiling_dialog.py index 9dfbb3ff..80693b94 100644 --- a/testgen/ui/views/dialogs/run_profiling_dialog.py +++ b/testgen/ui/views/dialogs/run_profiling_dialog.py @@ -2,6 +2,7 @@ import streamlit as st +from testgen.common.enums import JobSource from testgen.common.models import database_session from testgen.common.models.job_execution import JobExecution from testgen.ui.components import widgets as testgen @@ -32,7 +33,7 @@ def on_run_profiling_confirmed(table_group: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group["id"])}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: diff --git a/testgen/ui/views/dialogs/run_tests_dialog.py b/testgen/ui/views/dialogs/run_tests_dialog.py index 29b224e9..0c4a9d72 100644 --- a/testgen/ui/views/dialogs/run_tests_dialog.py +++ b/testgen/ui/views/dialogs/run_tests_dialog.py @@ -1,11 +1,12 @@ import streamlit as st +from testgen.common.enums import JobSource from testgen.common.models import database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.test_suite import TestSuite from testgen.ui.components import widgets as testgen from testgen.ui.navigation.router import Router -from testgen.ui.services.query_cache import get_test_run_summaries +from testgen.ui.services.query_cache import get_test_run_summaries, select_test_suites_minimal_where from testgen.ui.session import session LINK_HREF = "test-runs" @@ -18,7 +19,7 @@ def run_tests_dialog_widget( on_close: callable, test_suite_id: str | None = None, ) -> None: - test_suites = TestSuite.select_minimal_where( + test_suites = select_test_suites_minimal_where( TestSuite.project_code == project_code, TestSuite.is_monitor.isnot(True), ) @@ -34,7 +35,7 @@ def on_run_tests_confirmed(data: dict) -> None: JobExecution.submit( job_key="run-tests", kwargs={"test_suite_id": str(selected_id)}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as e: diff --git a/testgen/ui/views/hygiene_issues.py b/testgen/ui/views/hygiene_issues.py index 50dbcf50..521a9de4 100644 --- a/testgen/ui/views/hygiene_issues.py +++ b/testgen/ui/views/hygiene_issues.py @@ -10,7 +10,6 @@ from testgen.common.mixpanel_service import MixpanelService from testgen.common.models import with_database_session from testgen.common.models.hygiene_issue import HygieneIssue -from testgen.common.models.profiling_run import ProfilingRun from testgen.common.pii_masking import get_pii_columns, mask_hygiene_detail, mask_profiling_pii from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( @@ -26,6 +25,7 @@ from testgen.ui.queries.profiling_queries import get_profiling_anomalies from testgen.ui.queries.source_data_queries import get_hygiene_issue_source_data, get_hygiene_issue_source_query from testgen.ui.services.database_service import execute_db_query +from testgen.ui.services.query_cache import get_profiling_run_minimal from testgen.ui.session import session from testgen.utils import friendly_score, make_json_safe @@ -92,7 +92,7 @@ def render( sort: str | None = None, **_kwargs, ) -> None: - run = ProfilingRun.get_minimal(run_id) + run = get_profiling_run_minimal(run_id) if not run: self.router.navigate_with_warning( f"Profiling run with ID '{run_id}' does not exist. Redirecting to list of Profiling Runs ...", @@ -260,7 +260,7 @@ def on_view_source_data(row_id: str) -> None: anomaly_df = profiling_queries.get_profiling_anomalies_by_ids([row_id]) if anomaly_df.empty: return - row = make_json_safe(anomaly_df.where(anomaly_df.notna(), None).to_dict(orient="records")[0]) + row = anomaly_df.where(anomaly_df.notna(), None).to_dict(orient="records")[0] MixpanelService().send_event( "view-source-data", @@ -335,10 +335,7 @@ def on_download_report(payload: dict) -> None: anomaly_df = profiling_queries.get_profiling_anomalies_by_ids(ids) if anomaly_df.empty: return - selected_items = [ - make_json_safe(record) - for record in anomaly_df.where(anomaly_df.notna(), None).to_dict(orient="records") - ] + selected_items = anomaly_df.where(anomaly_df.notna(), None).to_dict(orient="records") MixpanelService().send_event( "download-issue-report", diff --git a/testgen/ui/views/monitors_dashboard.py b/testgen/ui/views/monitors_dashboard.py index 3007c173..e4fd3719 100644 --- a/testgen/ui/views/monitors_dashboard.py +++ b/testgen/ui/views/monitors_dashboard.py @@ -7,6 +7,7 @@ import streamlit as st from testgen.commands.test_generation import run_monitor_generation +from testgen.common.cron_service import get_cron_sample from testgen.common.freshness_service import add_business_minutes, get_schedule_params, resolve_holiday_dates from testgen.common.models import get_current_session, with_database_session from testgen.common.models.notification_settings import ( @@ -24,10 +25,19 @@ from testgen.ui.navigation.router import Router from testgen.ui.queries.profiling_queries import get_tables_by_table_group from testgen.ui.services.database_service import execute_db_query, fetch_all_from_db, fetch_one_from_db -from testgen.ui.services.query_cache import get_project_summary, get_test_type_summaries +from testgen.ui.services.query_cache import ( + get_monitor_schedule, + get_project_summary, + get_table_group, + get_test_definition, + get_test_suite, + get_test_type_summaries, + select_table_groups_minimal_where, + select_test_definitions_where, +) from testgen.ui.services.rerun_service import safe_rerun from testgen.ui.session import session, temp_value -from testgen.ui.utils import dict_from_kv, get_cron_sample, get_cron_sample_handler +from testgen.ui.utils import dict_from_kv, get_cron_sample_handler from testgen.ui.views.dialogs.manage_notifications import NotificationSettingsDialogBase from testgen.utils import make_json_safe @@ -85,7 +95,7 @@ def render( ) project_summary = get_project_summary(project_code) - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) if not table_group_id or table_group_id not in [ str(item.id) for item in table_groups ]: table_group_id = str(table_groups[0].id) if table_groups else None @@ -107,10 +117,7 @@ def render( if monitor_suite_id: with st.spinner(text="Loading data ..."): - monitor_schedule = JobSchedule.get( - JobSchedule.key == RUN_MONITORS_JOB_KEY, - JobSchedule.kwargs["test_suite_id"].astext == str(monitor_suite_id), - ) + monitor_schedule = get_monitor_schedule(monitor_suite_id) anomaly_type_filter = [t for t in anomaly_type_filter.split(",") if t in ANOMALY_TYPE_FILTERS] if anomaly_type_filter else None if sort_field and sort_field not in ALLOWED_SORT_FIELDS: @@ -597,7 +604,7 @@ def build_edit_monitor_settings_data( monitor_suite_id = table_group.monitor_test_suite_id if monitor_suite_id: - monitor_suite = TestSuite.get(monitor_suite_id) + monitor_suite = get_test_suite(monitor_suite_id) else: monitor_suite = TestSuite( project_code=table_group.project_code, @@ -638,7 +645,6 @@ def on_save_settings_clicked(payload: dict) -> None: new_schedule = JobSchedule( project_code=table_group.project_code, key=RUN_MONITORS_JOB_KEY, - args=[], kwargs={"test_suite_id": str(monitor_suite.id)}, **new_schedule_config, ) @@ -648,7 +654,7 @@ def on_save_settings_clicked(payload: dict) -> None: JobSchedule.update_active(schedule.id, new_schedule_config["active"]) if is_new: - updated_table_group = TableGroup.get(table_group.id) + updated_table_group = get_table_group(table_group.id) updated_table_group.monitor_test_suite_id = monitor_suite.id updated_table_group.save() # Commit needed to make test suite visible to run_monitor_generation's separate DB connection @@ -680,7 +686,7 @@ def on_save_settings_clicked(payload: dict) -> None: @with_database_session def delete_monitor_suite(table_group: TableGroupMinimal) -> None: try: - monitor_suite = TestSuite.get(table_group.monitor_test_suite_id) + monitor_suite = get_test_suite(table_group.monitor_test_suite_id) TestSuite.cascade_delete([monitor_suite.id]) st.cache_data.clear() except Exception: @@ -749,7 +755,7 @@ def on_close_trends(_payload=None): lookback_multiplier = 3 if extended_history else 1 events = get_monitor_events_for_table(table_group.monitor_test_suite_id, table_name, lookback_multiplier) - definitions = TestDefinition.select_where( + definitions = select_test_definitions_where( TestDefinition.test_suite_id == table_group.monitor_test_suite_id, TestDefinition.table_name == table_name, TestDefinition.test_type.in_(["Freshness_Trend", "Volume_Trend", "Metric_Trend"]), @@ -757,11 +763,8 @@ def on_close_trends(_payload=None): predictions = {} if len(definitions) > 0: - test_suite = TestSuite.get(table_group.monitor_test_suite_id) - monitor_schedule = JobSchedule.get( - JobSchedule.key == RUN_MONITORS_JOB_KEY, - JobSchedule.kwargs["test_suite_id"].astext == str(table_group.monitor_test_suite_id), - ) + test_suite = get_test_suite(table_group.monitor_test_suite_id) + monitor_schedule = get_monitor_schedule(table_group.monitor_test_suite_id) monitor_lookback = test_suite.monitor_lookback predict_sensitivity = test_suite.predict_sensitivity or PredictSensitivity.medium @@ -946,6 +949,7 @@ def get_monitor_events_for_table(test_suite_id: str, table_name: str, lookback_m "is_pending": not bool(event["result_id"]), "lower_tolerance": params.get("lower_tolerance") if params.get("lower_tolerance") else None, "upper_tolerance": params.get("upper_tolerance") if params.get("upper_tolerance") else None, + "threshold_value": params.get("threshold_value") if params.get("threshold_value") else None, }) return { @@ -1021,7 +1025,7 @@ def build_edit_table_monitors_data( table_group: TableGroupMinimal, payload: dict, dialog: dict | None = None, ) -> tuple[dict, dict]: table_name = payload.get("table_name") - definitions = TestDefinition.select_where( + definitions = select_test_definitions_where( TestDefinition.test_suite_id == table_group.monitor_test_suite_id, TestDefinition.table_name == table_name, TestDefinition.test_type.in_(["Freshness_Trend", "Volume_Trend", "Metric_Trend"]), @@ -1045,7 +1049,7 @@ def on_save_test_definition(payload: dict) -> None: valid_columns = {col.name for col in TestDefinition.__table__.columns} for updated_def in get_updated_definitions(): - current_def: TestDefinitionSummary = TestDefinition.get(updated_def.get("id")) + current_def: TestDefinitionSummary = get_test_definition(updated_def.get("id")) if current_def: merged = {key: getattr(current_def, key, None) for key in valid_columns} merged.update({key: value for key, value in updated_def.items() if key in valid_columns}) diff --git a/testgen/ui/views/profiling_results.py b/testgen/ui/views/profiling_results.py index 83608a12..f779394a 100644 --- a/testgen/ui/views/profiling_results.py +++ b/testgen/ui/views/profiling_results.py @@ -8,7 +8,6 @@ from testgen.common import date_service from testgen.common.date_service import parse_fuzzy_date from testgen.common.models import with_database_session -from testgen.common.models.profiling_run import ProfilingRun from testgen.common.pii_masking import ( PII_REDACTED, get_pii_columns, @@ -16,6 +15,7 @@ mask_profiling_pii, mask_source_data_pii, ) +from testgen.common.profile_top_values import parse_top_freq_values, parse_top_patterns from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( FILE_DATA_TYPE, @@ -25,9 +25,10 @@ ) from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router +from testgen.ui.services.query_cache import get_profiling_run_minimal from testgen.ui.session import session from testgen.ui.views.data_catalog import get_preview_data -from testgen.utils import make_json_safe +from testgen.utils import dataframe_to_json_records, make_json_safe PAGE_SIZE = 500 @@ -86,7 +87,7 @@ def render( sort: str | None = None, **_kwargs, ) -> None: - run = ProfilingRun.get_minimal(run_id) + run = get_profiling_run_minimal(run_id) if not run: self.router.navigate_with_warning( f"Profiling run with ID '{run_id}' does not exist. Redirecting to list of Profiling Runs ...", @@ -162,15 +163,14 @@ def render( pii_columns = get_pii_columns(str(run.table_groups_id)) mask_profiling_pii(df, pii_columns) - # Use pandas JSON serialization to safely handle NaN/NaT -> null, timestamps -> epoch seconds - items = json.loads(df.to_json(orient="records", date_unit="s")) + items = dataframe_to_json_records(df) selected_item = st.session_state.get(SELECTED_ITEM_KEY) # Load selected item if URL has a selection but session cache is missing or stale if selected and (selected_item is None or selected_item.get("id") != selected): row_df = df[df["id"] == selected] if not row_df.empty: - row = json.loads(row_df.to_json(orient="records", date_unit="s"))[0] + row = dataframe_to_json_records(row_df)[0] row["hygiene_issues"] = profiling_queries.get_hygiene_issues( run_id, row["table_name"], row.get("column_name") ) @@ -188,7 +188,7 @@ def on_row_selected(item_id: str) -> None: row_df = df[df["id"] == item_id] if row_df.empty: return - row = json.loads(row_df.to_json(orient="records", date_unit="s"))[0] + row = dataframe_to_json_records(row_df)[0] row["hygiene_issues"] = profiling_queries.get_hygiene_issues( run_id, row["table_name"], row.get("column_name") ) @@ -325,21 +325,12 @@ def get_excel_report_data( def _format_top_freq_values(val): if not val or val == PII_REDACTED: return val - lines = [] - for part in val[2:].split("\n| "): - left, right = part.split(" | ") - lines.append(f"{right} | {left}") - return "\n".join(lines) + return "\n".join(f"{count} | {value}" for value, count in parse_top_freq_values(val)) def _format_top_patterns(val): if not val or val == PII_REDACTED: return val - parts = val.split(" | ") - formatted = [] - for index, part in enumerate(parts): - separator = "\n" if index % 2 else " | " - formatted.append(f"{part}{separator}") - return "".join(formatted) + return "\n".join(f"{count} | {pattern}" for pattern, count in parse_top_patterns(val)) data["top_freq_values"] = data["top_freq_values"].apply(_format_top_freq_values) data["top_patterns"] = data["top_patterns"].apply(_format_top_patterns) diff --git a/testgen/ui/views/profiling_runs.py b/testgen/ui/views/profiling_runs.py index e612d6e6..5707838a 100644 --- a/testgen/ui/views/profiling_runs.py +++ b/testgen/ui/views/profiling_runs.py @@ -13,8 +13,9 @@ RUN_NOTIFICATIONS_DIALOG_OPEN_COUNT_KEY = "pr:run_notifications_dialog_open_count" import testgen.ui.services.form_service as fm +from testgen.common.enums import JobSource, JobStatus from testgen.common.models import database_session, get_current_session, with_database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.notification_settings import ( ProfilingRunNotificationSettings, ProfilingRunNotificationTrigger, @@ -26,7 +27,13 @@ from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router -from testgen.ui.services.query_cache import get_profiling_run_summaries, get_project_summary, get_table_group_stats +from testgen.ui.services.query_cache import ( + get_profiling_run_summaries, + get_project_summary, + get_table_group_stats, + select_profiling_runs_where, + select_table_groups_minimal_where, +) from testgen.ui.session import session from testgen.ui.views.dialogs.manage_notifications import NotificationSettingsDialogBase from testgen.ui.views.dialogs.manage_schedules import ScheduleDialog @@ -61,7 +68,7 @@ def render(self, project_code: str, table_group_id: str | None = None, **_kwargs with st.spinner("Loading data ..."): project_summary = get_project_summary(project_code) profiling_runs, total_count = get_profiling_run_summaries(project_code, table_group_id, page=page) - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) schedule_obj = ProfilingScheduleDialog(project_code) ns_obj = ProfilingRunNotificationSettingsDialog( @@ -114,7 +121,7 @@ def on_run_profiling_confirmed(table_group: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group["id"])}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: @@ -224,7 +231,7 @@ class ProfilingScheduleDialog(ScheduleDialog): table_groups: Iterable[TableGroupMinimal] | None = None def init(self) -> None: - self.table_groups = TableGroup.select_minimal_where(TableGroup.project_code == self.project_code) + self.table_groups = select_table_groups_minimal_where(TableGroup.project_code == self.project_code) def get_arg_value(self, job): return next(item.table_groups_name for item in self.table_groups if str(item.id) == job.kwargs["table_group_id"]) @@ -235,8 +242,8 @@ def get_arg_value_options(self) -> list[dict[str, str]]: for table_group in self.table_groups ] - def get_job_arguments(self, arg_value: str) -> tuple[list[typing.Any], dict[str, typing.Any]]: - return [], {"table_group_id": str(arg_value)} + def get_job_arguments(self, arg_value: str) -> dict[str, typing.Any]: + return {"table_group_id": str(arg_value)} class ProfilingRunNotificationSettingsDialog(NotificationSettingsDialogBase): @@ -258,7 +265,7 @@ def _model_to_item_attrs(self, model: ProfilingRunNotificationSettings) -> dict[ def _get_component_props(self) -> dict[str, typing.Any]: table_group_options = [ (str(tg.id), tg.table_groups_name) - for tg in TableGroup.select_minimal_where(TableGroup.project_code == self.ns_attrs["project_code"]) + for tg in select_table_groups_minimal_where(TableGroup.project_code == self.ns_attrs["project_code"]) ] table_group_options.insert(0, (None, "All Table Groups")) trigger_labels = { @@ -300,7 +307,7 @@ def on_delete_runs(job_execution_ids: list[str]) -> None: continue if job_exec.status in (JobStatus.PENDING, JobStatus.CLAIMED, JobStatus.RUNNING, JobStatus.CANCEL_REQUESTED): job_exec.request_cancel() - profiling_run = next(iter(ProfilingRun.select_where(ProfilingRun.job_execution_id == je_id)), None) + profiling_run = next(iter(select_profiling_runs_where(ProfilingRun.job_execution_id == je_id)), None) if profiling_run: ProfilingRun.cascade_delete([str(profiling_run.id)]) get_current_session().delete(job_exec) diff --git a/testgen/ui/views/project_settings.py b/testgen/ui/views/project_settings.py index 8ed1a45e..ec93229b 100644 --- a/testgen/ui/views/project_settings.py +++ b/testgen/ui/views/project_settings.py @@ -1,17 +1,31 @@ import random import typing from dataclasses import asdict, dataclass, field +from datetime import UTC, datetime, timedelta import streamlit as st +from sqlalchemy import select from testgen.commands.run_observability_exporter import test_observability_exporter -from testgen.common.models import with_database_session +from testgen.common.enums import JobKey, JobSource, JobStatus +from testgen.common.models import database_session, with_database_session from testgen.common.models.job_execution import JobExecution +from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.project import Project +from testgen.common.models.scheduler import ( + DEFAULT_DATA_CLEANUP_CRON, + DEFAULT_RETENTION_CRON_TZ, + JobSchedule, +) +from testgen.common.models.test_run import TestRun from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page +from testgen.ui.services.query_cache import get_project, select_projects_where from testgen.ui.session import session, temp_value +from testgen.ui.utils import get_cron_sample_handler + +DEFAULT_RETENTION_DAYS = 180 PAGE_TITLE = "Project Settings" @@ -34,7 +48,12 @@ class ProjectSettingsPage(Page): existing_names: list[str] | None = None def render(self, project_code: str | None = None, **_kwargs) -> None: - self.project = Project.get(project_code) + self.project = get_project(project_code) + retention_schedule = JobSchedule.get( + JobSchedule.project_code == project_code, + JobSchedule.key == JobKey.run_data_cleanup, + ) + retention_last_run = self._get_last_cleanup_timestamp(project_code) testgen.page_header( PAGE_TITLE, @@ -42,11 +61,21 @@ def render(self, project_code: str | None = None, **_kwargs) -> None: ) get_test_results, set_test_results = temp_value(f"project_settings:{project_code}", default=None) + cron_sample_result, on_cron_sample = get_cron_sample_handler( + f"project_settings:cron_sample:{project_code}", sample_count=2, + ) + # Persistent session_state (not pop-on-read) so rapid days edits don't lose the response. + retention_preview_key = f"project_settings:retention_preview:{project_code}" def on_observability_connection_test(payload: dict) -> None: results = self.test_observability_connection(project_code, payload) set_test_results(asdict(results)) + def on_retention_preview(payload: dict) -> None: + st.session_state[retention_preview_key] = self._get_retention_preview( + project_code, payload.get("retention_days"), + ) + return testgen.project_settings( key="project_settings", data={ @@ -55,15 +84,58 @@ def on_observability_connection_test(payload: dict) -> None: "observability_api_url": self.project.observability_api_url, "observability_api_key": self.project.observability_api_key, "observability_test_results": get_test_results(), + "data_retention_enabled": self.project.data_retention_enabled, + "data_retention_days": self.project.data_retention_days or DEFAULT_RETENTION_DAYS, + "retention_cron_expr": retention_schedule.cron_expr if retention_schedule else DEFAULT_DATA_CLEANUP_CRON, + "retention_cron_tz": retention_schedule.cron_tz if retention_schedule else None, + "retention_cron_sample": cron_sample_result(), + "retention_last_run": int(retention_last_run.timestamp() * 1000) if retention_last_run else None, + "retention_preview": st.session_state.get(retention_preview_key), }, on_TestObservabilityClicked_change=on_observability_connection_test, + on_GetCronSample_change=on_cron_sample, + on_GetRetentionPreview_change=on_retention_preview, on_SaveClicked_change=lambda payload: self.update_project(project_code, payload), ) + @staticmethod + def _get_last_cleanup_timestamp(project_code: str) -> datetime | None: + with database_session() as session_: + return session_.scalar( + select(JobExecution.completed_at) + .where( + JobExecution.project_code == project_code, + JobExecution.job_key == JobKey.run_data_cleanup, + JobExecution.status == JobStatus.COMPLETED, + JobExecution.completed_at.isnot(None), + ) + .order_by(JobExecution.completed_at.desc()) + .limit(1) + ) + + @staticmethod + def _get_retention_preview(project_code: str, retention_days: int | None) -> dict | None: + if not retention_days or retention_days < 1: + return None + cutoff = datetime.now(UTC) - timedelta(days=retention_days) + with database_session(): + protected_profiling_ids = ProfilingRun.find_latest_per_table_group(project_code) + protected_test_ids = TestRun.find_latest_per_test_suite(project_code) + return { + "profiling_count": ProfilingRun.delete_older_than( + cutoff, project_code, protected_profiling_ids, dry_run=True, + ), + "test_count": TestRun.delete_older_than( + cutoff, project_code, protected_test_ids, dry_run=True, + ), + # Tiebreaker: identical counts for different days otherwise deep-equal and suppress the prop update. + "_": random.random(), # noqa: S311 + } + @with_database_session def update_project(self, project_code: str, edited_project: dict) -> None: existing_names = [ - p.project_name.lower() for p in Project.select_where(Project.project_code != project_code) + p.project_name.lower() for p in select_projects_where(Project.project_code != project_code) ] new_project_name = edited_project["name"] if new_project_name.lower() in existing_names: @@ -75,17 +147,36 @@ def update_project(self, project_code: str, edited_project: dict) -> None: self.project.use_dq_score_weights = edited_project.get("use_dq_score_weights", True) self.project.observability_api_url = edited_project.get("observability_api_url") self.project.observability_api_key = edited_project.get("observability_api_key") + + retention_enabled = bool(edited_project.get("data_retention_enabled")) + retention_days = edited_project.get("data_retention_days") or DEFAULT_RETENTION_DAYS + self.project.data_retention_enabled = retention_enabled + self.project.data_retention_days = retention_days if retention_enabled else None self.project.save() + get_project.clear() + select_projects_where.clear() + + if retention_enabled: + JobSchedule.upsert_for_retention( + project_code=project_code, + retention_days=retention_days, + cron_expr=edited_project.get("retention_cron_expr") or DEFAULT_DATA_CLEANUP_CRON, + cron_tz=edited_project.get("retention_cron_tz") or DEFAULT_RETENTION_CRON_TZ, + ) + else: + JobSchedule.delete_for_retention(project_code) if weights_changed: JobExecution.submit( - job_key="recalculate-project-scores", + job_key=JobKey.recalculate_project_scores, kwargs={"project_code": project_code}, - source="user", + source=JobSource.ui, project_code=project_code, ) st.toast("Scores will be recalculated in the background.") + st.toast("Project settings saved", icon=":material/task_alt:") + def test_observability_connection(self, project_code: str, edited_project: dict) -> "ObservabilityConnectionStatus": try: test_observability_exporter( diff --git a/testgen/ui/views/score_explorer.py b/testgen/ui/views/score_explorer.py index 5eca56ff..e9598ddd 100644 --- a/testgen/ui/views/score_explorer.py +++ b/testgen/ui/views/score_explorer.py @@ -1,19 +1,14 @@ import json import typing -from datetime import datetime from io import BytesIO from typing import ClassVar import pandas as pd import streamlit as st -from testgen.commands.run_refresh_score_cards_results import ( - run_recalculate_score_card, - run_refresh_score_cards_results, -) +from testgen.commands.run_refresh_score_cards_results import save_and_refresh_score_definition from testgen.common.mixpanel_service import MixpanelService from testgen.common.models import with_database_session -from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.scores import ( Categories, ScoreCategory, @@ -22,7 +17,6 @@ ScoreTypes, SelectedIssue, ) -from testgen.common.models.test_run import TestRun from testgen.common.pii_masking import get_pii_columns, mask_hygiene_detail, mask_profiling_pii from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import FILE_DATA_TYPE, download_dialog, zip_multi_file_data @@ -373,24 +367,11 @@ def save_score_definition(_) -> None: is_new = True score_definition = ScoreDefinition() - refresh_kwargs = {} if definition_id: is_new = False score_definition = ScoreDefinition.get(definition_id) project_code = score_definition.project_code - if is_new: - latest_run = max( - ProfilingRun.get_latest_run(project_code), - TestRun.get_latest_run(project_code), - key=lambda run: getattr(run, "run_time", datetime.min), - ) - - refresh_kwargs = { - "add_history_entry": True, - "refresh_date": latest_run.run_time if latest_run else None, - } - score_definition.project_code = project_code score_definition.name = name score_definition.total_score = total_score and total_score.lower() == "true" @@ -403,13 +384,9 @@ def save_score_definition(_) -> None: ], group_by_field=not filter_by_columns, ) - score_definition.save() - run_refresh_score_cards_results(definition_id=score_definition.id, **refresh_kwargs) + save_and_refresh_score_definition(score_definition, is_new=is_new) get_all_score_cards.clear() - if not is_new: - run_recalculate_score_card(project_code=project_code, definition_id=score_definition.id) - Router().set_query_params({ "name": None, "total_score": None, diff --git a/testgen/ui/views/table_groups.py b/testgen/ui/views/table_groups.py index ef93062c..bb9e186c 100644 --- a/testgen/ui/views/table_groups.py +++ b/testgen/ui/views/table_groups.py @@ -7,6 +7,7 @@ from sqlalchemy.exc import IntegrityError from testgen.commands.test_generation import run_monitor_generation +from testgen.common.enums import JobSource from testgen.common.models import get_current_session, with_database_session from testgen.common.models.connection import Connection from testgen.common.models.job_execution import JobExecution @@ -21,7 +22,16 @@ from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router from testgen.ui.queries import table_group_queries -from testgen.ui.services.query_cache import get_profiling_run_summaries, get_project_summary, get_table_group_stats +from testgen.ui.services.query_cache import ( + get_connection_minimal, + get_profiling_run_summaries, + get_project_summary, + get_table_group, + get_table_group_minimal, + get_table_group_stats, + select_connections_minimal_where, + select_table_groups_minimal_where, +) from testgen.ui.services.rerun_service import safe_rerun from testgen.ui.session import session, temp_value from testgen.ui.utils import get_cron_sample_handler @@ -73,7 +83,7 @@ def render( if table_group_name: table_group_filters.append(TableGroup.table_groups_name.ilike(f"%{table_group_name}%")) - table_groups = TableGroup.select_minimal_where(*table_group_filters) + table_groups = select_table_groups_minimal_where(*table_group_filters) connections = self._get_connections(project_code) wizard_mode = st.session_state.get("tg_wizard_mode") @@ -136,7 +146,7 @@ def on_run_profiling_confirmed(table_group: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group["id"])}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: @@ -281,7 +291,7 @@ def on_save_table_group_clicked(payload: dict): set_run_profiling(run_profiling) def on_close_clicked(_params: dict) -> None: - TableGroup.select_minimal_where.clear() + select_table_groups_minimal_where.clear() for key in ["tg_wizard_mode", "tg_wizard_connection_id", "tg_wizard_table_group_id"]: st.session_state.pop(key, None) @@ -327,7 +337,7 @@ def on_close_clicked(_params: dict) -> None: table_group = TableGroup(project_code=project_code) original_table_group_schema = None if table_group_id: - table_group = TableGroup.get(table_group_id) + table_group = get_table_group(table_group_id) original_table_group_schema = table_group.table_group_schema is_table_group_used = TableGroup.is_in_use([table_group_id]) @@ -403,7 +413,6 @@ def on_close_clicked(_params: dict) -> None: key=RUN_TESTS_JOB_KEY, cron_expr=standard_test_suite_data["schedule"], cron_tz=standard_test_suite_data["timezone"], - args=[], kwargs={"test_suite_id": str(standard_test_suite.id)}, ).save() @@ -435,7 +444,6 @@ def on_close_clicked(_params: dict) -> None: key=RUN_MONITORS_JOB_KEY, cron_expr=monitor_test_suite_data.get("schedule"), cron_tz=monitor_test_suite_data.get("timezone"), - args=[], kwargs={"test_suite_id": str(monitor_test_suite.id)}, ).save() @@ -450,7 +458,7 @@ def on_close_clicked(_params: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group.id)}, - source="ui", + source=JobSource.ui, project_code=table_group.project_code, ) message = f"Profiling run started for table group {table_group.table_groups_name}." @@ -526,7 +534,7 @@ def on_close_edit(_params: dict) -> None: for key in ["tg_wizard_mode", "tg_wizard_table_group_id"]: st.session_state.pop(key, None) - table_group = TableGroup.get(table_group_id) + table_group = get_table_group(table_group_id) original_schema = table_group.table_group_schema is_in_use = TableGroup.is_in_use([table_group_id]) @@ -593,9 +601,9 @@ def on_close_edit(_params: dict) -> None: def _get_connections(self, project_code: str, connection_id: str | None = None) -> list[dict]: if connection_id: - connections = [Connection.get_minimal(connection_id)] + connections = [get_connection_minimal(connection_id)] else: - connections = Connection.select_minimal_where(Connection.project_code == project_code) + connections = select_connections_minimal_where(Connection.project_code == project_code) return [ format_connection(connection) for connection in connections ] def _format_table_group_list( @@ -623,7 +631,7 @@ def _format_table_group_list( @with_database_session def _prepare_delete_dialog(self, table_group_id: str) -> None: - table_group = TableGroup.get_minimal(table_group_id) + table_group = get_table_group_minimal(table_group_id) can_be_deleted = not TableGroup.is_in_use([table_group_id]) st.session_state["tg_delete_dialog"] = { "open": True, @@ -636,7 +644,7 @@ def _execute_delete(self, table_group_id: str) -> None: table_group_name = st.session_state.get("tg_delete_dialog", {}).get("table_group", {}).get("table_groups_name", "") if not (ProfilingRun.has_active_job_for(TableGroup, table_group_id) or TestRun.has_active_job_for(TableGroup, table_group_id)): TableGroup.cascade_delete([table_group_id]) - TableGroup.select_minimal_where.clear() + select_table_groups_minimal_where.clear() st.toast(f"Table Group {table_group_name} has been deleted.", icon=":material/check:") else: st.toast("This Table Group is in use by a running process and cannot be deleted.", icon=":material/error:") diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py index e0d8d267..f460d873 100644 --- a/testgen/ui/views/test_definitions.py +++ b/testgen/ui/views/test_definitions.py @@ -1,4 +1,3 @@ -import json import logging import typing from datetime import UTC, datetime @@ -8,9 +7,10 @@ from sqlalchemy import and_, asc, case, desc, func, or_, tuple_ from testgen.common import date_service -from testgen.common.database.database_service import get_flavor_service, replace_params +from testgen.common.custom_test_validation import validate_custom_query +from testgen.common.database.database_service import get_flavor_service +from testgen.common.enums import JobSource from testgen.common.models import with_database_session -from testgen.common.models.connection import Connection from testgen.common.models.job_execution import JobExecution from testgen.common.models.table_group import TableGroup, TableGroupMinimal from testgen.common.models.test_definition import ( @@ -33,8 +33,18 @@ from testgen.ui.navigation.router import Router from testgen.ui.queries import profiling_queries from testgen.ui.services.database_service import fetch_all_from_db, fetch_df_from_db, fetch_from_target_db +from testgen.ui.services.query_cache import ( + get_connection, + get_table_group_minimal, + get_test_suite, + select_table_groups_minimal_where, + select_test_definitions_minimal_where, + select_test_definitions_page, + select_test_definitions_where, + select_test_suites_minimal_where, +) from testgen.ui.session import session -from testgen.utils import make_json_safe, to_dataframe +from testgen.utils import dataframe_to_json_records, make_json_safe, to_dataframe LOG = logging.getLogger("testgen") @@ -105,7 +115,7 @@ def render( sort: str | None = None, **_kwargs, ) -> None: - test_suite = TestSuite.get(test_suite_id) + test_suite = get_test_suite(test_suite_id) if not test_suite: self.router.navigate_with_warning( f"Test suite with ID '{test_suite_id}' does not exist. Redirecting to list of Test Suites ...", @@ -113,7 +123,7 @@ def render( ) return - table_group = TableGroup.get_minimal(test_suite.table_groups_id) + table_group = get_table_group_minimal(test_suite.table_groups_id) project_code = table_group.project_code if not session.auth.user_has_project_access(project_code): @@ -141,16 +151,14 @@ def render( with st.spinner("Loading data ..."): user_can_edit = session.auth.user_has_permission("edit") user_can_disposition = session.auth.user_has_permission("disposition") - df = get_test_definitions(test_suite, table_name, column_name, test_type, sorting_columns, - page=current_page, page_size=current_page_size, - flagged_filter=flagged) - total_count = get_test_definitions_count(test_suite, table_name, column_name, test_type, - flagged_filter=flagged) + df, total_count = get_test_definitions(test_suite, table_name, column_name, test_type, sorting_columns, + page_index=current_page, page_size=current_page_size, + flagged_filter=flagged) test_types = run_test_type_lookup_query().to_dict("records") table_columns = get_columns(str(table_group.id)) filter_columns_df = get_test_suite_columns(test_suite_id) - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) - all_test_suites = TestSuite.select_minimal_where( + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) + all_test_suites = select_test_suites_minimal_where( TestSuite.table_groups_id.in_([str(tg.id) for tg in table_groups]), TestSuite.is_monitor.isnot(True), ) @@ -182,6 +190,12 @@ def render( # Build dialog states validate_result = st.session_state.pop(TD_VALIDATE_RESULT_KEY, None) + qualifies_table_refs_with_schema = True + if st.session_state.get(TD_ADD_DIALOG_KEY) or st.session_state.get(TD_EDIT_DIALOG_KEY): + connection = get_connection(table_group.connection_id) + if connection: + qualifies_table_refs_with_schema = get_flavor_service(connection.sql_flavor).qualifies_table_refs_with_schema + add_dialog = None if st.session_state.get(TD_ADD_DIALOG_KEY): add_dialog = { @@ -191,6 +205,7 @@ def render( "table_groups_id": str(table_group.id), "table_group_schema": table_group.table_group_schema, "test_suite": test_suite_info, + "qualifies_table_refs_with_schema": qualifies_table_refs_with_schema, } edit_dialog = None @@ -202,6 +217,7 @@ def render( "table_columns": table_columns, "table_group_schema": table_group.table_group_schema, "test_suite": test_suite_info, + "qualifies_table_refs_with_schema": qualifies_table_refs_with_schema, } delete_dialog = None @@ -258,7 +274,7 @@ def on_edit_dialog_opened(payload: dict) -> None: # Fetch fresh row from the current data row_df = df[df["id"] == test_def_id] if not row_df.empty: - test_def = json.loads(row_df.to_json(orient="records", date_unit="s"))[0] + test_def = dataframe_to_json_records(row_df)[0] st.session_state[TD_EDIT_DIALOG_KEY] = test_def def on_delete_dialog_opened(selected: list) -> None: @@ -288,7 +304,7 @@ def on_unlock_all_opened(*_) -> None: def on_copy_move_dialog_opened(selected) -> None: if selected == "all": all_ids = get_test_definition_ids(test_suite, table_name, column_name, test_type, flagged_filter=flagged) - results = TestDefinition.select_where(TestDefinition.id.in_(all_ids)) + results = select_test_definitions_where(TestDefinition.id.in_(all_ids)) selected = [ {"id": str(r.id), "table_name": r.table_name, "column_name": r.column_name, "test_type": r.test_type, "lock_refresh": r.lock_refresh} @@ -318,9 +334,27 @@ def on_copy_move_dialog_closed(*_) -> None: st.session_state.pop(TD_COPY_MOVE_COLLISION_KEY, None) st.session_state.pop(TD_COPY_MOVE_OVERWRITE_KEY, None) + match_schema_test_types = { + tt["test_type"] + for tt in test_types + if "match_schema_name" in (tt.get("default_parm_columns") or "").split(",") + } + + def _default_match_schema(test_def: dict) -> None: + # The Match Schema field is hidden in the UI for flavors whose SQL doesn't + # qualify table refs with a schema, but downstream SQL/Python still expects + # match_schema_name populated for tests that support it. Default to the + # test's schema (or table-group schema) when the test type accepts + # match_schema_name and match_table_name is set. + if test_def.get("test_type") not in match_schema_test_types: + return + if test_def.get("match_table_name") and not test_def.get("match_schema_name"): + test_def["match_schema_name"] = test_def.get("schema_name") or table_group.table_group_schema + @with_database_session def on_add_test_saved(test_def: dict) -> None: test_def["last_manual_update"] = datetime.now(UTC) + _default_match_schema(test_def) td_columns = set(TestDefinition.__table__.columns.keys()) TestDefinition(**{k: v for k, v in test_def.items() if k in td_columns}).save() st.cache_data.clear() @@ -330,6 +364,7 @@ def on_add_test_saved(test_def: dict) -> None: @with_database_session def on_edit_test_saved(test_def: dict) -> None: test_def["last_manual_update"] = datetime.now(UTC) + _default_match_schema(test_def) td_columns = set(TestDefinition.__table__.columns.keys()) TestDefinition(**{k: v for k, v in test_def.items() if k in td_columns}).save() st.cache_data.clear() @@ -446,7 +481,7 @@ def on_run_tests_confirmed(data: dict) -> None: JobExecution.submit( job_key="run-tests", kwargs={"test_suite_id": str(selected_id)}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: @@ -530,7 +565,7 @@ def on_export_filtered(payload: dict) -> None: def on_export_selected(payload: dict) -> None: ids = payload.get("ids", []) if ids: - data = get_test_definitions(test_suite) + data, _ = get_test_definitions(test_suite) data = data[data["id"].isin(ids)] download_dialog( dialog_title="Download Excel Report", @@ -539,11 +574,22 @@ def on_export_selected(payload: dict) -> None: ) def on_filter_changed(filters: dict) -> None: + norm = lambda v: None if v in (None, "None", "") else str(v) + if ( + norm(filters.get("table_name")) == norm(table_name) + and norm(filters.get("column_name")) == norm(column_name) + and norm(filters.get("test_type")) == norm(test_type) + and norm(filters.get("flagged")) == norm(flagged) + and current_page == 0 + ): + return Router().set_query_params({**filters, "page": "0"}) def on_page_changed(payload: dict) -> None: new_page = payload.get("page", 0) new_page_size = payload.get("page_size") + if new_page == current_page and (new_page_size is None or int(new_page_size) == current_page_size): + return params: dict = {"page": str(new_page)} if new_page_size is not None: params["page_size"] = str(int(new_page_size)) @@ -557,6 +603,8 @@ def on_sort_changed(payload: dict) -> None: order = col.get("order", "asc") sort_parts.append(f"{field}:{order}") sort_value = ",".join(sort_parts) if sort_parts else None + if sort_value == sort and current_page == 0: + return Router().set_query_params({"sort": sort_value, "page": "0"}) testgen.test_definitions_widget( @@ -567,7 +615,7 @@ def on_sort_changed(payload: dict) -> None: "test_suite": test_suite.test_suite, "project_code": project_code, }, - "test_definitions": json.loads(df.to_json(orient="records", date_unit="s")), + "test_definitions": dataframe_to_json_records(df), "filter_options": { "tables": table_options, "columns": columns_raw, @@ -679,7 +727,7 @@ def get_excel_report_data( if data is not None: data = data.copy() else: - data = get_test_definitions(test_suite) + data, _ = get_test_definitions(test_suite) for key in ["test_active_display", "lock_refresh_display", "flagged_display"]: data[key] = data[key].apply(lambda val: val if val == "Yes" else None) @@ -759,7 +807,7 @@ def run_test_type_lookup_query(test_type: str | None = None) -> pd.DataFrame: @st.cache_data(show_spinner=False) def get_test_suite_columns(test_suite_id: str) -> pd.DataFrame: - results = TestDefinition.select_minimal_where( + results = select_test_definitions_minimal_where( TestDefinition.test_suite_id == test_suite_id, order_by=(asc(func.lower(TestDefinition.table_name)), asc(func.lower(TestDefinition.column_name))), ) @@ -772,10 +820,16 @@ def get_test_definitions( column_name: str | None = None, test_type: str | None = None, sorting_columns: list[tuple] | None = None, - page: int = 0, - page_size: int = 0, + page_index: int | None = None, + page_size: int = 500, flagged_filter: str | None = None, -) -> pd.DataFrame: +) -> tuple[pd.DataFrame, int]: + """Return ``(df, total_count)`` for test definitions matching the given filters. + + When ``page_index`` is provided (0-based), fetches only that page from + the DB using ``select_test_definitions_page()``; otherwise fetches all rows + via ``select_test_definitions_where()``. ``total_count`` is always the full matching count. + """ clauses = [TestDefinition.test_suite_id == test_suite.id] if table_name: clauses.append(TestDefinition.table_name == table_name) @@ -803,16 +857,18 @@ def get_test_definitions( else: order_by.append(sort_funcs[direction](func.lower(getattr(TestDefinition, attribute)))) - # For pagination, we need to bypass the base select_where which doesn't support offset/limit. - # We'll fetch all matching results and slice in Python. - test_definitions = TestDefinition.select_where( - *clauses, - order_by=tuple(order_by) if order_by else None, - ) + order_by_tuple = tuple(order_by) if order_by else None - if page_size > 0: - offset = page * page_size - test_definitions = list(test_definitions)[offset:offset + page_size] + if page_index is not None: + test_definitions, total_count = select_test_definitions_page( + *clauses, + order_by=order_by_tuple, + page=page_index + 1, + limit=page_size, + ) + else: + test_definitions = select_test_definitions_where(*clauses, order_by=order_by_tuple) + total_count = len(test_definitions) df = to_dataframe(test_definitions, TestDefinitionSummary.columns()) date_service.accommodate_dataframe_to_timezone(df, st.session_state) @@ -844,37 +900,7 @@ def get_export_to_observability_display(value: str) -> str: for col in df.select_dtypes(include=["datetime"]).columns: df[col] = df[col].astype(str).replace("NaT", "") - return df - - -def get_test_definitions_count( - test_suite: TestSuite, - table_name: str | None = None, - column_name: str | None = None, - test_type: str | None = None, - flagged_filter: str | None = None, -) -> int: - from testgen.ui.services.database_service import fetch_one_from_db - - where_parts = ["test_suite_id = :test_suite_id"] - params: dict = {"test_suite_id": str(test_suite.id)} - if table_name: - where_parts.append("table_name = :table_name") - params["table_name"] = table_name - if column_name: - where_parts.append("column_name ILIKE :column_name") - params["column_name"] = column_name - if test_type: - where_parts.append("test_type = :test_type") - params["test_type"] = test_type - if flagged_filter == "Flagged": - where_parts.append("flagged = true") - elif flagged_filter == "Not Flagged": - where_parts.append("flagged = false") - - query = f"SELECT COUNT(*) as cnt FROM test_definitions WHERE {' AND '.join(where_parts)};" - result = fetch_one_from_db(query, params) - return int(result["cnt"]) if result else 0 + return df, total_count def get_test_definition_ids( @@ -895,7 +921,7 @@ def get_test_definition_ids( clauses.append(TestDefinition.flagged == True) elif flagged_filter == "Not Flagged": clauses.append(TestDefinition.flagged == False) - results = TestDefinition.select_where(*clauses) + results = select_test_definitions_where(*clauses) return [str(r.id) for r in results] @@ -916,7 +942,7 @@ def get_test_definitions_collision( for item in test_definitions if item["column_name"] is not None ] - results = TestDefinition.select_minimal_where( + results = select_test_definitions_minimal_where( TestDefinition.table_groups_id == target_table_group_id, TestDefinition.test_suite_id == target_test_suite_id, TestDefinition.last_auto_gen_date.isnot(None), @@ -947,13 +973,13 @@ def get_columns(table_groups_id: str) -> list[dict]: def validate_test(test_definition: dict, table_group: TableGroupMinimal) -> None: schema = test_definition["schema_name"] table_name = test_definition["table_name"] - connection = Connection.get(table_group.connection_id) + connection = get_connection(table_group.connection_id) if test_definition["test_type"] == "Condition_Flag": condition = test_definition["custom_query"] flavor_service = get_flavor_service(connection.sql_flavor) concat_operator = flavor_service.concat_operator - quote = flavor_service.quote_character + table_ref = flavor_service.get_table_ref(schema, table_name) query = f""" SELECT COALESCE( @@ -965,17 +991,8 @@ def validate_test(test_definition: dict, table_group: TableGroupMinimal) -> None {concat_operator} '|', '|' ) - FROM {quote}{schema}{quote}.{quote}{table_name}{quote}; + FROM {table_ref}; """ + fetch_from_target_db(connection, query) else: - query = replace_params( - f""" - SELECT COUNT(*) - FROM ( - {test_definition["custom_query"]} - ) TEST - """, - {"DATA_SCHEMA": schema}, - ) - - fetch_from_target_db(connection, query) + validate_custom_query(connection, schema, test_definition["custom_query"]) diff --git a/testgen/ui/views/test_results.py b/testgen/ui/views/test_results.py index 014e4182..fd4d7b2c 100644 --- a/testgen/ui/views/test_results.py +++ b/testgen/ui/views/test_results.py @@ -1,4 +1,3 @@ -import json import typing from io import BytesIO from itertools import zip_longest @@ -11,10 +10,8 @@ from testgen.common import date_service from testgen.common.mixpanel_service import MixpanelService from testgen.common.models import with_database_session -from testgen.common.models.table_group import TableGroup from testgen.common.models.test_definition import TestDefinition, TestDefinitionNote, TestDefinitionSummary -from testgen.common.models.test_run import TestRun -from testgen.common.models.test_suite import TestSuite, TestSuiteMinimal +from testgen.common.models.test_suite import TestSuiteMinimal from testgen.common.pii_masking import get_pii_columns, mask_profiling_pii from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( @@ -35,9 +32,17 @@ get_test_issue_source_query_custom, ) from testgen.ui.services.database_service import execute_db_query, fetch_df_from_db, fetch_one_from_db +from testgen.ui.services.query_cache import ( + get_table_group_minimal, + get_test_definition, + get_test_run_minimal, + get_test_suite, + get_test_suite_minimal, + select_test_definitions_where, +) from testgen.ui.services.string_service import snake_case_to_title_case from testgen.ui.session import session -from testgen.utils import friendly_score, make_json_safe +from testgen.utils import dataframe_to_json_records, friendly_score, make_json_safe PAGE_PATH = "test-runs:results" PAGE_SIZE = 500 @@ -134,7 +139,7 @@ def render( sort: str | None = None, **_kwargs, ) -> None: - run = TestRun.get_minimal(run_id) + run = get_test_run_minimal(run_id) if not run: self.router.navigate_with_warning( f"Test run with ID '{run_id}' does not exist. Redirecting to list of Test Runs ...", @@ -166,7 +171,7 @@ def render( # Handle deferred export/issue report (still use st.dialog for file downloads) export_filters = st.session_state.pop(EXPORT_FILTERS_KEY, None) if export_filters is not None: - test_suite = TestSuite.get_minimal(run.test_suite_id) + test_suite = get_test_suite_minimal(run.test_suite_id) _handle_export(export_filters, run_id, run_date, test_suite) issue_report_data = st.session_state.pop(ISSUE_REPORT_KEY, None) @@ -216,9 +221,9 @@ def render( filter_options = test_result_queries.get_filter_options(run_id) - test_suite = TestSuite.get_minimal(run.test_suite_id) + test_suite = get_test_suite_minimal(run.test_suite_id) - items = json.loads(df.to_json(orient="records", date_unit="s")) + items = dataframe_to_json_records(df) summary = get_test_result_summary(run_id) score = friendly_score(run.dq_score_test_run) or "--" @@ -227,7 +232,7 @@ def render( if selected and (selected_item is None or selected_item.get("test_result_id") != selected): row_df = df[df["test_result_id"] == selected] if not row_df.empty: - row = json.loads(row_df.to_json(orient="records", date_unit="s"))[0] + row = dataframe_to_json_records(row_df)[0] selected_item = build_selected_item_data(row, test_suite) st.session_state[SELECTED_ITEM_KEY] = selected_item elif not selected: @@ -249,7 +254,7 @@ def on_row_selected(item_id: str) -> None: row_df = df[df["test_result_id"] == item_id] if row_df.empty: return - row = json.loads(row_df.to_json(orient="records", date_unit="s"))[0] + row = dataframe_to_json_records(row_df)[0] item_data = build_selected_item_data(row, test_suite) st.session_state[SELECTED_ITEM_KEY] = item_data Router().set_query_params({"selected": item_id}) @@ -364,7 +369,7 @@ def on_notes_dialog_closed(*_) -> None: def on_source_data_clicked(item_id: str) -> None: result_df = test_result_queries.get_test_results_by_ids([item_id]) if not result_df.empty: - row = json.loads(result_df.to_json(orient="records", date_unit="s"))[0] + row = result_df.where(result_df.notna(), None).to_dict(orient="records")[0] MixpanelService().send_event("view-source-data", page=PAGE_PATH, test_type=row.get("test_name_short")) mask_pii = not session.auth.user_has_permission("view_pii") st.session_state[SOURCE_DATA_KEY] = _build_source_data(row, mask_pii=mask_pii) @@ -412,7 +417,7 @@ def on_edit_test_saved(test_def: dict) -> None: def on_validate_test(test_def: dict) -> None: from testgen.ui.views.test_definitions import validate_test - table_group = TableGroup.get_minimal(test_suite.table_groups_id) + table_group = get_table_group_minimal(test_suite.table_groups_id) try: validate_test(test_def, table_group) st.session_state[VALIDATE_RESULT_KEY] = {"success": True, "message": "Validation is successful."} @@ -434,7 +439,7 @@ def on_issue_report_clicked(payload: dict) -> None: result_df = test_result_queries.get_test_results_by_ids(ids) if result_df.empty: return - rows = json.loads(result_df.to_json(orient="records", date_unit="s")) + rows = result_df.where(result_df.notna(), None).to_dict(orient="records") MixpanelService().send_event("download-issue-report", page=PAGE_PATH, issue_count=len(rows)) st.session_state[ISSUE_REPORT_KEY] = rows @@ -550,12 +555,12 @@ def _build_edit_test_dialog_data(test_definition_id: str | None, test_suite_mini from testgen.ui.views.test_definitions import get_columns, run_test_type_lookup_query - test_def = TestDefinition.select_where(TestDefinition.id == test_definition_id) + test_def = select_test_definitions_where(TestDefinition.id == test_definition_id) if not test_def: return None - full_test_suite = TestSuite.get(test_suite_minimal.id) - table_group = TableGroup.get_minimal(test_suite_minimal.table_groups_id) + full_test_suite = get_test_suite(test_suite_minimal.id) + table_group = get_table_group_minimal(test_suite_minimal.table_groups_id) test_def_row = test_def[0] test_def_dict = {col: getattr(test_def_row, col) for col in TestDefinitionSummary.columns()} for key in ["id", "table_groups_id", "profile_run_id", "test_suite_id"]: @@ -654,7 +659,7 @@ def build_selected_item_data(row: dict, test_suite: TestSuiteMinimal) -> dict: dfh = test_result_queries.get_test_result_history(row) time_columns = ["test_date"] date_service.accommodate_dataframe_to_timezone(dfh, st.session_state, time_columns) - history = json.loads(dfh.to_json(orient="records", date_unit="s")) + history = dataframe_to_json_records(dfh) test_definition = _build_test_definition_data(row.get("test_definition_id"), test_suite) @@ -672,7 +677,7 @@ def readable_boolean(v: bool) -> str: if not test_definition_id: return None - test_definition = TestDefinition.get(test_definition_id) + test_definition = get_test_definition(test_definition_id) if not test_definition: return None @@ -730,8 +735,7 @@ def readable_boolean(v: bool) -> str: def _handle_export(export_filters: dict, run_id: str, run_date: str, test_suite: TestSuiteMinimal) -> None: - from testgen.common.models.table_group import TableGroup - table_group = TableGroup.get_minimal(test_suite.table_groups_id) + table_group = get_table_group_minimal(test_suite.table_groups_id) export_type = export_filters.get("type", "all") with st.spinner("Loading data ..."): diff --git a/testgen/ui/views/test_runs.py b/testgen/ui/views/test_runs.py index b53a0d48..f6563b94 100644 --- a/testgen/ui/views/test_runs.py +++ b/testgen/ui/views/test_runs.py @@ -6,8 +6,9 @@ import streamlit as st import testgen.ui.services.form_service as fm +from testgen.common.enums import JobSource, JobStatus from testgen.common.models import database_session, get_current_session, with_database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.notification_settings import ( TestRunNotificationSettings, TestRunNotificationTrigger, @@ -20,7 +21,13 @@ from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router -from testgen.ui.services.query_cache import get_project_summary, get_test_run_summaries +from testgen.ui.services.query_cache import ( + get_project_summary, + get_test_run_summaries, + select_table_groups_minimal_where, + select_test_runs_where, + select_test_suites_minimal_where, +) from testgen.ui.session import session from testgen.ui.views.dialogs.manage_notifications import NotificationSettingsDialogBase from testgen.ui.views.dialogs.manage_schedules import ScheduleDialog @@ -60,8 +67,8 @@ def render(self, project_code: str, table_group_id: str | None = None, test_suit with st.spinner("Loading data ..."): project_summary = get_project_summary(project_code) test_runs, total_count = get_test_run_summaries(project_code, table_group_id, test_suite_id, page=page) - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) - test_suites = TestSuite.select_minimal_where(TestSuite.project_code == project_code, TestSuite.is_monitor.isnot(True)) + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) + test_suites = select_test_suites_minimal_where(TestSuite.project_code == project_code, TestSuite.is_monitor.isnot(True)) def on_run_tests_clicked(*_) -> None: st.session_state[TR_RUN_TESTS_DIALOG_KEY] = True @@ -108,7 +115,7 @@ def on_run_tests_confirmed(data: dict) -> None: JobExecution.submit( job_key="run-tests", kwargs={"test_suite_id": str(selected_id)}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: @@ -239,7 +246,7 @@ def _model_to_item_attrs(self, model: TestRunNotificationSettings) -> dict[str, def _get_component_props(self) -> dict[str, Any]: test_suite_options = [ (str(ts.id), ts.test_suite) - for ts in TestSuite.select_minimal_where( + for ts in select_test_suites_minimal_where( TestSuite.project_code == self.ns_attrs["project_code"], TestSuite.is_monitor.isnot(True), ) @@ -267,7 +274,7 @@ class TestRunScheduleDialog(ScheduleDialog): test_suites: Iterable[TestSuiteMinimal] | None = None def init(self) -> None: - self.test_suites = TestSuite.select_minimal_where( + self.test_suites = select_test_suites_minimal_where( TestSuite.project_code == self.project_code, TestSuite.is_monitor.isnot(True), ) @@ -281,8 +288,8 @@ def get_arg_value_options(self) -> list[dict[str, str]]: for test_suite in self.test_suites ] - def get_job_arguments(self, arg_value: str) -> tuple[list[typing.Any], dict[str, typing.Any]]: - return [], {"test_suite_id": str(arg_value)} + def get_job_arguments(self, arg_value: str) -> dict[str, typing.Any]: + return {"test_suite_id": str(arg_value)} @with_database_session @@ -312,7 +319,7 @@ def on_delete_runs(job_execution_ids: list[str]) -> None: continue if job_exec.status in (JobStatus.PENDING, JobStatus.CLAIMED, JobStatus.RUNNING, JobStatus.CANCEL_REQUESTED): job_exec.request_cancel() - test_run = next(iter(TestRun.select_where(TestRun.job_execution_id == je_id)), None) + test_run = next(iter(select_test_runs_where(TestRun.job_execution_id == je_id)), None) if test_run: TestRun.cascade_delete([str(test_run.id)]) get_current_session().delete(job_exec) diff --git a/testgen/ui/views/test_suites.py b/testgen/ui/views/test_suites.py index 605774f6..d1797f02 100644 --- a/testgen/ui/views/test_suites.py +++ b/testgen/ui/views/test_suites.py @@ -4,6 +4,7 @@ from testgen.commands.run_observability_exporter import export_test_results from testgen.commands.test_generation import run_test_generation +from testgen.common.enums import JobSource from testgen.common.models import database_session, with_database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.notification_settings import TestRunNotificationSettings @@ -14,7 +15,14 @@ from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router -from testgen.ui.services.query_cache import get_project_summary, get_test_suite_summaries +from testgen.ui.services.query_cache import ( + get_project_summary, + get_table_group, + get_test_suite, + get_test_suite_minimal, + get_test_suite_summaries, + select_table_groups_minimal_where, +) from testgen.ui.session import session from testgen.ui.views.dialogs.generate_tests_dialog import ( get_generation_set_choices, @@ -56,7 +64,7 @@ def render(self, project_code: str, table_group_id: str | None = None, test_suit "manage-test-suites", ) - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) user_can_edit = session.auth.user_has_permission("edit") test_suites = get_test_suite_summaries(project_code, table_group_id, test_suite_name) project_summary = get_project_summary(project_code) @@ -76,7 +84,7 @@ def render(self, project_code: str, table_group_id: str | None = None, test_suit "result": st.session_state.get("ts_form_dialog:result"), } elif edit_ts_id := st.session_state.get(EDIT_DIALOG_KEY): - selected = TestSuite.get(edit_ts_id) + selected = get_test_suite(edit_ts_id) form_dialog = { "open": True, "mode": "edit", @@ -132,7 +140,7 @@ def on_run_notifications_clicked(*_) -> None: generate_tests_data = None if generate_tests_ts_id := st.session_state.get(GENERATE_TESTS_DIALOG_KEY): - generate_ts = TestSuite.get_minimal(generate_tests_ts_id) + generate_ts = get_test_suite_minimal(generate_tests_ts_id) generation_sets = get_generation_set_choices() default_set = "Standard" if "Standard" in generation_sets else (generation_sets[0] if generation_sets else "") test_ct, unlocked_test_ct, unlocked_edits_ct = get_test_suite_refresh_warning(str(generate_ts.id)) @@ -173,7 +181,7 @@ def on_run_tests_confirmed(data: dict) -> None: JobExecution.submit( job_key="run-tests", kwargs={"test_suite_id": str(selected_id)}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: @@ -314,7 +322,7 @@ def save_test_suite_form(data: dict) -> None: if mode == "edit": test_suite_id = data.get("test_suite_id") - test_suite = TestSuite.get(test_suite_id) + test_suite = get_test_suite(test_suite_id) test_suite.test_suite_description = data.get("test_suite_description", "") test_suite.severity = data.get("severity") test_suite.export_to_observability = data.get("export_to_observability", False) @@ -328,7 +336,7 @@ def save_test_suite_form(data: dict) -> None: get_test_suite_summaries.clear() st.session_state[PAGE_RESULT_KEY] = {"success": True, "message": "Changes have been saved successfully."} else: - table_group = TableGroup.get(data.get("table_groups_id")) + table_group = get_table_group(data.get("table_groups_id")) test_suite = TestSuite() test_suite.project_code = table_group.project_code test_suite.test_suite = data.get("test_suite") @@ -350,7 +358,7 @@ def save_test_suite_form(data: dict) -> None: @with_database_session def prepare_ts_delete_dialog(test_suite_id: str) -> None: - selected = TestSuite.get_minimal(test_suite_id) + selected = get_test_suite_minimal(test_suite_id) is_in_use = TestSuite.is_in_use([selected.id]) st.session_state["ts_delete_dialog"] = { "open": True, @@ -374,7 +382,7 @@ def execute_ts_delete(test_suite_id: str) -> None: @with_database_session def observability_export_action(test_suite_id: str) -> None: - selected_test_suite = TestSuite.get_minimal(test_suite_id) + selected_test_suite = get_test_suite_minimal(test_suite_id) try: qty_of_exported_events = export_test_results(selected_test_suite.id) st.session_state[PAGE_RESULT_KEY] = {"success": True, "message": f"Export finished: {qty_of_exported_events} events exported."} diff --git a/testgen/utils/__init__.py b/testgen/utils/__init__.py index 7f3b71d5..5218d3d0 100644 --- a/testgen/utils/__init__.py +++ b/testgen/utils/__init__.py @@ -101,6 +101,8 @@ def make_json_safe(value: Any) -> str | bool | int | float | None: elif isinstance(value, UUID): return str(value) elif isinstance(value, datetime): + if value != value: # NaT (and other nan-like datetimes) are never equal to themselves + return None return int(value.replace(tzinfo=UTC).timestamp()) elif isinstance(value, date): return value.isoformat() @@ -115,6 +117,17 @@ def make_json_safe(value: Any) -> str | bool | int | float | None: return value +def dataframe_to_json_records(df: pd.DataFrame) -> list[dict]: + """Convert a DataFrame to JSON-safe records, one dict per row. + + Routes every cell through make_json_safe rather than DataFrame.to_json. to_json forces datetime values + through pandas' nanosecond Timestamp, which raises OverflowError on dates outside 1677-09-21..2262-04-11 + (e.g. the year-9999 / year-1 sentinel dates that SQL Server date/datetime2 columns commonly carry). + make_json_safe handles native datetimes via timedelta arithmetic, so any in-range datetime is unaffected. + """ + return [{key: make_json_safe(value) for key, value in record.items()} for record in df.to_dict(orient="records")] + + def chunk_queries(queries: list[str], join_string: str, max_query_length: int) -> list[str]: full_query = join_string.join(queries) if len(full_query) <= max_query_length: diff --git a/tests/unit/api/test_jobs.py b/tests/unit/api/test_jobs.py index 18d260a4..6af04eb4 100644 --- a/tests/unit/api/test_jobs.py +++ b/tests/unit/api/test_jobs.py @@ -205,7 +205,7 @@ def test_list_jobs_empty_project(mock_je_cls): def _client_with_overrides() -> TestClient: """Build a TestClient that bypasses auth and db_session so query validation runs unimpeded.""" app = FastAPI() - app.include_router(router) + app.include_router(router, prefix="/api/v1") app.dependency_overrides[db_session] = lambda: iter([None]) app.dependency_overrides[get_authorized_user] = lambda: MagicMock(id=uuid4()) return app diff --git a/tests/unit/commands/queries/test_execute_tests_query.py b/tests/unit/commands/queries/test_execute_tests_query.py index 71fb66dd..4839f99b 100644 --- a/tests/unit/commands/queries/test_execute_tests_query.py +++ b/tests/unit/commands/queries/test_execute_tests_query.py @@ -1,10 +1,12 @@ from datetime import UTC, datetime +from unittest.mock import patch from uuid import uuid4 import pytest from testgen.commands.queries.execute_tests_query import ( TestExecutionDef, + TestExecutionSQL, build_cat_expressions, group_cat_tests, parse_cat_results, @@ -359,3 +361,120 @@ def test_parse_result_code_negative_one(): rows = parse_cat_results(results, test_defs, uuid4(), uuid4(), datetime.now(UTC), _make_input_params_fn()) assert rows[0][10] == "-1" + + +# --- TestExecutionSQL freshness-gating helpers --- + + +def _make_execution_sql() -> TestExecutionSQL: + """Build a minimal TestExecutionSQL instance for testing instance methods. + + Bypasses __init__ (which hits the database) and sets only the attributes the + freshness-gating methods touch. + """ + instance = TestExecutionSQL.__new__(TestExecutionSQL) + instance._freshness_changed_cache = {} + return instance + + +FRESHNESS_FETCH_TARGET = "testgen.commands.queries.execute_tests_query.fetch_dict_from_db" + + +@patch.object(TestExecutionSQL, "_get_query", return_value=("SELECT ...", {})) +@patch(FRESHNESS_FETCH_TARGET) +def test_freshness_changed_true_when_result_signal_is_zero(mock_fetch, _mock_query): + mock_fetch.return_value = [{"result_signal": "0"}] + instance = _make_execution_sql() + assert instance._freshness_changed_for_table(_make_td()) is True + + +@patch.object(TestExecutionSQL, "_get_query", return_value=("SELECT ...", {})) +@patch(FRESHNESS_FETCH_TARGET) +def test_freshness_changed_false_when_result_signal_is_interval(mock_fetch, _mock_query): + mock_fetch.return_value = [{"result_signal": "1440"}] + instance = _make_execution_sql() + assert instance._freshness_changed_for_table(_make_td()) is False + + +@patch.object(TestExecutionSQL, "_get_query", return_value=("SELECT ...", {})) +@patch(FRESHNESS_FETCH_TARGET) +def test_freshness_changed_none_when_no_result(mock_fetch, _mock_query): + mock_fetch.return_value = [] + instance = _make_execution_sql() + assert instance._freshness_changed_for_table(_make_td()) is None + + +@patch.object(TestExecutionSQL, "_get_query", return_value=("SELECT ...", {})) +@patch(FRESHNESS_FETCH_TARGET) +def test_freshness_changed_cached_per_table(mock_fetch, _mock_query): + """Multiple Volume/Metric defs on the same table should not re-query.""" + mock_fetch.return_value = [{"result_signal": "0"}] + instance = _make_execution_sql() + instance._freshness_changed_for_table(_make_td(schema_name="s", table_name="t")) + instance._freshness_changed_for_table(_make_td(schema_name="s", table_name="t")) + assert mock_fetch.call_count == 1 + + +def test_resolve_cat_returns_definition_default_for_non_monitor_types(): + instance = _make_execution_sql() + td = _make_td(test_type="Alpha_Trunc", test_operator=">=", test_condition="50") + operator, condition = instance._resolve_cat_operator_and_condition(td) + assert (operator, condition) == (">=", "50") + + +def test_resolve_cat_returns_definition_default_when_no_gating(): + """Volume_Trend / Metric_Trend with no freshness_gated flag in prediction → band check.""" + instance = _make_execution_sql() + td = _make_td( + test_type="Volume_Trend", + test_operator="NOT BETWEEN", + test_condition="{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}", + prediction={"mean": {"123": 220.0}}, # no freshness_gated + ) + operator, condition = instance._resolve_cat_operator_and_condition(td) + assert operator == "NOT BETWEEN" + assert condition == "{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}" + + +@patch.object(TestExecutionSQL, "_freshness_changed_for_table", return_value=False) +def test_resolve_cat_stale_period_overrides_to_baseline_equality(_mock_changed): + """When freshness-gated and Freshness signal != '0' (no change), override to <> baseline.""" + instance = _make_execution_sql() + td = _make_td( + test_type="Volume_Trend", + test_operator="NOT BETWEEN", + test_condition="{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}", + prediction={"freshness_gated": True, "baseline_value": 220.0}, + ) + assert instance._resolve_cat_operator_and_condition(td) == ("<>", "220.0") + + +@patch.object(TestExecutionSQL, "_freshness_changed_for_table", return_value=True) +def test_resolve_cat_refresh_period_uses_band_check(_mock_changed): + """When freshness-gated and Freshness fired this run, fall through to band check.""" + instance = _make_execution_sql() + td = _make_td( + test_type="Volume_Trend", + test_operator="NOT BETWEEN", + test_condition="{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}", + prediction={"freshness_gated": True, "baseline_value": 220.0}, + ) + operator, condition = instance._resolve_cat_operator_and_condition(td) + assert operator == "NOT BETWEEN" + assert condition == "{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}" + + +@patch.object(TestExecutionSQL, "_freshness_changed_for_table", return_value=None) +def test_resolve_cat_no_freshness_result_uses_band_check(_mock_changed): + """When no Freshness_Trend has run for this table this run, fall back to band check.""" + instance = _make_execution_sql() + td = _make_td( + test_type="Metric_Trend", + test_operator="NOT BETWEEN", + test_condition="{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}", + prediction={"freshness_gated": True, "baseline_value": 5.5}, + ) + operator, condition = instance._resolve_cat_operator_and_condition(td) + assert operator == "NOT BETWEEN" + + diff --git a/tests/unit/commands/queries/test_refresh_data_chars_query.py b/tests/unit/commands/queries/test_refresh_data_chars_query.py index 9118d586..5dc3ff16 100644 --- a/tests/unit/commands/queries/test_refresh_data_chars_query.py +++ b/tests/unit/commands/queries/test_refresh_data_chars_query.py @@ -1,12 +1,40 @@ import pytest from testgen.commands.queries.refresh_data_chars_query import RefreshDataCharsSQL +from testgen.common.database.column_chars import ColumnChars from testgen.common.models.connection import Connection from testgen.common.models.table_group import TableGroup pytestmark = pytest.mark.unit +def _make_columns(*table_names: str) -> list[ColumnChars]: + return [ + ColumnChars(schema_name="default", table_name=name, column_name="id") + for name in table_names + ] + + +@pytest.mark.parametrize( + "flavor,expected_sql", + [ + ("postgresql", 'SELECT 1 FROM "test_schema"."orders" LIMIT 1'), + ("mssql", 'SELECT TOP 1 1 FROM "test_schema"."orders"'), + ("oracle", 'SELECT 1 FROM "test_schema"."orders" FETCH FIRST 1 ROWS ONLY'), + ], +) +def test_verify_access_uses_literal_1_projection(flavor, expected_sql): + """Access check uses literal ``1`` (not ``*``) — projection doesn't matter for an + existence/permission probe, and ``1`` avoids materialising columns on wide tables.""" + connection = Connection(sql_flavor=flavor) + table_group = TableGroup(table_group_schema="test_schema") + sql_generator = RefreshDataCharsSQL(connection, table_group) + + query, _ = sql_generator.verify_access("orders") + + assert query == expected_sql + + def test_include_exclude_mask_basic(): connection = Connection(sql_flavor="postgresql") table_group = TableGroup( @@ -107,3 +135,85 @@ def test_table_set_with_include_exclude(): assert "LIKE 'important%'" in criteria assert "AND NOT" in criteria assert "LIKE 'temp%'" in criteria + + +def test_filter_schema_columns_table_set(): + connection = Connection(sql_flavor="salesforce_data360") + table_group = TableGroup( + table_group_schema="default", + profiling_table_set="users, orders", + profiling_include_mask="", + profiling_exclude_mask="", + ) + sql_generator = RefreshDataCharsSQL(connection, table_group) + columns = _make_columns("users", "orders", "products", "logs") + + filtered = sql_generator.filter_schema_columns(columns) + + assert {c.table_name for c in filtered} == {"users", "orders"} + + +def test_filter_schema_columns_include_mask(): + connection = Connection(sql_flavor="salesforce_data360") + table_group = TableGroup( + table_group_schema="default", + profiling_table_set="", + profiling_include_mask="party_%, summary", + profiling_exclude_mask="", + ) + sql_generator = RefreshDataCharsSQL(connection, table_group) + columns = _make_columns("party_planners", "party_transactions", "summary", "audit_log") + + filtered = sql_generator.filter_schema_columns(columns) + + assert {c.table_name for c in filtered} == {"party_planners", "party_transactions", "summary"} + + +def test_filter_schema_columns_exclude_mask(): + connection = Connection(sql_flavor="salesforce_data360") + table_group = TableGroup( + table_group_schema="default", + profiling_table_set="", + profiling_include_mask="", + profiling_exclude_mask="tmp_%, raw_log", + ) + sql_generator = RefreshDataCharsSQL(connection, table_group) + columns = _make_columns("users", "tmp_x", "tmp_y", "raw_log", "orders") + + filtered = sql_generator.filter_schema_columns(columns) + + assert {c.table_name for c in filtered} == {"users", "orders"} + + +def test_filter_schema_columns_underscore_is_literal(): + # SQL LIKE _ wildcard semantics: the existing SQL path escapes user `_` to `\_`, + # treating `_` as a literal. The Python filter must match that behavior. + connection = Connection(sql_flavor="salesforce_data360") + table_group = TableGroup( + table_group_schema="default", + profiling_table_set="", + profiling_include_mask="a_b", + profiling_exclude_mask="", + ) + sql_generator = RefreshDataCharsSQL(connection, table_group) + columns = _make_columns("a_b", "axb", "axxb") + + filtered = sql_generator.filter_schema_columns(columns) + + assert {c.table_name for c in filtered} == {"a_b"} + + +def test_filter_schema_columns_no_filters_returns_all(): + connection = Connection(sql_flavor="salesforce_data360") + table_group = TableGroup( + table_group_schema="default", + profiling_table_set="", + profiling_include_mask="", + profiling_exclude_mask="", + ) + sql_generator = RefreshDataCharsSQL(connection, table_group) + columns = _make_columns("users", "orders") + + filtered = sql_generator.filter_schema_columns(columns) + + assert {c.table_name for c in filtered} == {"users", "orders"} diff --git a/tests/unit/commands/test_exec_job.py b/tests/unit/commands/test_exec_job.py index 21ae898a..56df805f 100644 --- a/tests/unit/commands/test_exec_job.py +++ b/tests/unit/commands/test_exec_job.py @@ -4,7 +4,8 @@ import pytest from testgen.commands.exec_job import exec_job -from testgen.commands.job_registry import JOB_DISPATCH, JOB_FINAL_CALLBACKS, run_final_callbacks +from testgen.commands.job_registry import JOB_DISPATCH, JOB_FINAL_CALLBACKS, JobConfig, run_final_callbacks +from testgen.common.enums import JobKey from testgen.common.models.job_execution import JobExecution pytestmark = pytest.mark.unit @@ -19,7 +20,7 @@ def mock_session(): yield session -def _make_job_exec(job_key="run-tests", status="claimed", **kwargs): +def _make_job_exec(job_key=JobKey.run_tests, status="claimed", **kwargs): job = MagicMock(spec=JobExecution) job.id = uuid4() job.job_key = job_key @@ -31,13 +32,13 @@ def _make_job_exec(job_key="run-tests", status="claimed", **kwargs): def test_exec_job_dispatches_run_tests(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True dispatch_mock = Mock(return_value="ok") with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": dispatch_mock}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=dispatch_mock)}), ): exec_job(job.id) @@ -47,14 +48,14 @@ def test_exec_job_dispatches_run_tests(mock_session): def test_exec_job_dispatches_run_profile(mock_session): - job = _make_job_exec(job_key="run-profile") + job = _make_job_exec(job_key=JobKey.run_profile) job.kwargs = {"table_group_id": "tg-123"} job.mark_running.return_value = True dispatch_mock = Mock(return_value="ok") with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-profile": dispatch_mock}), + patch.dict(JOB_DISPATCH, {JobKey.run_profile: JobConfig(handler=dispatch_mock)}), ): exec_job(job.id) @@ -63,13 +64,13 @@ def test_exec_job_dispatches_run_profile(mock_session): def test_exec_job_dispatches_run_monitors(mock_session): - job = _make_job_exec(job_key="run-monitors") + job = _make_job_exec(job_key=JobKey.run_monitors) job.mark_running.return_value = True dispatch_mock = Mock(return_value="ok") with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-monitors": dispatch_mock}), + patch.dict(JOB_DISPATCH, {JobKey.run_monitors: JobConfig(handler=dispatch_mock)}), ): exec_job(job.id) @@ -77,14 +78,14 @@ def test_exec_job_dispatches_run_monitors(mock_session): def test_exec_job_dispatches_run_test_generation(mock_session): - job = _make_job_exec(job_key="run-test-generation") + job = _make_job_exec(job_key=JobKey.run_test_generation) job.kwargs = {"test_suite_id": "suite-123", "generation_set": "Standard"} job.mark_running.return_value = True dispatch_mock = Mock(return_value="ok") with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-test-generation": dispatch_mock}), + patch.dict(JOB_DISPATCH, {JobKey.run_test_generation: JobConfig(handler=dispatch_mock)}), ): exec_job(job.id) @@ -103,7 +104,7 @@ def test_exec_job_marks_interrupted_on_unknown_key(mock_session): def test_exec_job_skips_when_mark_running_fails(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = False with patch.object(JobExecution, "get", return_value=job): @@ -113,12 +114,12 @@ def test_exec_job_skips_when_mark_running_fails(mock_session): def test_exec_job_marks_interrupted_on_dispatch_error(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(side_effect=RuntimeError("boom"))}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(side_effect=RuntimeError("boom")))}), ): exec_job(job.id) @@ -136,24 +137,24 @@ def test_exec_job_exits_on_missing_record(mock_session): def test_job_dispatch_has_all_job_keys(): - assert "run-profile" in JOB_DISPATCH - assert "run-tests" in JOB_DISPATCH - assert "run-monitors" in JOB_DISPATCH - assert "run-test-generation" in JOB_DISPATCH - assert "run-score-update" in JOB_DISPATCH - assert "recalculate-project-scores" in JOB_DISPATCH + assert JobKey.run_profile in JOB_DISPATCH + assert JobKey.run_tests in JOB_DISPATCH + assert JobKey.run_monitors in JOB_DISPATCH + assert JobKey.run_test_generation in JOB_DISPATCH + assert JobKey.run_score_update in JOB_DISPATCH + assert JobKey.recalculate_project_scores in JOB_DISPATCH def test_exec_job_fires_final_callbacks_on_success(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True job.mark_completed.return_value = True cb1, cb2 = Mock(), Mock() with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(return_value="ok")}), - patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [cb1, cb2]}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(return_value="ok"))}), + patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [cb1, cb2]}), ): exec_job(job.id) @@ -162,7 +163,7 @@ def test_exec_job_fires_final_callbacks_on_success(mock_session): def test_exec_job_runs_callbacks_in_registered_order(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True job.mark_completed.return_value = True order = [] @@ -171,8 +172,8 @@ def test_exec_job_runs_callbacks_in_registered_order(mock_session): with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(return_value="ok")}), - patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [cb1, cb2]}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(return_value="ok"))}), + patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [cb1, cb2]}), ): exec_job(job.id) @@ -180,15 +181,15 @@ def test_exec_job_runs_callbacks_in_registered_order(mock_session): def test_exec_job_skips_callbacks_when_mark_completed_fails(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True job.mark_completed.return_value = False cb = Mock() with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(return_value="ok")}), - patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [cb]}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(return_value="ok"))}), + patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [cb]}), ): exec_job(job.id) @@ -196,15 +197,15 @@ def test_exec_job_skips_callbacks_when_mark_completed_fails(mock_session): def test_exec_job_fires_callbacks_on_interrupted(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True job.mark_interrupted.return_value = True cb = Mock() with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(side_effect=RuntimeError("boom"))}), - patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [cb]}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(side_effect=RuntimeError("boom")))}), + patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [cb]}), ): exec_job(job.id) @@ -212,15 +213,15 @@ def test_exec_job_fires_callbacks_on_interrupted(mock_session): def test_exec_job_skips_callbacks_when_mark_interrupted_fails(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True job.mark_interrupted.return_value = False cb = Mock() with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(side_effect=RuntimeError("boom"))}), - patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [cb]}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(side_effect=RuntimeError("boom")))}), + patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [cb]}), ): exec_job(job.id) @@ -228,11 +229,11 @@ def test_exec_job_skips_callbacks_when_mark_interrupted_fails(mock_session): def test_run_final_callbacks_isolates_failures(): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) failing = Mock(side_effect=RuntimeError("boom"), __name__="failing_cb") succeeding = Mock(__name__="succeeding_cb") - with patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [failing, succeeding]}): + with patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [failing, succeeding]}): run_final_callbacks(job) failing.assert_called_once_with(job) @@ -247,6 +248,6 @@ def test_run_final_callbacks_noop_for_unknown_job_key(): def test_registered_callbacks_cover_notification_job_keys(): - assert "run-profile" in JOB_FINAL_CALLBACKS - assert "run-tests" in JOB_FINAL_CALLBACKS - assert "run-monitors" in JOB_FINAL_CALLBACKS + assert JobKey.run_profile in JOB_FINAL_CALLBACKS + assert JobKey.run_tests in JOB_FINAL_CALLBACKS + assert JobKey.run_monitors in JOB_FINAL_CALLBACKS diff --git a/tests/unit/commands/test_job_runner.py b/tests/unit/commands/test_job_runner.py index 3ac4ffa5..dadc66d2 100644 --- a/tests/unit/commands/test_job_runner.py +++ b/tests/unit/commands/test_job_runner.py @@ -4,7 +4,8 @@ import pytest from testgen.commands.job_runner import submit_and_wait -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.enums import JobStatus +from testgen.common.models.job_execution import JobExecution pytestmark = pytest.mark.unit diff --git a/tests/unit/commands/test_run_data_cleanup.py b/tests/unit/commands/test_run_data_cleanup.py new file mode 100644 index 00000000..6c6e4304 --- /dev/null +++ b/tests/unit/commands/test_run_data_cleanup.py @@ -0,0 +1,265 @@ +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, Mock, patch +from uuid import uuid4 + +import pytest + +from testgen.commands.run_data_cleanup import BATCH_SIZE, run_data_cleanup + +pytestmark = pytest.mark.unit + +MODULE = "testgen.commands.run_data_cleanup" + + +def _db_ctx(): + """Mock database_session() that yields nothing useful — the orchestrator's + nested with-blocks just need the context manager to enter/exit cleanly.""" + ctx = MagicMock() + ctx.__enter__ = Mock(return_value=MagicMock()) + ctx.__exit__ = Mock(return_value=False) + return ctx + + +def _patch_orchestrator( + protected_profiling: set | None = None, + protected_tests: set | None = None, + protected_profiling_jes: set | None = None, + protected_test_jes: set | None = None, + protected_history_keys: set | None = None, + deleted_profiling: int = 0, + deleted_tests: int = 0, + deleted_job_executions: int = 0, + deleted_score_history: int = 0, + deleted_score_latest: int = 0, + deleted_stg: tuple[int, int, int, int] = (0, 0, 0, 0), +): + """One-stop helper: patches every collaborator the orchestrator touches. + + Returns a dict of the patch mocks so individual tests can assert call shape. + """ + patches = { + "database_session": patch(f"{MODULE}.database_session", side_effect=lambda: _db_ctx()), + "ProfilingRun": patch(f"{MODULE}.ProfilingRun"), + "TestRun": patch(f"{MODULE}.TestRun"), + "JobExecution": patch(f"{MODULE}.JobExecution"), + "ScoreHistoryLatestRun": patch(f"{MODULE}.ScoreHistoryLatestRun"), + "ScoreDefinitionResultHistoryEntry": patch(f"{MODULE}.ScoreDefinitionResultHistoryEntry"), + "StgSecondaryProfileUpdate": patch(f"{MODULE}.StgSecondaryProfileUpdate"), + "StgFunctionalTableUpdate": patch(f"{MODULE}.StgFunctionalTableUpdate"), + "StgDataCharsUpdate": patch(f"{MODULE}.StgDataCharsUpdate"), + "StgTestDefinitionUpdate": patch(f"{MODULE}.StgTestDefinitionUpdate"), + } + started = {name: p.start() for name, p in patches.items()} + + started["ProfilingRun"].find_latest_per_table_group.return_value = protected_profiling or set() + # get_job_execution_ids returns dict[run_id, je_id]; orchestrator filters nulls. + started["ProfilingRun"].get_job_execution_ids.return_value = { + uuid4(): je_id for je_id in (protected_profiling_jes or set()) + } + started["ProfilingRun"].delete_older_than.return_value = deleted_profiling + + started["TestRun"].find_latest_per_test_suite.return_value = protected_tests or set() + started["TestRun"].get_job_execution_ids.return_value = { + uuid4(): je_id for je_id in (protected_test_jes or set()) + } + started["TestRun"].delete_older_than.return_value = deleted_tests + + started["JobExecution"].delete_older_than.return_value = deleted_job_executions + + started["ScoreHistoryLatestRun"].find_protected_keys.return_value = protected_history_keys or set() + started["ScoreHistoryLatestRun"].delete_older_than.return_value = deleted_score_latest + started["ScoreDefinitionResultHistoryEntry"].delete_older_than.return_value = deleted_score_history + + started["StgSecondaryProfileUpdate"].delete_older_than.return_value = deleted_stg[0] + started["StgFunctionalTableUpdate"].delete_older_than.return_value = deleted_stg[1] + started["StgDataCharsUpdate"].delete_older_than.return_value = deleted_stg[2] + started["StgTestDefinitionUpdate"].delete_older_than.return_value = deleted_stg[3] + + return started, patches + + +def _stop(patches): + for p in patches.values(): + p.stop() + + +def test_computes_cutoff_from_retention_days(): + """Cutoff passed to delete_older_than is `now - retention_days` (UTC).""" + started, patches = _patch_orchestrator() + try: + before = datetime.now(UTC) + run_data_cleanup(project_code="proj", retention_days=30) + after = datetime.now(UTC) + finally: + _stop(patches) + + cutoff = started["ProfilingRun"].delete_older_than.call_args.kwargs["cutoff"] + expected_low = before - timedelta(days=30) + expected_high = after - timedelta(days=30) + assert expected_low <= cutoff <= expected_high + # Same cutoff threads through every sweep + assert started["TestRun"].delete_older_than.call_args.kwargs["cutoff"] == cutoff + assert started["JobExecution"].delete_older_than.call_args.kwargs["cutoff"] == cutoff + + +def test_passes_protected_profiling_ids_to_delete(): + """Latest-run-per-table-group set is computed once and threaded through to + ProfilingRun.delete_older_than as the carve-out.""" + protected = {uuid4(), uuid4(), uuid4()} + started, patches = _patch_orchestrator(protected_profiling=protected) + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + started["ProfilingRun"].find_latest_per_table_group.assert_called_once_with("proj") + assert started["ProfilingRun"].delete_older_than.call_args.kwargs["protected_ids"] == protected + + +def test_passes_protected_test_run_ids_to_delete(): + """Latest-run-per-test-suite (incl. monitor suites) threads through to TestRun.delete_older_than.""" + protected = {uuid4(), uuid4()} + started, patches = _patch_orchestrator(protected_tests=protected) + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + started["TestRun"].find_latest_per_test_suite.assert_called_once_with("proj") + assert started["TestRun"].delete_older_than.call_args.kwargs["protected_ids"] == protected + + +def test_protected_job_execution_ids_is_union_of_run_je_ids(): + """JobExecution sweep carve-out = union of protected profiling + test run JE ids.""" + profiling_jes = {uuid4(), uuid4()} + test_jes = {uuid4()} + started, patches = _patch_orchestrator( + protected_profiling_jes=profiling_jes, + protected_test_jes=test_jes, + ) + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + passed = started["JobExecution"].delete_older_than.call_args.kwargs["protected_ids"] + assert passed == profiling_jes | test_jes + + +def test_score_history_uses_protected_keys_from_latest_runs(): + """find_protected_keys runs once with both run-id sets, and its result feeds + BOTH score-history sweeps (history entries + latest-runs mapping).""" + keys = {(uuid4(), datetime(2026, 1, 1)), (uuid4(), datetime(2026, 2, 1))} + profiling_ids = {uuid4()} + test_ids = {uuid4()} + started, patches = _patch_orchestrator( + protected_profiling=profiling_ids, + protected_tests=test_ids, + protected_history_keys=keys, + ) + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + started["ScoreHistoryLatestRun"].find_protected_keys.assert_called_once_with( + protected_profiling_ids=profiling_ids, + protected_test_run_ids=test_ids, + ) + assert started["ScoreDefinitionResultHistoryEntry"].delete_older_than.call_args.kwargs["protected_keys"] == keys + assert started["ScoreHistoryLatestRun"].delete_older_than.call_args.kwargs["protected_keys"] == keys + + +def test_staging_sweeps_get_no_carve_out(): + """All 4 staging models receive only cutoff + project_code — no protected_ids + arg (these tables have no per-run linkage).""" + started, patches = _patch_orchestrator() + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + for stg_name in [ + "StgSecondaryProfileUpdate", + "StgFunctionalTableUpdate", + "StgDataCharsUpdate", + "StgTestDefinitionUpdate", + ]: + call = started[stg_name].delete_older_than.call_args + # Positional args only: (cutoff, project_code) + assert len(call.args) == 2 + assert call.args[1] == "proj" + assert "protected_ids" not in call.kwargs + assert "protected_keys" not in call.kwargs + + +def test_batch_size_threaded_through(): + """The orchestrator's BATCH_SIZE constant is passed to every batch-capable sweep.""" + started, patches = _patch_orchestrator() + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + for collaborator, method in [ + ("ProfilingRun", "delete_older_than"), + ("TestRun", "delete_older_than"), + ("JobExecution", "delete_older_than"), + ("ScoreDefinitionResultHistoryEntry", "delete_older_than"), + ("ScoreHistoryLatestRun", "delete_older_than"), + ]: + kwargs = getattr(started[collaborator], method).call_args.kwargs + assert kwargs["batch_size"] == BATCH_SIZE, f"{collaborator}.{method} missing batch_size" + + +def test_summary_log_has_all_counts(caplog): + """The trailing summary log line includes the count from every sweep so the + operator can correlate what was deleted in a single grep.""" + import logging + caplog.set_level(logging.INFO, logger="testgen") + + started, patches = _patch_orchestrator( + deleted_profiling=10, + deleted_tests=20, + deleted_job_executions=30, + deleted_score_history=40, + deleted_score_latest=50, + deleted_stg=(1, 2, 3, 4), # sums to 10 + ) + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + summary = [r for r in caplog.records if "Data retention cleanup complete" in r.getMessage()] + assert len(summary) == 1 + msg = summary[0].getMessage() + assert "deleted_profiling=10" in msg + assert "deleted_tests=20" in msg + assert "deleted_job_executions=30" in msg + assert "deleted_score_history=40" in msg + assert "deleted_score_latest=50" in msg + assert "deleted_staging=10" in msg # sum of staging counts + + +def test_no_data_to_delete_runs_clean(): + """Empty everywhere: handler completes without error, all sweeps still invoked.""" + started, patches = _patch_orchestrator() + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + # Every sweep was still called (cleanup is unconditional once the schedule fires) + started["ProfilingRun"].delete_older_than.assert_called_once() + started["TestRun"].delete_older_than.assert_called_once() + started["JobExecution"].delete_older_than.assert_called_once() + started["ScoreDefinitionResultHistoryEntry"].delete_older_than.assert_called_once() + started["ScoreHistoryLatestRun"].delete_older_than.assert_called_once() + for stg in [ + "StgSecondaryProfileUpdate", + "StgFunctionalTableUpdate", + "StgDataCharsUpdate", + "StgTestDefinitionUpdate", + ]: + started[stg].delete_older_than.assert_called_once() diff --git a/tests/unit/commands/test_score_cards.py b/tests/unit/commands/test_score_cards.py index a537eee3..584d185c 100644 --- a/tests/unit/commands/test_score_cards.py +++ b/tests/unit/commands/test_score_cards.py @@ -1,8 +1,14 @@ +from datetime import UTC, datetime +from unittest.mock import patch from uuid import uuid4 import pytest -from testgen.commands.run_refresh_score_cards_results import _score_card_to_results +from testgen.commands.run_refresh_score_cards_results import ( + _score_card_to_results, + save_and_refresh_score_definition, +) +from testgen.common.models.scores import ScoreDefinition, ScoreDefinitionCriteria pytestmark = pytest.mark.unit @@ -80,3 +86,128 @@ def test_none_score_values(): results = _score_card_to_results(card) for result in results: assert result.score is None + + +# --- save_and_refresh_score_definition --- + + +def _fake_definition(project_code: str = "demo") -> ScoreDefinition: + sd = ScoreDefinition() + sd.id = uuid4() + sd.project_code = project_code + sd.name = "Card" + sd.total_score = True + sd.cde_score = False + sd.category = None + sd.criteria = ScoreDefinitionCriteria.from_filters( + [{"field": "table_groups_name", "value": "tg1"}], + group_by_field=True, + ) + return sd + + +def test_save_and_refresh_score_definition_for_existing_card_calls_save_refresh_and_recalculate(): + """is_new=False path: save → refresh → recalculate, all in that order.""" + sd = _fake_definition() + call_order: list[str] = [] + + def record(name): + def _called(*_a, **_kw): + call_order.append(name) + return _called + + with ( + patch.object(ScoreDefinition, "save", autospec=True, side_effect=record("save")), + patch( + "testgen.commands.run_refresh_score_cards_results.run_refresh_score_cards_results", + side_effect=record("refresh"), + ), + patch( + "testgen.commands.run_refresh_score_cards_results.run_recalculate_score_card", + side_effect=record("recalculate"), + ), + ): + save_and_refresh_score_definition(sd, is_new=False) + + assert call_order == ["save", "refresh", "recalculate"] + + +def test_save_and_refresh_score_definition_for_existing_card_passes_refresh_kwargs_for_update(): + """Updates (is_new=False) do NOT pass add_history_entry / refresh_date.""" + sd = _fake_definition() + + with ( + patch.object(ScoreDefinition, "save", autospec=True), + patch( + "testgen.commands.run_refresh_score_cards_results.run_refresh_score_cards_results", + ) as mock_refresh, + patch( + "testgen.commands.run_refresh_score_cards_results.run_recalculate_score_card", + ), + ): + save_and_refresh_score_definition(sd, is_new=False) + + mock_refresh.assert_called_once_with(definition_id=sd.id) + + +def test_save_and_refresh_score_definition_for_new_card_skips_recalculate(): + """is_new=True path: save → refresh with history kwargs; no recalculate.""" + sd = _fake_definition() + + fake_latest = type("Run", (), {"run_time": datetime(2026, 5, 1, tzinfo=UTC)})() + + with ( + patch.object(ScoreDefinition, "save", autospec=True), + patch( + "testgen.commands.run_refresh_score_cards_results.ProfilingRun.get_latest_run", + return_value=fake_latest, + ), + patch( + "testgen.commands.run_refresh_score_cards_results.TestRun.get_latest_run", + return_value=None, + ), + patch( + "testgen.commands.run_refresh_score_cards_results.run_refresh_score_cards_results", + ) as mock_refresh, + patch( + "testgen.commands.run_refresh_score_cards_results.run_recalculate_score_card", + ) as mock_recalc, + ): + save_and_refresh_score_definition(sd, is_new=True) + + mock_refresh.assert_called_once_with( + definition_id=sd.id, + add_history_entry=True, + refresh_date=fake_latest.run_time, + ) + mock_recalc.assert_not_called() + + +def test_save_and_refresh_score_definition_for_new_card_handles_no_runs(): + """When there are no profiling/test runs for the project, refresh_date is None.""" + sd = _fake_definition() + + with ( + patch.object(ScoreDefinition, "save", autospec=True), + patch( + "testgen.commands.run_refresh_score_cards_results.ProfilingRun.get_latest_run", + return_value=None, + ), + patch( + "testgen.commands.run_refresh_score_cards_results.TestRun.get_latest_run", + return_value=None, + ), + patch( + "testgen.commands.run_refresh_score_cards_results.run_refresh_score_cards_results", + ) as mock_refresh, + patch( + "testgen.commands.run_refresh_score_cards_results.run_recalculate_score_card", + ), + ): + save_and_refresh_score_definition(sd, is_new=True) + + mock_refresh.assert_called_once_with( + definition_id=sd.id, + add_history_entry=True, + refresh_date=None, + ) diff --git a/tests/unit/commands/test_thresholds_prediction.py b/tests/unit/commands/test_thresholds_prediction.py index f9df4592..37891a8c 100644 --- a/tests/unit/commands/test_thresholds_prediction.py +++ b/tests/unit/commands/test_thresholds_prediction.py @@ -1,5 +1,6 @@ import json -from unittest.mock import patch +from datetime import datetime +from unittest.mock import MagicMock, patch import pandas as pd import pytest @@ -8,7 +9,9 @@ from testgen.commands.test_thresholds_prediction import ( T_DISTRIBUTION_THRESHOLD, Z_SCORE_MAP, + TestThresholdsPrediction, compute_sarimax_threshold, + compute_volume_or_metric_threshold, ) from testgen.common.models.test_suite import PredictSensitivity from testgen.common.time_series_service import NotEnoughData @@ -16,6 +19,19 @@ pytestmark = pytest.mark.unit +def _make_prediction_instance(suite_id: str = "suite-xyz") -> TestThresholdsPrediction: + """Build a minimal TestThresholdsPrediction instance for testing instance methods. + + Bypasses __init__ (which queries the database) and sets just the attributes that + _get_query and methods under test rely on. + """ + instance = TestThresholdsPrediction.__new__(TestThresholdsPrediction) + instance.test_suite = MagicMock(id=suite_id) + instance.run_date = datetime(2026, 1, 1) + instance.tz = None + return instance + + def _make_history(n: int, value: float = 100.0) -> pd.DataFrame: """Build a minimal history DataFrame with n data points.""" dates = pd.date_range("2025-01-01", periods=n, freq="D") @@ -31,17 +47,6 @@ def _make_forecast(mean_values: list[float], se_values: list[float]) -> pd.DataF MOCK_TARGET = "testgen.commands.test_thresholds_prediction.get_sarimax_forecast" -# --- min_lookback guard --- - - -def test_below_min_lookback_returns_none(): - history = _make_history(3) - lower, upper, prediction = compute_sarimax_threshold(history, PredictSensitivity.medium, min_lookback=5) - assert lower is None - assert upper is None - assert prediction is None - - # --- Normal tolerance calculation (large sample, z-scores used directly) --- @@ -196,3 +201,199 @@ def test_all_z_score_columns_added_to_forecast(mock_forecast): for key in Z_SCORE_MAP: col = f"{key[0]}|{key[1].value}" assert col in forecast.columns + + +# --- TestThresholdsPrediction._fetch_freshness_updates_by_table --- +# +# Method fetches via _get_query → get_freshness_fingerprint_events.sql, which returns +# rows pre-filtered to fingerprint-change events and ordered by (schema, table, time). +# Tests mock the fetch and verify the indexing. + +FETCH_TARGET = "testgen.commands.test_thresholds_prediction.fetch_dict_from_db" + + +@patch(FETCH_TARGET) +def test_fetch_freshness_events_groups_by_table(mock_fetch): + mock_fetch.return_value = [ + {"schema_name": "s", "table_name": "t1", "test_run_id": "run_1"}, + {"schema_name": "s", "table_name": "t1", "test_run_id": "run_2"}, + {"schema_name": "s", "table_name": "t2", "test_run_id": "run_3"}, + ] + instance = _make_prediction_instance() + events = instance._fetch_freshness_updates_by_table() + assert set(events.keys()) == {("s", "t1"), ("s", "t2")} + assert events[("s", "t1")] == ["run_1", "run_2"] + assert events[("s", "t2")] == ["run_3"] + + +@patch(FETCH_TARGET) +def test_fetch_freshness_events_preserves_input_order(mock_fetch): + """SQL returns rows ordered by (schema, table, test_time); the method trusts that + order rather than re-sorting.""" + mock_fetch.return_value = [ + {"schema_name": "s", "table_name": "t", "test_run_id": "run_a"}, + {"schema_name": "s", "table_name": "t", "test_run_id": "run_b"}, + {"schema_name": "s", "table_name": "t", "test_run_id": "run_c"}, + ] + instance = _make_prediction_instance() + events = instance._fetch_freshness_updates_by_table() + assert events[("s", "t")] == ["run_a", "run_b", "run_c"] + + +@patch(FETCH_TARGET) +def test_fetch_freshness_events_coerces_run_id_to_str(mock_fetch): + """test_run_id can come back as a UUID object — must be cast to str for downstream + .isin() matching against the str-cast Volume/Metric test_run_id column.""" + from uuid import UUID as _UUID + rid = _UUID("12345678-1234-5678-1234-567812345678") + mock_fetch.return_value = [ + {"schema_name": "s", "table_name": "t", "test_run_id": rid}, + ] + instance = _make_prediction_instance() + events = instance._fetch_freshness_updates_by_table() + assert events[("s", "t")] == [str(rid)] + + +@patch(FETCH_TARGET) +def test_fetch_freshness_events_empty_result(mock_fetch): + mock_fetch.return_value = [] + instance = _make_prediction_instance() + assert instance._fetch_freshness_updates_by_table() == {} + + +@patch(FETCH_TARGET) +def test_fetch_freshness_events_passes_suite_id_through_get_query(mock_fetch): + """Reuses self._get_query, which substitutes TEST_SUITE_ID from self.test_suite.id.""" + mock_fetch.return_value = [] + instance = _make_prediction_instance(suite_id="suite-xyz") + instance._fetch_freshness_updates_by_table() + _query, params = mock_fetch.call_args.args + assert params["TEST_SUITE_ID"] == "suite-xyz" + + +# --- compute_volume_or_metric_threshold --- + + +def _history_with_run_ids(timestamps: list[str], run_ids: list[str], value: float = 100.0) -> pd.DataFrame: + """Build a Volume/Metric-shaped history: indexed by test_time, with a test_run_id + column matching how `run()` slices the historical-results dataframe per definition.""" + assert len(timestamps) == len(run_ids) + return pd.DataFrame( + {"result_signal": [value] * len(timestamps), "test_run_id": run_ids}, + index=pd.to_datetime(timestamps), + ) + + +@patch(MOCK_TARGET) +def test_freshness_gating_engages_when_filtered_fit_succeeds(mock_forecast): + mock_forecast.return_value = _make_forecast([220.0], [1.0]) + timestamps = [f"2026-01-{day:02d}" for day in range(1, 21)] + run_ids = [f"run_{i:02d}" for i in range(len(timestamps))] + history = _history_with_run_ids(timestamps, run_ids, value=220.0) + freshness_updates = run_ids[:8] + + lower, upper, baseline, prediction = compute_volume_or_metric_threshold( + history, freshness_updates, PredictSensitivity.medium, + ) + + assert lower is not None and upper is not None + assert baseline == 220.0 + assert prediction is not None + parsed = json.loads(prediction) + assert parsed["freshness_gated"] is True + assert parsed["baseline_value"] == 220.0 + + +@patch(MOCK_TARGET) +def test_freshness_gating_falls_back_when_filtered_fit_raises(mock_forecast): + """If SARIMAX fails on the freshness-filtered series (NotEnoughData after resample, + convergence), fall back to fitting on the raw value series and emit a prediction + without the freshness-gating markers.""" + raw_forecast = _make_forecast([220.0], [1.0]) + mock_forecast.side_effect = [NotEnoughData("not enough"), raw_forecast] + timestamps = [f"2026-01-{day:02d}" for day in range(1, 21)] + run_ids = [f"run_{i:02d}" for i in range(len(timestamps))] + history = _history_with_run_ids(timestamps, run_ids, value=220.0) + freshness_updates = run_ids[:5] # any selection — first call is forced to raise + + _, _, baseline, prediction = compute_volume_or_metric_threshold( + history, freshness_updates, PredictSensitivity.medium, + ) + + assert mock_forecast.call_count == 2 # filtered failed, raw retried + assert baseline is None + assert prediction is not None + parsed = json.loads(prediction) + assert "freshness_gated" not in parsed + assert "baseline_value" not in parsed + + +@patch(MOCK_TARGET) +def test_freshness_gating_falls_back_when_no_freshness_events(mock_forecast): + """Empty freshness_updates → filtered history is empty → filtered fit fails → + fall back to fitting on the raw series.""" + # First call (filtered, 0 rows) returns enough that compute_sarimax_threshold trips + # the NaN tolerance path; second call (raw) succeeds. + raw_forecast = _make_forecast([220.0], [1.0]) + mock_forecast.side_effect = [NotEnoughData("not enough"), raw_forecast] + timestamps = [f"2026-01-{day:02d}" for day in range(1, 21)] + run_ids = [f"run_{i:02d}" for i in range(len(timestamps))] + history = _history_with_run_ids(timestamps, run_ids) + + _, _, baseline, prediction = compute_volume_or_metric_threshold( + history, freshness_updates=[], sensitivity=PredictSensitivity.medium, + ) + + assert baseline is None + assert prediction is not None + parsed = json.loads(prediction) + assert "freshness_gated" not in parsed + + +@patch(MOCK_TARGET) +def test_freshness_gating_fits_on_filtered_series(mock_forecast): + """SARIMAX should be fit on the filtered series (one row per freshness change), + not on the raw plateau-laden series. Verified via the length of the dataframe + passed to get_sarimax_forecast on the engaging call.""" + mock_forecast.return_value = _make_forecast([220.0], [1.0]) + timestamps = [f"2026-01-{day:02d}" for day in range(1, 21)] + run_ids = [f"run_{i:02d}" for i in range(len(timestamps))] + history = _history_with_run_ids(timestamps, run_ids, value=220.0) + freshness_updates = run_ids[:8] + + compute_volume_or_metric_threshold( + history, freshness_updates, PredictSensitivity.medium, + ) + + fitted_history = mock_forecast.call_args.args[0] + assert len(fitted_history) == len(freshness_updates) + + +@patch(MOCK_TARGET) +def test_freshness_gating_baseline_from_filtered_when_events_extend_past_history(mock_forecast): + """When freshness_updates includes runs beyond the (retention-trimmed) history window, + baseline_value must come from the most recent filtered row — not from a run that's no + longer in history.""" + mock_forecast.return_value = _make_forecast([220.0], [1.0]) + # History only covers the first 8 days (run_00..run_07) + history_timestamps = [f"2026-01-{day:02d}" for day in range(1, 9)] + history_run_ids = [f"run_{i:02d}" for i in range(8)] + values = [float(i) for i in range(1, 9)] # distinct values so baseline is identifiable + history = pd.DataFrame( + {"result_signal": values, "test_run_id": history_run_ids}, + index=pd.to_datetime(history_timestamps), + ) + # Freshness events for those 8 runs PLUS 3 more that aren't in history (trimmed) + freshness_updates = history_run_ids + [f"run_{i}" for i in range(20, 23)] + + _, _, baseline, prediction = compute_volume_or_metric_threshold( + history, freshness_updates, PredictSensitivity.medium, + ) + + assert baseline == 8.0 + assert prediction is not None + parsed = json.loads(prediction) + assert parsed["freshness_gated"] is True + # baseline_value must be the value at the LAST timestamp present in BOTH history and + # freshness_updates (not freshness_updates[-1] which points past the history window) + assert parsed["baseline_value"] == 8.0 diff --git a/tests/unit/common/conftest.py b/tests/unit/common/conftest.py index 8646de5c..875d147b 100644 --- a/tests/unit/common/conftest.py +++ b/tests/unit/common/conftest.py @@ -134,18 +134,23 @@ def _run_scenario( sensitivity: PredictSensitivity, exclude_weekends: bool = False, tz: str | None = None, + min_lookback: int = 30, ) -> list[ScenarioPoint]: - """Iterate through csv_rows calling compute_freshness_threshold at each step.""" + """Iterate through csv_rows, mirroring the call shape of TestThresholdsPrediction.run(): + a min_lookback guard against the raw history, then compute_freshness_threshold.""" results: list[ScenarioPoint] = [] freshness_last_update: pd.Timestamp | None = None for i, (timestamp, value) in enumerate(csv_rows): history_df = _to_history_df(csv_rows[:i]) - lower, upper, staleness, prediction_json = compute_freshness_threshold( - history_df, sensitivity, min_lookback=30, - exclude_weekends=exclude_weekends, schedule_tz=tz, - ) + if len(history_df) < min_lookback: + lower = upper = staleness = prediction_json = None + else: + lower, upper, staleness, prediction_json = compute_freshness_threshold( + history_df, sensitivity, + exclude_weekends=exclude_weekends, schedule_tz=tz, + ) result_code, result_status = _evaluate_freshness_point( timestamp, value, lower, upper, staleness, prediction_json, diff --git a/tests/unit/common/models/test_job_execution.py b/tests/unit/common/models/test_job_execution.py index 2ff28c18..7ee152b7 100644 --- a/tests/unit/common/models/test_job_execution.py +++ b/tests/unit/common/models/test_job_execution.py @@ -1,9 +1,10 @@ +from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 import pytest -from testgen.common.models.job_execution import JobExecution +from testgen.common.models.job_execution import JobExecution, JobStatus pytestmark = pytest.mark.unit @@ -21,7 +22,13 @@ def _returning_row(job, **overrides): @pytest.fixture def mock_session(): session = MagicMock() - with patch(f"{MODULE}.get_current_session", return_value=session): + ctx = MagicMock() + ctx.__enter__ = Mock(return_value=session) + ctx.__exit__ = Mock(return_value=False) + with ( + patch(f"{MODULE}.get_current_session", return_value=session), + patch(f"{MODULE}.database_session", return_value=ctx), + ): yield session @@ -185,3 +192,97 @@ def test_request_cancel_terminal_state_returns_false(mock_session): assert job.request_cancel() is False assert job.status == "completed" + + +# ─── delete_older_than (data retention) ───────────────────────────── + + +def _capture_clauses_used_in_select(mock_session): + """Returns the WHERE clauses passed to the candidate-id select query. + + The cleanup loop does select(id).where(*clauses).limit(...). We capture + those clauses to assert which filters were applied.""" + select_call = mock_session.scalars.call_args + select_stmt = select_call.args[0] + return list(select_stmt.whereclause.clauses) if select_stmt.whereclause is not None else [] + + +def test_delete_older_than_filters_only_terminal_statuses(mock_session): + """The status filter is `IN ('completed', 'error', 'canceled')` — non-terminal + rows (pending/claimed/running/cancel_requested) are skipped regardless of age. + This is the key safety guarantee: live work must never be deleted.""" + mock_session.scalars.return_value.all.return_value = [] # no candidates → loop exits + + cutoff = datetime.now(UTC) - timedelta(days=180) + JobExecution.delete_older_than(cutoff=cutoff, project_code="proj", protected_ids=set()) + + clauses = _capture_clauses_used_in_select(mock_session) + status_clause = next( + (c for c in clauses if "status" in str(c).lower()), + None, + ) + assert status_clause is not None + rendered = str(status_clause.compile(compile_kwargs={"literal_binds": True})) + # Must include all three terminal states + for state in (JobStatus.COMPLETED.value, JobStatus.ERROR.value, JobStatus.CANCELED.value): + assert state in rendered + # Must not include any non-terminal state + for state in (JobStatus.PENDING.value, JobStatus.CLAIMED.value, + JobStatus.RUNNING.value, JobStatus.CANCEL_REQUESTED.value): + assert state not in rendered + + +def test_delete_older_than_returns_zero_when_no_candidates(mock_session): + """No-op when nothing is old enough to delete — returns 0, no DELETE executed.""" + mock_session.scalars.return_value.all.return_value = [] + + cutoff = datetime.now(UTC) - timedelta(days=180) + result = JobExecution.delete_older_than(cutoff=cutoff, project_code="proj", protected_ids=set()) + + assert result == 0 + # Only the candidate-select ran; no DELETE statement was issued. + mock_session.execute.assert_not_called() + + +def test_delete_older_than_batches_and_deletes(mock_session): + """Two-batch path: scalars returns one batch, then empty. Both should result + in a DELETE on the first batch, and the total count returned.""" + first_batch = [uuid4(), uuid4(), uuid4()] + mock_session.scalars.return_value.all.side_effect = [first_batch, []] + + cutoff = datetime.now(UTC) - timedelta(days=180) + result = JobExecution.delete_older_than( + cutoff=cutoff, project_code="proj", protected_ids=set(), batch_size=1000, + ) + + assert result == 3 + mock_session.execute.assert_called_once() # one DELETE for one non-empty batch + + +def test_delete_older_than_applies_protected_ids_exclusion(mock_session): + """The protected_ids carve-out — job_executions of protected runs — adds a + NOT IN clause so they survive even when older than the cutoff.""" + protected = {uuid4(), uuid4()} + mock_session.scalars.return_value.all.return_value = [] + + cutoff = datetime.now(UTC) - timedelta(days=180) + JobExecution.delete_older_than(cutoff=cutoff, project_code="proj", protected_ids=protected) + + clauses = _capture_clauses_used_in_select(mock_session) + rendered = " ".join(str(c) for c in clauses).lower() + assert "not in" in rendered or "!= all" in rendered or "in (" in rendered # NOT IN expression present + + +def test_delete_older_than_skips_protected_filter_when_empty(mock_session): + """Empty protected_ids → no NOT IN clause emitted, avoiding the SQL warning + that `IN ()` triggers in postgres.""" + mock_session.scalars.return_value.all.return_value = [] + + cutoff = datetime.now(UTC) - timedelta(days=180) + JobExecution.delete_older_than(cutoff=cutoff, project_code="proj", protected_ids=set()) + + clauses = _capture_clauses_used_in_select(mock_session) + rendered = " ".join(str(c) for c in clauses).lower() + # Three expected clauses: project_code, completed_at, status IN + # Absence of "not in" confirms the protected-ids clause was skipped. + assert "not in" not in rendered diff --git a/tests/unit/common/models/test_notification_settings.py b/tests/unit/common/models/test_notification_settings.py new file mode 100644 index 00000000..1a46817d --- /dev/null +++ b/tests/unit/common/models/test_notification_settings.py @@ -0,0 +1,139 @@ +"""Tests for ``NotificationSettings`` query semantics. + +The listing surface (``list_for_test_suite`` / ``list_for_table_group`` / +``list_for_score_definition``) must use strict equality on the scope column — +no ``IS NULL`` wildcard. The firing-pipeline surface (``_base_select_query``) +must keep the ``IS NULL`` wildcard so a project-wide notification matches +events on any child entity. +""" + +from unittest.mock import patch +from uuid import UUID, uuid4 + +import pytest + +from testgen.common.models.notification_settings import NotificationSettings, is_valid_email + +pytestmark = pytest.mark.unit + + +# ─── Shared email validation helper ─────────────────────────────────── + + +@pytest.mark.parametrize("addr", [ + "alice@example.com", + "a.b+tag@sub.domain.co", + "x_y%z@host-name.io", +]) +def test_is_valid_email_accepts_well_formed(addr): + assert is_valid_email(addr) is True + + +@pytest.mark.parametrize("addr", [ + "no-at-sign", + "spaces in@here.com", + "nodot@nope", + "@nodomain.com", + "trailing@dot.", + "", +]) +def test_is_valid_email_rejects_malformed(addr): + assert is_valid_email(addr) is False + + +def _captured_list_sql(method_name: str, *args, **kwargs) -> str: + """Invoke a ``list_for_*`` classmethod and compile the query it passes to ``_paginate``.""" + with patch.object(NotificationSettings, "_paginate", return_value=([], 0)) as mock_paginate: + getattr(NotificationSettings, method_name)(*args, **kwargs) + query = mock_paginate.call_args.args[0] + return str(query.compile(compile_kwargs={"literal_binds": True})) + + +def _uuid_in_sql(value: UUID, sql: str) -> bool: + """SQLAlchemy literal_binds compiles UUIDs as 32-char hex (no dashes); accept either.""" + return str(value) in sql or value.hex in sql + + +# ─── Listing surface — strict equality, no IS NULL ──────────────────── + + +def test_list_for_test_suite_filters_by_strict_equality_only(): + suite_id = uuid4() + sql = _captured_list_sql("list_for_test_suite", suite_id) + + assert "IS NULL" not in sql.upper(), ( + "list_for_test_suite must not surface rows where test_suite_id IS NULL — " + "they may be unrelated event types whose scope column happens to be null." + ) + assert "test_suite_id" in sql + assert _uuid_in_sql(suite_id, sql) + + +def test_list_for_table_group_filters_by_strict_equality_only(): + table_group_id = uuid4() + sql = _captured_list_sql("list_for_table_group", table_group_id) + + assert "IS NULL" not in sql.upper(), ( + "list_for_table_group must not surface rows where table_group_id IS NULL — " + "they may be unrelated event types whose scope column happens to be null." + ) + assert "table_group_id" in sql + assert _uuid_in_sql(table_group_id, sql) + + +def test_list_for_score_definition_filters_by_strict_equality_only(): + score_definition_id = uuid4() + sql = _captured_list_sql("list_for_score_definition", score_definition_id) + + assert "IS NULL" not in sql.upper(), ( + "list_for_score_definition must not surface rows where score_definition_id IS NULL — " + "they may be unrelated event types whose scope column happens to be null." + ) + assert "score_definition_id" in sql + assert _uuid_in_sql(score_definition_id, sql) + + +# ─── Firing pipeline — IS NULL preserved (regression guard) ─────────── +# +# `_base_select_query` is consumed by the notification firing pipeline, where +# a notification with `_id IS NULL` legitimately means "fires for any +# child of that type in the same project." Leaving this branch alone is the +# whole reason the listing-side fix is scoped to the `list_for_*` helpers. + + +def _firing_query_sql(**kwargs) -> str: + query = NotificationSettings._base_select_query(**kwargs) + return str(query.compile(compile_kwargs={"literal_binds": True})) + + +def test_base_select_query_test_suite_keeps_null_wildcard(): + suite_id = uuid4() + sql = _firing_query_sql(test_suite_id=suite_id) + + assert "IS NULL" in sql.upper(), ( + "_base_select_query is used by the firing pipeline, which needs " + "test_suite_id IS NULL to mean 'fires for any suite in the project'." + ) + assert _uuid_in_sql(suite_id, sql) + + +def test_base_select_query_table_group_keeps_null_wildcard(): + table_group_id = uuid4() + sql = _firing_query_sql(table_group_id=table_group_id) + + assert "IS NULL" in sql.upper(), ( + "_base_select_query is used by the firing pipeline, which needs " + "table_group_id IS NULL to mean 'fires for any table group in the project'." + ) + assert _uuid_in_sql(table_group_id, sql) + + +def test_base_select_query_score_definition_keeps_null_wildcard(): + score_definition_id = uuid4() + sql = _firing_query_sql(score_definition_id=score_definition_id) + + assert "IS NULL" in sql.upper(), ( + "_base_select_query is used by the firing pipeline, which needs " + "score_definition_id IS NULL to mean 'fires for any scorecard in the project'." + ) + assert _uuid_in_sql(score_definition_id, sql) diff --git a/tests/unit/common/models/test_scheduler.py b/tests/unit/common/models/test_scheduler.py new file mode 100644 index 00000000..f06f1292 --- /dev/null +++ b/tests/unit/common/models/test_scheduler.py @@ -0,0 +1,137 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from testgen.common.enums import JobKey +from testgen.common.models.scheduler import ( + DEFAULT_DATA_CLEANUP_CRON, + JobSchedule, +) + +pytestmark = pytest.mark.unit + +MODULE = "testgen.common.models.scheduler" + + +@pytest.fixture +def mock_session(): + session = MagicMock() + with patch(f"{MODULE}.get_current_session", return_value=session): + yield session + + +# ─── upsert_for_retention ─────────────────────────────────────────── + + +def test_upsert_for_retention_inserts_when_missing(mock_session): + """No existing schedule for (project, JobKey.run_data_cleanup) → INSERT path: + creates a fresh JobSchedule and adds it to the session.""" + mock_session.scalars.return_value.first.return_value = None + + schedule = JobSchedule.upsert_for_retention( + project_code="proj", + retention_days=90, + cron_expr="0 1 * * *", + cron_tz="UTC", + ) + + mock_session.add.assert_called_once() + added = mock_session.add.call_args[0][0] + assert added is schedule + assert schedule.project_code == "proj" + assert schedule.key == JobKey.run_data_cleanup + assert schedule.kwargs == {"project_code": "proj", "retention_days": 90} + assert schedule.cron_expr == "0 1 * * *" + assert schedule.cron_tz == "UTC" + assert schedule.active is True + + +def test_upsert_for_retention_updates_when_present(mock_session): + """Existing schedule for the same (project, key) → UPDATE path: mutates in + place; does NOT add a new row (would otherwise violate the table's + UNIQUE constraint and duplicate schedules per project).""" + existing = JobSchedule( + project_code="proj", + key=JobKey.run_data_cleanup, + kwargs={"project_code": "proj", "retention_days": 180}, + cron_expr="0 1 * * *", + cron_tz="UTC", + active=False, + ) + mock_session.scalars.return_value.first.return_value = existing + + result = JobSchedule.upsert_for_retention( + project_code="proj", + retention_days=30, + cron_expr="0 2 * * *", + cron_tz="America/New_York", + ) + + mock_session.add.assert_not_called() + assert result is existing + assert existing.kwargs == {"project_code": "proj", "retention_days": 30} + assert existing.cron_expr == "0 2 * * *" + assert existing.cron_tz == "America/New_York" + # Re-activated even when the previous schedule had been deactivated + assert existing.active is True + + +def test_upsert_for_retention_reactivates_inactive_schedule(mock_session): + """A specific guard: if a project's retention schedule was disabled (active=False) + and the user re-enables retention, the upsert flips active back to True.""" + existing = JobSchedule( + project_code="proj", + key=JobKey.run_data_cleanup, + kwargs={}, + cron_expr="0 1 * * *", + cron_tz="UTC", + active=False, + ) + mock_session.scalars.return_value.first.return_value = existing + + JobSchedule.upsert_for_retention( + project_code="proj", + retention_days=180, + cron_expr=DEFAULT_DATA_CLEANUP_CRON, + cron_tz="UTC", + ) + + assert existing.active is True + + +def test_upsert_for_retention_does_not_commit(mock_session): + """Like other model methods: the helper participates in the caller's + transaction; it must not commit on its own. The save() path is owned by + the request scope (database_session or safe_rerun).""" + mock_session.scalars.return_value.first.return_value = None + + JobSchedule.upsert_for_retention( + project_code="proj", + retention_days=180, + cron_expr=DEFAULT_DATA_CLEANUP_CRON, + cron_tz="UTC", + ) + + mock_session.commit.assert_not_called() + + +# ─── delete_for_retention ─────────────────────────────────────────── + + +def test_delete_for_retention_executes_scoped_delete(mock_session): + """Issues a single DELETE filtered to (project_code, JobKey.run_data_cleanup). + Idempotent — safe to call when no schedule exists (mock_session.execute + is a no-op).""" + JobSchedule.delete_for_retention("proj") + + mock_session.execute.assert_called_once() + stmt = mock_session.execute.call_args.args[0] + rendered = str(stmt.compile(compile_kwargs={"literal_binds": True})) + assert "DELETE FROM job_schedules" in rendered + assert "proj" in rendered + assert JobKey.run_data_cleanup in rendered + + +def test_delete_for_retention_does_not_commit(mock_session): + JobSchedule.delete_for_retention("proj") + mock_session.commit.assert_not_called() diff --git a/tests/unit/common/models/test_score_definition.py b/tests/unit/common/models/test_score_definition.py new file mode 100644 index 00000000..7978448b --- /dev/null +++ b/tests/unit/common/models/test_score_definition.py @@ -0,0 +1,473 @@ +"""Tests for ScoreDefinition.as_score_card() filter behavior across toggle combinations. + +Covers TG-1078: in CDE-only mode (total_score OFF, cde_score ON) the per-category +scores must be computed over CDE columns only. In all other modes the per-category +scores must be computed over the full column universe (no CDE filter). +""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from testgen.common.models.scores import ( + ScoreCategory, + ScoreDefinition, + ScoreDefinitionCriteria, + ScoreDefinitionFilter, +) + +pytestmark = pytest.mark.unit + + +CDE_FILTER_FRAGMENT = "critical_data_element = true" + + +def _make_definition( + *, + total_score: bool, + cde_score: bool, + category: ScoreCategory = ScoreCategory.dq_dimension, +) -> ScoreDefinition: + definition = ScoreDefinition( + project_code="demo", + name="Test card", + total_score=total_score, + cde_score=cde_score, + category=category, + ) + definition.criteria = ScoreDefinitionCriteria( + operand="AND", + group_by_field=True, + filters=[ScoreDefinitionFilter(field="table_groups_name", value="my_group")], + ) + return definition + + +def _capture_executed_sql(definition: ScoreDefinition) -> list[str]: + """Run as_score_card() against a mocked session and return the SQL of each execute call.""" + session = MagicMock() + mappings_result = MagicMock() + mappings_result.first.return_value = {} + mappings_result.all.return_value = [] + session.execute.return_value.mappings.return_value = mappings_result + + with patch("testgen.common.models.scores.get_current_session", return_value=session): + definition.as_score_card() + + return [str(call.args[0]) for call in session.execute.call_args_list] + + +@pytest.mark.parametrize( + "category", + [ScoreCategory.dq_dimension, ScoreCategory.impact_dimension, ScoreCategory.business_domain], +) +def test_categories_query_omits_cde_filter_in_total_only_mode(category): + definition = _make_definition(total_score=True, cde_score=False, category=category) + sql_calls = _capture_executed_sql(definition) + + assert len(sql_calls) == 2, "expected one overall and one categories query" + overall_sql, categories_sql = sql_calls + assert CDE_FILTER_FRAGMENT not in categories_sql + assert CDE_FILTER_FRAGMENT not in overall_sql + + +@pytest.mark.parametrize( + "category", + [ScoreCategory.dq_dimension, ScoreCategory.impact_dimension, ScoreCategory.business_domain], +) +def test_categories_query_omits_cde_filter_in_total_and_cde_mode(category): + definition = _make_definition(total_score=True, cde_score=True, category=category) + sql_calls = _capture_executed_sql(definition) + + assert len(sql_calls) == 2 + overall_sql, categories_sql = sql_calls + assert CDE_FILTER_FRAGMENT not in categories_sql + assert CDE_FILTER_FRAGMENT not in overall_sql + + +@pytest.mark.parametrize( + "category", + [ScoreCategory.dq_dimension, ScoreCategory.impact_dimension, ScoreCategory.business_domain], +) +def test_categories_query_includes_cde_filter_in_cde_only_mode(category): + definition = _make_definition(total_score=False, cde_score=True, category=category) + sql_calls = _capture_executed_sql(definition) + + assert len(sql_calls) == 2 + overall_sql, categories_sql = sql_calls + assert CDE_FILTER_FRAGMENT in categories_sql, ( + "Categories query must filter by CDE columns when the card is in CDE-only mode" + ) + # Overall query must stay un-filtered by CDE — it selects score and cde_score as + # separate columns, so adding the filter would zero out the non-CDE total. + assert CDE_FILTER_FRAGMENT not in overall_sql + + +def test_categories_query_uses_column_template_for_column_category(): + definition = _make_definition(total_score=False, cde_score=True, category=ScoreCategory.business_domain) + sql_calls = _capture_executed_sql(definition) + + categories_sql = sql_calls[1] + # Column-grouped template aggregates by a placeholder substituted into the SELECT. + assert "business_domain" in categories_sql + assert CDE_FILTER_FRAGMENT in categories_sql +# --- list_with_table_group_targets --- + + +def _row(definition_id, name, tg_names): + """Simulate a row returned by the recursive-CTE aggregate query.""" + row = MagicMock() + row.id = definition_id + row.name = name + row.tg_names = tg_names + return row + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_single_name_filter(mock_session_fn): + """A scorecard with one table_groups_name filter yields (id, name, [tg_name]).""" + def_id = uuid4() + mock_result = MagicMock() + mock_result.all.return_value = [_row(def_id, "orders-sc", ["orders"])] + mock_session_fn.return_value.execute.return_value = mock_result + + out = ScoreDefinition.list_with_table_group_targets("proj") + + assert out == [(def_id, "orders-sc", ["orders"])] + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_multiple_name_filters(mock_session_fn): + """A scorecard with multiple table_groups_name filters yields all names.""" + def_id = uuid4() + mock_result = MagicMock() + mock_result.all.return_value = [_row(def_id, "multi-sc", ["orders", "customers"])] + mock_session_fn.return_value.execute.return_value = mock_result + + out = ScoreDefinition.list_with_table_group_targets("proj") + + assert out == [(def_id, "multi-sc", ["orders", "customers"])] + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_no_name_filter(mock_session_fn): + """A scorecard with no table_groups_name filter yields an empty list of targets.""" + def_id = uuid4() + mock_result = MagicMock() + # Postgres array_agg with FILTER returns NULL when no rows match — the method + # must normalize this to []. + mock_result.all.return_value = [_row(def_id, "metadata-only-sc", None)] + mock_session_fn.return_value.execute.return_value = mock_result + + out = ScoreDefinition.list_with_table_group_targets("proj") + + assert out == [(def_id, "metadata-only-sc", [])] + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_filters_by_project_code(mock_session_fn): + """The query filters on project_code via the WHERE clause.""" + mock_result = MagicMock() + mock_result.all.return_value = [] + mock_session_fn.return_value.execute.return_value = mock_result + + ScoreDefinition.list_with_table_group_targets("my-project") + + args, _ = mock_session_fn.return_value.execute.call_args + compiled = args[0].compile(compile_kwargs={"literal_binds": True}) + sql = str(compiled) + assert "project_code" in sql + assert "'my-project'" in sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_uses_recursive_cte_on_filter_chain(mock_session_fn): + """The query SQL walks score_definition_filters via next_filter_id (recursive CTE).""" + mock_result = MagicMock() + mock_result.all.return_value = [] + mock_session_fn.return_value.execute.return_value = mock_result + + ScoreDefinition.list_with_table_group_targets("proj") + + args, _ = mock_session_fn.return_value.execute.call_args + sql = str(args[0].compile(compile_kwargs={"literal_binds": True})) + assert "RECURSIVE" in sql.upper() + assert "next_filter_id" in sql + assert "table_groups_name" in sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_empty_project(mock_session_fn): + """When the project has no scorecards, returns an empty list.""" + mock_result = MagicMock() + mock_result.all.return_value = [] + mock_session_fn.return_value.execute.return_value = mock_result + + assert ScoreDefinition.list_with_table_group_targets("proj") == [] + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_dedupes_repeated_names(mock_session_fn): + """A mode-2 scorecard with N chains all rooted at the same table_groups_name + must surface that name only once — otherwise the inventory tool lists the + scorecard once per chain under the same table group.""" + def_id = uuid4() + mock_result = MagicMock() + mock_result.all.return_value = [_row(def_id, "redbox-tables", ["redbox"] * 4)] + mock_session_fn.return_value.execute.return_value = mock_result + + out = ScoreDefinition.list_with_table_group_targets("proj") + + assert out == [(def_id, "redbox-tables", ["redbox"])] + + +# --- get_overall_issue_ct --- + + +def _definition_with_filter(project_code="demo", field="business_domain", value="Finance"): + """Build a transient ScoreDefinition with one filter.""" + definition = ScoreDefinition() + definition.project_code = project_code + definition.name = "test" + definition.total_score = True + definition.cde_score = False + definition.criteria = ScoreDefinitionCriteria( + operand="AND", + group_by_field=True, + filters=[ScoreDefinitionFilter(field=field, value=value)], + ) + return definition + + +@patch("testgen.common.models.scores.get_current_session") +def test_get_overall_issue_ct_sums_profile_and_test(mock_session_fn): + """Returns the sum of profile + test issue_ct from the two scoring views.""" + definition = _definition_with_filter() + # Two execute() calls; first returns profile sum, second returns test sum. + mock_session_fn.return_value.execute.side_effect = [ + MagicMock(scalar=MagicMock(return_value=7)), + MagicMock(scalar=MagicMock(return_value=3)), + ] + + assert definition.get_overall_issue_ct() == 10 + + +@patch("testgen.common.models.scores.get_current_session") +def test_get_overall_issue_ct_queries_both_views(mock_session_fn): + """Issues two queries — one against the profile view, one against the test view.""" + definition = _definition_with_filter() + mock_session_fn.return_value.execute.side_effect = [ + MagicMock(scalar=MagicMock(return_value=0)), + MagicMock(scalar=MagicMock(return_value=0)), + ] + + definition.get_overall_issue_ct() + + calls = mock_session_fn.return_value.execute.call_args_list + assert len(calls) == 2 + sql_1 = str(calls[0].args[0]) + sql_2 = str(calls[1].args[0]) + assert "v_dq_profile_scoring_latest_by_column" in sql_1 + assert "v_dq_test_scoring_latest_by_column" in sql_2 + # Both queries must use the same filters as as_score_card (project_code + criteria). + for sql in (sql_1, sql_2): + assert "project_code = 'demo'" in sql + assert "business_domain = 'Finance'" in sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_get_overall_issue_ct_handles_null_scalars(mock_session_fn): + """A NULL sum (no matching rows) is treated as 0, not None.""" + definition = _definition_with_filter() + mock_session_fn.return_value.execute.side_effect = [ + MagicMock(scalar=MagicMock(return_value=None)), + MagicMock(scalar=MagicMock(return_value=None)), + ] + + assert definition.get_overall_issue_ct() == 0 + + +def test_get_overall_issue_ct_no_filters_returns_zero(): + """When the definition has no filters, return 0 without hitting the DB.""" + definition = ScoreDefinition() + definition.project_code = "demo" + definition.name = "test" + definition.total_score = True + definition.cde_score = False + definition.criteria = ScoreDefinitionCriteria( + operand="AND", + group_by_field=True, + filters=[], + ) + + with patch("testgen.common.models.scores.get_current_session") as mock_session_fn: + assert definition.get_overall_issue_ct() == 0 + mock_session_fn.return_value.execute.assert_not_called() + + +# --- list_for_project --- + + +def _make_scorecard_orm(name: str, project_code: str = "demo") -> ScoreDefinition: + sd = ScoreDefinition() + sd.id = uuid4() + sd.project_code = project_code + sd.name = name + sd.total_score = True + sd.cde_score = False + return sd + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_returns_items_and_total(mock_session_fn): + """Returns (rows, total) from scalars().unique() and the count scalar.""" + sd_a = _make_scorecard_orm("Apple") + sd_b = _make_scorecard_orm("Mango") + + session = mock_session_fn.return_value + session.scalar.return_value = 2 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [sd_a, sd_b] + session.scalars.return_value = scalars_result + + items, total = ScoreDefinition.list_for_project("demo", page=1, limit=20) + + assert items == [sd_a, sd_b] + assert total == 2 + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_filters_by_project_code(mock_session_fn): + """The page query's compiled SQL must filter by project_code.""" + session = mock_session_fn.return_value + session.scalar.return_value = 0 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + ScoreDefinition.list_for_project("my-proj") + + page_call = session.scalars.call_args + sql = str(page_call.args[0].compile(compile_kwargs={"literal_binds": True})) + assert "project_code" in sql + assert "'my-proj'" in sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_orders_by_name(mock_session_fn): + """The page query must include ORDER BY name for stable pagination.""" + session = mock_session_fn.return_value + session.scalar.return_value = 0 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + ScoreDefinition.list_for_project("demo") + + sql = str(session.scalars.call_args.args[0].compile(compile_kwargs={"literal_binds": True})) + assert "ORDER BY" in sql.upper() + assert "score_definitions.name" in sql.lower() + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_applies_offset_and_limit(mock_session_fn): + """page=3, limit=10 → OFFSET 20 LIMIT 10.""" + session = mock_session_fn.return_value + session.scalar.return_value = 100 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + ScoreDefinition.list_for_project("demo", page=3, limit=10) + + sql = str(session.scalars.call_args.args[0].compile(compile_kwargs={"literal_binds": True})) + assert "LIMIT 10" in sql + assert "OFFSET 20" in sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_eager_loads_criteria(mock_session_fn): + """Criteria must be joinedload'd so the rendering loop doesn't fire N+1.""" + session = mock_session_fn.return_value + session.scalar.return_value = 0 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + ScoreDefinition.list_for_project("demo") + + sql = str(session.scalars.call_args.args[0].compile(compile_kwargs={"literal_binds": True})) + # joinedload emits a LEFT OUTER JOIN against the criteria table. + assert "score_definition_criteria" in sql.lower() + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_count_is_separate_query(mock_session_fn): + """A scalar count query runs alongside the paged scalars query.""" + session = mock_session_fn.return_value + session.scalar.return_value = 7 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + _, total = ScoreDefinition.list_for_project("demo") + + assert total == 7 + assert session.scalar.call_count == 1 + count_sql = str(session.scalar.call_args.args[0].compile(compile_kwargs={"literal_binds": True})) + assert "count(" in count_sql.lower() + assert "'demo'" in count_sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_count_null_returns_zero(mock_session_fn): + """When count() returns NULL on an empty table, normalize to 0.""" + session = mock_session_fn.return_value + session.scalar.return_value = None + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + items, total = ScoreDefinition.list_for_project("demo") + assert items == [] + assert total == 0 + + +# ─── names_by_id — single batched lookup, no N+1 ────────────────────── + + +@patch("testgen.common.models.scores.get_current_session") +def test_names_by_id_returns_id_to_name_mapping(mock_session_fn): + id_a, id_b = uuid4(), uuid4() + mock_result = MagicMock() + mock_result.all.return_value = [_row(id_a, "Card A", None), _row(id_b, "Card B", None)] + mock_session_fn.return_value.execute.return_value = mock_result + + out = ScoreDefinition.names_by_id([id_a, id_b]) + + assert out == {id_a: "Card A", id_b: "Card B"} + + +@patch("testgen.common.models.scores.get_current_session") +def test_names_by_id_empty_input_skips_query(mock_session_fn): + out = ScoreDefinition.names_by_id([]) + + assert out == {} + mock_session_fn.return_value.execute.assert_not_called() + + +@patch("testgen.common.models.scores.get_current_session") +def test_names_by_id_uses_single_in_query(mock_session_fn): + """One IN query for all IDs — not a per-ID lookup (N+1).""" + ids = [uuid4(), uuid4(), uuid4()] + mock_result = MagicMock() + mock_result.all.return_value = [] + mock_session_fn.return_value.execute.return_value = mock_result + + ScoreDefinition.names_by_id(ids) + + assert mock_session_fn.return_value.execute.call_count == 1 + args, _ = mock_session_fn.return_value.execute.call_args + sql = str(args[0].compile(compile_kwargs={"literal_binds": True})) + assert " IN (" in sql.upper() diff --git a/tests/unit/common/models/test_test_definition.py b/tests/unit/common/models/test_test_definition.py new file mode 100644 index 00000000..f733d1b6 --- /dev/null +++ b/tests/unit/common/models/test_test_definition.py @@ -0,0 +1,372 @@ +"""Tests for TestDefinition model methods.""" + +from datetime import datetime +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from testgen.common.models.test_definition import ( + InvalidTestDefinitionFields, + Severity, + TestDefinition, + TestDefinitionSummary, + _required_fields_for, +) + + +def make_test_type( + code: str = "Alpha_Trunc", + scope: str = "column", + param_columns: set[str] | None = None, + default_parm_columns: str | None = "threshold_value", + default_parm_required: str | None = None, +) -> MagicMock: + """Build a TestType-shaped mock with the attributes the validator reads.""" + tt = MagicMock() + tt.test_type = code + tt.test_scope = scope + tt.param_columns = param_columns if param_columns is not None else {"threshold_value"} + tt.default_parm_columns = default_parm_columns + tt.default_parm_required = default_parm_required + return tt + + +def make_td(**fields) -> TestDefinition: + """Build a TestDefinition with the given fields set, nothing else.""" + td = TestDefinition() + for key, value in fields.items(): + setattr(td, key, value) + return td + + +# -- _required_fields_for ----------------------------------------------------- + + +def test_required_fields_column_scope_adds_column_name(): + tt = make_test_type(scope="column") + assert "column_name" in _required_fields_for(tt) + + +def test_required_fields_table_scope_no_column_name(): + tt = make_test_type(code="Row_Ct", scope="table", param_columns=set(), default_parm_columns=None) + assert "column_name" not in _required_fields_for(tt) + + +def test_required_fields_custom_query_when_in_param_columns(): + tt = make_test_type( + code="CUSTOM", + scope="custom", + param_columns={"custom_query", "match_column_names"}, + default_parm_columns="custom_query,match_column_names", + ) + assert "custom_query" in _required_fields_for(tt) + + +def test_required_fields_parses_default_parm_required(): + tt = make_test_type( + code="Metric_Trend", + scope="custom", + param_columns={"custom_query", "threshold_value", "baseline_value"}, + default_parm_columns="custom_query,threshold_value,baseline_value", + default_parm_required="Y,Y,N", + ) + required = _required_fields_for(tt) + assert "custom_query" in required + assert "threshold_value" in required + assert "baseline_value" not in required + + +def test_required_fields_null_required_means_no_extras(): + tt = make_test_type(scope="column", default_parm_required=None) + assert _required_fields_for(tt) == {"column_name"} + + +# -- TestDefinition.editable_fields ------------------------------------------- + + +def test_editable_fields_includes_base_set(): + tt = make_test_type(param_columns=set(), default_parm_columns=None) + td = make_td() + accepted = td.editable_fields(tt) + assert {"test_active", "severity", "lock_refresh", "flagged", "test_description"} <= accepted + + +def test_editable_fields_includes_param_columns(): + tt = make_test_type(param_columns={"threshold_value", "baseline_value"}) + td = make_td() + accepted = td.editable_fields(tt) + assert {"threshold_value", "baseline_value"} <= accepted + + +def test_editable_fields_includes_impact_dimension_only_for_custom_or_referential_scope(): + """impact_dimension is overridable only for user-defined-semantic scopes.""" + td = make_td() + + custom_tt = make_test_type(scope="custom", param_columns={"custom_query"}) + assert "impact_dimension" in td.editable_fields(custom_tt) + + referential_tt = make_test_type(scope="referential", param_columns={"match_column_names"}) + assert "impact_dimension" in td.editable_fields(referential_tt) + + column_tt = make_test_type(scope="column", param_columns={"threshold_value"}) + assert "impact_dimension" not in td.editable_fields(column_tt) + + table_tt = make_test_type(scope="table", param_columns=set()) + assert "impact_dimension" not in td.editable_fields(table_tt) + + +def test_editable_fields_includes_column_name_only_for_column_or_custom_scope(): + """column_name is meaningful for column-scope (column under test) and custom-scope (label).""" + td = make_td() + + column_tt = make_test_type(scope="column", param_columns={"threshold_value"}) + assert "column_name" in td.editable_fields(column_tt) + + custom_tt = make_test_type(scope="custom", param_columns={"custom_query"}) + assert "column_name" in td.editable_fields(custom_tt) + + table_tt = make_test_type(scope="table", param_columns=set()) + assert "column_name" not in td.editable_fields(table_tt) + + referential_tt = make_test_type(scope="referential", param_columns={"match_column_names"}) + assert "column_name" not in td.editable_fields(referential_tt) + + +def test_editable_fields_does_not_leak_identity_or_internal_columns(): + tt = make_test_type(param_columns={"threshold_value"}) + td = make_td() + accepted = td.editable_fields(tt) + # Identity fields — callers must never set these via fields/extra_params + for forbidden in ("test_suite_id", "table_groups_id", "test_type", "schema_name"): + assert forbidden not in accepted + # Internal/system-managed columns + for forbidden in ("profile_run_id", "external_id", "prediction", "last_auto_gen_date"): + assert forbidden not in accepted + + +# -- TestDefinition.validate -------------------------------------------------- + + +def test_validate_happy_path(): + tt = make_test_type() + td = make_td(column_name="email", threshold_value="10") + td.validate(tt) # no raise + + +def test_validate_missing_required_column_name(): + tt = make_test_type(scope="column") + td = make_td(threshold_value="10") # no column_name + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "column_name" in exc_info.value.errors + + +def test_validate_wrong_scope_column_name_rejected(): + tt = make_test_type(code="Row_Ct", scope="table", param_columns=set()) + td = make_td(column_name="email") + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "column_name" in exc_info.value.errors + + +def test_validate_custom_scope_accepts_column_name_as_label(): + # CUSTOM uses column_name as a "Test Focus" label — must be accepted. + tt = make_test_type( + code="CUSTOM", + scope="custom", + param_columns={"custom_query"}, + default_parm_columns="custom_query", + ) + td = make_td(column_name="Negative Total Check", custom_query="SELECT 1") + td.validate(tt) # no raise + + +def test_validate_custom_query_not_accepted(): + tt = make_test_type() # param_columns = {threshold_value}; no custom_query allowed + td = make_td(column_name="email", threshold_value="10", custom_query="SELECT 1") + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "custom_query" in exc_info.value.errors + + +def test_validate_severity_accepts_valid_strenum_values(): + tt = make_test_type() + for value in ("Fail", "Warning"): + td = make_td(column_name="email", threshold_value="10", severity=value) + td.validate(tt) + + +def test_validate_severity_rejects_invalid(): + tt = make_test_type() + td = make_td(column_name="email", threshold_value="10", severity="critical") + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "severity" in exc_info.value.errors + + +def test_validate_severity_case_sensitive(): + # Per CLAUDE.md, case-sensitive — "fail" must be rejected. + tt = make_test_type() + td = make_td(column_name="email", threshold_value="10", severity="fail") + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "severity" in exc_info.value.errors + + +def test_validate_severity_empty_string_treated_as_unset(): + tt = make_test_type() + td = make_td(column_name="email", threshold_value="10", severity="") + td.validate(tt) # empty severity is OK — falls back to test type default + + +def test_validate_aggregates_errors(): + tt = make_test_type(scope="column") + td = make_td(severity="critical", custom_query="SELECT 1") # no column_name + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + errors = exc_info.value.errors + assert {"column_name", "severity", "custom_query"} <= errors.keys() + + +def test_validate_empty_string_treats_required_field_as_cleared(): + tt = make_test_type(scope="column") + td = make_td(column_name="", threshold_value="10") + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "column_name" in exc_info.value.errors + + +def test_severity_enum_value_accepted(): + # StrEnum subclasses str, so setting severity to the enum should pass validate. + tt = make_test_type() + td = make_td(column_name="email", threshold_value="10", severity=Severity.FAIL) + td.validate(tt) + + +# --- select_page --- + +def _make_summary_row(table_name: str = "my_table") -> dict: + return { + "id": uuid4(), + "table_groups_id": uuid4(), + "profile_run_id": uuid4(), + "test_type": "CUSTOM", + "test_suite_id": uuid4(), + "test_description": None, + "schema_name": "public", + "table_name": table_name, + "column_name": "col1", + "skip_errors": 0, + "baseline_ct": None, + "baseline_unique_ct": None, + "baseline_value": None, + "baseline_value_ct": None, + "threshold_value": None, + "baseline_sum": None, + "baseline_avg": None, + "baseline_sd": None, + "lower_tolerance": None, + "upper_tolerance": None, + "subset_condition": None, + "groupby_names": None, + "having_condition": None, + "window_date_column": None, + "window_days": None, + "match_schema_name": None, + "match_table_name": None, + "match_column_names": None, + "match_subset_condition": None, + "match_groupby_names": None, + "match_having_condition": None, + "custom_query": None, + "history_calculation": None, + "history_calculation_upper": None, + "history_lookback": None, + "test_active": True, + "test_definition_status": None, + "severity": None, + "lock_refresh": False, + "last_auto_gen_date": None, + "profiling_as_of_date": None, + "last_manual_update": datetime.now(), + "export_to_observability": False, + "prediction": None, + "flagged": False, + "impact_dimension": None, + "test_name_short": "Custom", + "default_test_description": "A test", + "measure_uom": "", + "measure_uom_description": "", + "default_parm_columns": "", + "default_parm_prompts": "", + "default_parm_help": "", + "default_parm_required": "", + "default_severity": "Warning", + "test_scope": "column", + "dq_dimension": "", + "default_impact_dimension": "", + "usage_notes": "", + } + + +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_returns_items_and_total(mock_get_session): + rows = [_make_summary_row("table_a"), _make_summary_row("table_b"), _make_summary_row("table_c")] + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 3 + mock_session.execute.return_value.mappings.return_value.all.return_value = rows + + items, total = TestDefinition.select_page() + + assert total == 3 + assert len(items) == 3 + assert all(isinstance(item, TestDefinitionSummary) for item in items) + assert items[0].table_name == "table_a" + assert items[2].table_name == "table_c" + + +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_empty_result_returns_zero_total(mock_get_session): + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 0 + mock_session.execute.return_value.mappings.return_value.all.return_value = [] + + items, total = TestDefinition.select_page() + + assert items == [] + assert total == 0 + + +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_uses_correct_offset_and_limit(mock_get_session): + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 0 + mock_session.execute.return_value.mappings.return_value.all.return_value = [] + + TestDefinition.select_page(page=3, limit=100) + + call_args = mock_session.execute.call_args + query = call_args[0][0] + compiled = query.compile(compile_kwargs={"literal_binds": True}) + sql = str(compiled) + + assert "LIMIT 100" in sql + assert "OFFSET 200" in sql + + +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_first_page_has_no_offset(mock_get_session): + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 0 + mock_session.execute.return_value.mappings.return_value.all.return_value = [] + + TestDefinition.select_page(page=1, limit=500) + + call_args = mock_session.execute.call_args + query = call_args[0][0] + compiled = query.compile(compile_kwargs={"literal_binds": True}) + sql = str(compiled) + + assert "LIMIT 500" in sql + assert "OFFSET 0" in sql diff --git a/tests/unit/common/test_custom_test_validation.py b/tests/unit/common/test_custom_test_validation.py new file mode 100644 index 00000000..0e007067 --- /dev/null +++ b/tests/unit/common/test_custom_test_validation.py @@ -0,0 +1,173 @@ +"""Tests for testgen.common.custom_test_validation.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from testgen.common.custom_test_validation import ( + CustomQueryResult, + validate_custom_query, +) +from testgen.common.database.flavor.flavor_service import FlavorService + + +def _flavor_service(row_limiting: str = "limit") -> FlavorService: + svc = FlavorService() + svc.row_limiting_clause = row_limiting # type: ignore[assignment] + return svc + + +def _connection(flavor: str = "postgresql") -> MagicMock: + conn = MagicMock() + conn.sql_flavor = flavor + return conn + + +# -- validate_custom_query ---------------------------------------------------- + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_count_only(mock_get_flavor, mock_fetch): + mock_get_flavor.return_value = _flavor_service("limit") + mock_fetch.return_value = [{"row_count": 0}] + + result = validate_custom_query( + _connection(), "demo", "SELECT * FROM orders WHERE total < 0", + ) + + assert isinstance(result, CustomQueryResult) + assert result.row_count == 0 + assert result.preview_rows == [] + # Only one fetch call: the count query + assert mock_fetch.call_count == 1 + # Verify the count query is wrapped with ERR_TABLE + count_sql = mock_fetch.call_args_list[0].args[1] + assert "SELECT COUNT(*)" in count_sql + assert "ERR_TABLE" in count_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_with_preview(mock_get_flavor, mock_fetch): + mock_get_flavor.return_value = _flavor_service("limit") + preview_row = MagicMock() + preview_row.keys.return_value = ["order_id", "amount"] + mock_fetch.side_effect = [ + [{"row_count": 3}], # count query result + [preview_row], # preview query result + ] + + result = validate_custom_query( + _connection(), "demo", "SELECT * FROM orders WHERE total < 0", preview_limit=1, + ) + + assert result.row_count == 3 + assert result.preview_rows == [preview_row] + assert mock_fetch.call_count == 2 + preview_sql = mock_fetch.call_args_list[1].args[1] + assert "SELECT" in preview_sql + assert "ERR_TABLE" in preview_sql + assert "LIMIT 1" in preview_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_preview_skipped_when_no_rows(mock_get_flavor, mock_fetch): + mock_get_flavor.return_value = _flavor_service("limit") + mock_fetch.return_value = [{"row_count": 0}] + + result = validate_custom_query( + _connection(), "demo", "SELECT 1 WHERE 1=0", preview_limit=5, + ) + + assert result.row_count == 0 + assert result.preview_rows == [] + # Preview query should NOT run when count is 0 + assert mock_fetch.call_count == 1 + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_substitutes_data_schema(mock_get_flavor, mock_fetch): + mock_get_flavor.return_value = _flavor_service("limit") + mock_fetch.return_value = [{"row_count": 0}] + + validate_custom_query( + _connection(), + "production_schema", + "SELECT * FROM {DATA_SCHEMA}.orders", + ) + + count_sql = mock_fetch.call_args_list[0].args[1] + # {DATA_SCHEMA} was substituted with the actual schema name + assert "production_schema.orders" in count_sql + assert "{DATA_SCHEMA}" not in count_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_strips_trailing_semicolon(mock_get_flavor, mock_fetch): + """Trailing semicolons break the subquery wrap — must be stripped.""" + mock_get_flavor.return_value = _flavor_service("limit") + mock_fetch.return_value = [{"row_count": 0}] + + validate_custom_query( + _connection(), "demo", "SELECT 1; ", + ) + + count_sql = mock_fetch.call_args_list[0].args[1] + # The subquery should not contain a trailing semicolon + assert "SELECT 1)" in count_sql or "SELECT 1 )" in count_sql + # Specifically, the inner SELECT 1 should not be followed by ; inside the wrap + assert "SELECT 1;" not in count_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_uses_flavor_specific_limit(mock_get_flavor, mock_fetch): + """Oracle uses FETCH FIRST; MSSQL uses TOP — preview SQL must respect the flavor.""" + mock_get_flavor.return_value = _flavor_service("fetch") + preview_row = MagicMock() + mock_fetch.side_effect = [ + [{"row_count": 5}], + [preview_row], + ] + + validate_custom_query( + _connection("oracle"), "demo", "SELECT * FROM t", preview_limit=1, + ) + + preview_sql = mock_fetch.call_args_list[1].args[1] + assert "FETCH FIRST 1 ROWS ONLY" in preview_sql + assert "LIMIT" not in preview_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_top_flavor_uses_prefix(mock_get_flavor, mock_fetch): + mock_get_flavor.return_value = _flavor_service("top") + preview_row = MagicMock() + mock_fetch.side_effect = [ + [{"row_count": 5}], + [preview_row], + ] + + validate_custom_query( + _connection("mssql"), "demo", "SELECT * FROM t", preview_limit=1, + ) + + preview_sql = mock_fetch.call_args_list[1].args[1] + assert "TOP 1" in preview_sql + assert "LIMIT" not in preview_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_propagates_db_errors(mock_get_flavor, mock_fetch): + """DB errors propagate as-is — caller decides how to surface them.""" + mock_get_flavor.return_value = _flavor_service("limit") + mock_fetch.side_effect = Exception("syntax error at or near 'DROP'") + + with pytest.raises(Exception, match="syntax error"): + validate_custom_query(_connection(), "demo", "DROP TABLE orders") diff --git a/tests/unit/common/test_date_service.py b/tests/unit/common/test_date_service.py index d9f8af96..174f8a44 100644 --- a/tests/unit/common/test_date_service.py +++ b/tests/unit/common/test_date_service.py @@ -39,6 +39,16 @@ def test_parses_string_date(self): result = parse_fuzzy_date("2024-03-15 10:30:45") assert result == datetime(2024, 3, 15, 10, 30, 45) + def test_parses_string_date_with_microseconds(self): + # DB timestamp strings carry fractional seconds; the source-data lookups + # (PROFILE_RUN_DATE / TEST_DATE) feed these through parse_fuzzy_date. + result = parse_fuzzy_date("2026-06-02 06:54:30.105548") + assert result == datetime(2026, 6, 2, 6, 54, 30, 105548) + + def test_parses_iso_t_separator(self): + result = parse_fuzzy_date("2026-06-02T06:54:30") + assert result == datetime(2026, 6, 2, 6, 54, 30) + def test_parses_unix_timestamp_seconds(self): result = parse_fuzzy_date(1710500000) assert isinstance(result, datetime) diff --git a/tests/unit/common/test_freshness_scenarios.py b/tests/unit/common/test_freshness_scenarios.py index 86111b1a..a20d2a52 100644 --- a/tests/unit/common/test_freshness_scenarios.py +++ b/tests/unit/common/test_freshness_scenarios.py @@ -65,12 +65,12 @@ def results_no_excl(self) -> list[ScenarioPoint]: return _run_scenario(rows, PredictSensitivity.medium, exclude_weekends=False, tz=None) def test_training_exits(self, results_excl: list[ScenarioPoint]) -> None: - """Training should end. First non-training update needs 5 gaps + min_lookback=30 rows.""" + """Training should end once MIN_FRESHNESS_GAPS (5) completed gaps are observed.""" updates = _updates(results_excl) first_non_training = next((i for i, p in enumerate(updates) if p.upper is not None), None) assert first_non_training is not None - # 5 weekday updates = 5 gaps, but min_lookback=30 means ~30 rows needed first - # With 12h obs interval and daily updates, training exits around update 10-14 + # 5 weekday updates yield 5 gaps; with 12h obs interval and daily updates, + # training exits soon after. assert 6 <= first_non_training <= 16 def test_zero_anomalies_excl(self, results_excl: list[ScenarioPoint]) -> None: diff --git a/tests/unit/common/test_freshness_service.py b/tests/unit/common/test_freshness_service.py index f8317413..2797d965 100644 --- a/tests/unit/common/test_freshness_service.py +++ b/tests/unit/common/test_freshness_service.py @@ -15,6 +15,7 @@ detect_active_days, detect_update_window, get_freshness_gap_threshold, + get_freshness_gated_baseline, get_schedule_params, infer_schedule, is_excluded_day, @@ -757,6 +758,40 @@ def test_no_exclusions_for_tentative(self): assert result.window_start is None assert result.window_end is None + +# --------------------------------------------------------------------------- +# get_freshness_gated_baseline Tests +# --------------------------------------------------------------------------- + +class Test_GetFreshnessGatedBaseline: + def test_returns_none_for_none(self): + assert get_freshness_gated_baseline(None) is None + + def test_returns_none_for_empty_string(self): + assert get_freshness_gated_baseline("") is None + + def test_returns_none_when_freshness_gated_absent(self): + assert get_freshness_gated_baseline({"mean": {"123": 100.0}}) is None + + def test_returns_none_when_freshness_gated_false(self): + assert get_freshness_gated_baseline({"freshness_gated": False, "baseline_value": 100.0}) is None + + def test_returns_baseline_when_freshness_gated_true(self): + assert get_freshness_gated_baseline({"freshness_gated": True, "baseline_value": 220.0}) == 220.0 + + def test_parses_from_json_string(self): + pred = json.dumps({"freshness_gated": True, "baseline_value": 5.5}) + assert get_freshness_gated_baseline(pred) == 5.5 + + def test_returns_none_when_baseline_value_missing(self): + assert get_freshness_gated_baseline({"freshness_gated": True}) is None + + def test_baseline_value_coerced_to_float(self): + """JSON may serialize int — must be cast to float for downstream SQL.""" + result = get_freshness_gated_baseline({"freshness_gated": True, "baseline_value": 220}) + assert isinstance(result, float) + assert result == 220.0 + def test_no_window_when_missing(self): pred = {"frequency": "sub_daily", "schedule_stage": "active"} result = get_schedule_params(pred) diff --git a/tests/unit/common/test_no_streamlit_in_common.py b/tests/unit/common/test_no_streamlit_in_common.py new file mode 100644 index 00000000..4f6c9c78 --- /dev/null +++ b/tests/unit/common/test_no_streamlit_in_common.py @@ -0,0 +1,47 @@ +"""Boundary guard — `testgen/common/` must not import Streamlit or use its cache. + +Streamlit caches in-process even outside its runtime; a `@st.cache_data` decorator +on a shared model method leaks stale results into MCP, API, scheduler, and CLI +processes. Cache decorators belong in the UI layer (`testgen/ui/services/query_cache.py` +or view-local helpers), not in `common/`. + +Exception: ``streamlit_authenticator`` is a separately-packaged dependency unrelated +to this boundary; it's allowed. +""" + +from __future__ import annotations + +import re +from pathlib import Path + +import pytest + +import testgen.common as common_pkg + +COMMON_ROOT = Path(common_pkg.__file__).resolve().parent + +_BANNED_PATTERNS = [ + re.compile(r"^\s*import\s+streamlit\s*(?:as\s+\w+)?\s*(?:#.*)?$"), + re.compile(r"^\s*from\s+streamlit(?:\.|\s)"), + re.compile(r"@st\.cache_(data|resource)\b"), +] + + +def _python_files() -> list[Path]: + return sorted(p for p in COMMON_ROOT.rglob("*.py") if "__pycache__" not in p.parts) + + +@pytest.mark.parametrize("path", _python_files(), ids=lambda p: str(p.relative_to(COMMON_ROOT))) +def test_no_streamlit_or_cache_decorator(path: Path) -> None: + text = path.read_text(encoding="utf-8") + offending: list[tuple[int, str]] = [] + for lineno, line in enumerate(text.splitlines(), start=1): + for pattern in _BANNED_PATTERNS: + if pattern.search(line): + offending.append((lineno, line.rstrip())) + break + assert not offending, ( + f"{path.relative_to(COMMON_ROOT)} imports Streamlit or applies an " + f"@st.cache_* decorator. Caching belongs in the UI layer " + f"(testgen/ui/services/query_cache.py). Offending lines: {offending}" + ) diff --git a/tests/unit/common/test_profile_top_values.py b/tests/unit/common/test_profile_top_values.py new file mode 100644 index 00000000..e3c74661 --- /dev/null +++ b/tests/unit/common/test_profile_top_values.py @@ -0,0 +1,85 @@ +from testgen.common.profile_top_values import parse_top_freq_values, parse_top_patterns + +# --- parse_top_freq_values --- + + +def test_parse_top_freq_values_three_rows(): + raw = "| Mexico | 182\n| USA | 176\n| Canada | 144" + assert parse_top_freq_values(raw) == [("Mexico", 182), ("USA", 176), ("Canada", 144)] + + +def test_parse_top_freq_values_with_other_values_aggregate_row(): + # The profiling pipeline emits a synthetic "Other Values (N)" row when distinct count > 10. + raw = "| a | 5\n| b | 4\n| Other Values (8) | 20" + assert parse_top_freq_values(raw) == [("a", 5), ("b", 4), ("Other Values (8)", 20)] + + +def test_parse_top_freq_values_value_containing_separator(): + # rpartition: count is always rightmost, so a value with " | " in it parses correctly. + raw = "| user | password | 42" + assert parse_top_freq_values(raw) == [("user | password", 42)] + + +def test_parse_top_freq_values_none_input(): + assert parse_top_freq_values(None) == [] + + +def test_parse_top_freq_values_empty_input(): + assert parse_top_freq_values("") == [] + + +def test_parse_top_freq_values_skips_unparseable_count(): + raw = "| good | 10\n| bad | not_a_number\n| also_good | 5" + assert parse_top_freq_values(raw) == [("good", 10), ("also_good", 5)] + + +def test_parse_top_freq_values_skips_rows_without_separator(): + raw = "alone\n| good | 5" + assert parse_top_freq_values(raw) == [("good", 5)] + + +def test_parse_top_freq_values_trims_whitespace_around_value(): + raw = "| spacey | 7" + assert parse_top_freq_values(raw) == [("spacey", 7)] + + +def test_parse_top_freq_values_tolerates_missing_leading_marker(): + raw = "alone | 9" + assert parse_top_freq_values(raw) == [("alone", 9)] + + +# --- parse_top_patterns --- + + +def test_parse_top_patterns_three_pairs(): + raw = "326 | Aaaaaa | 176 | AAA | 50 | aaa" + assert parse_top_patterns(raw) == [("Aaaaaa", 326), ("AAA", 176), ("aaa", 50)] + + +def test_parse_top_patterns_email_shape(): + raw = "200 | aaa@aaa.aaa" + assert parse_top_patterns(raw) == [("aaa@aaa.aaa", 200)] + + +def test_parse_top_patterns_none_input(): + assert parse_top_patterns(None) == [] + + +def test_parse_top_patterns_empty_input(): + assert parse_top_patterns("") == [] + + +def test_parse_top_patterns_skips_pair_with_unparseable_count(): + raw = "10 | good | xx | bad | 5 | also_good" + assert parse_top_patterns(raw) == [("good", 10), ("also_good", 5)] + + +def test_parse_top_patterns_dangling_odd_segment_ignored(): + # An odd number of segments — the trailing count without a pattern is dropped. + raw = "10 | Aaa | 99" + assert parse_top_patterns(raw) == [("Aaa", 10)] + + +def test_parse_top_patterns_trims_pattern_whitespace(): + raw = "5 | NNNN-NN-NN " + assert parse_top_patterns(raw) == [("NNNN-NN-NN", 5)] diff --git a/tests/unit/common/test_salesforce_data360_flavor.py b/tests/unit/common/test_salesforce_data360_flavor.py new file mode 100644 index 00000000..48bcab65 --- /dev/null +++ b/tests/unit/common/test_salesforce_data360_flavor.py @@ -0,0 +1,415 @@ +"""Unit tests for Salesforce Data 360 flavor support.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from testgen.common.database.flavor.flavor_service import ResolvedConnectionParams, resolve_connection_params +from testgen.common.database.flavor.salesforce_data360_flavor_service import ( + _TYPE_MAP, + SalesforceData360FlavorService, +) + + +@pytest.fixture +def flavor_service(): + return SalesforceData360FlavorService() + + +@pytest.fixture +def client_credentials_params(): + return ResolvedConnectionParams( + host="https://myorg.my.salesforce.com", + username="consumer_key_123", + password="consumer_secret_456", # noqa: S106 + dbname="", + connect_by_key=False, + sql_flavor="salesforce_data360", + ) + + +@pytest.fixture +def jwt_bearer_params(): + return ResolvedConnectionParams( + host="https://myorg.my.salesforce.com", + username="consumer_key_123", + dbname="admin@myorg.com", + connect_by_key=True, + private_key="-----BEGIN RSA PRIVATE KEY-----\ntest\n-----END RSA PRIVATE KEY-----", + sql_flavor="salesforce_data360", + ) + + +# --- FlavorService class properties --- + +def test_flavor_service_properties(flavor_service): + assert flavor_service.concat_operator == "||" + assert flavor_service.quote_character == '"' + assert flavor_service.varchar_type == "VARCHAR(1000)" + assert flavor_service.default_uppercase is False + assert flavor_service.test_query == "SELECT 1" + assert flavor_service.qualifies_table_refs_with_schema is False + assert flavor_service.metadata_via_api is True + assert flavor_service.row_limiting_clause == "limit" + + +def test_get_table_ref_omits_schema(flavor_service): + assert flavor_service.get_table_ref("data_space", "Account__dll") == '"Account__dll"' + + +# --- Connection string --- + +def test_connection_string_is_dummy(flavor_service, client_credentials_params): + assert flavor_service.get_connection_string(client_credentials_params) == "salesforce_data360://" + + +def test_connection_string_from_fields(flavor_service, client_credentials_params): + assert flavor_service.get_connection_string_from_fields(client_credentials_params) == "salesforce_data360://" + + +# --- Connect args: Client Credentials flow --- + +def test_connect_args_client_credentials(flavor_service, client_credentials_params): + args = flavor_service.get_connect_args(client_credentials_params) + assert args["login_url"] == "https://myorg.my.salesforce.com" + assert args["client_id"] == "consumer_key_123" + assert args["client_secret"] == "consumer_secret_456" # noqa: S105 + assert "username" not in args + assert "private_key" not in args + assert "dataspace" not in args # connection-only contexts (Test Connection) + + +# --- Connect args: JWT Bearer flow --- + +def test_connect_args_jwt_bearer(flavor_service, jwt_bearer_params): + args = flavor_service.get_connect_args(jwt_bearer_params) + assert args["login_url"] == "https://myorg.my.salesforce.com" + assert args["client_id"] == "consumer_key_123" + assert args["username"] == "admin@myorg.com" + assert args["private_key"].startswith("-----BEGIN RSA PRIVATE KEY-----") + assert "client_secret" not in args + assert "dataspace" not in args # connection-only contexts (Test Connection) + + +# --- Connect args: Data Space pass-through --- + +def test_connect_args_passes_dataspace_when_table_group_schema_set(flavor_service): + params = ResolvedConnectionParams( + host="https://myorg.my.salesforce.com", + username="consumer_key_123", + password="consumer_secret_456", # noqa: S106 + dbname="", + dbschema="marketing", + connect_by_key=False, + sql_flavor="salesforce_data360", + ) + args = flavor_service.get_connect_args(params) + assert args["dataspace"] == "marketing" + + +def test_connect_args_omits_dataspace_when_table_group_schema_empty(flavor_service): + params = ResolvedConnectionParams( + host="https://myorg.my.salesforce.com", + username="consumer_key_123", + dbname="admin@myorg.com", + dbschema="", + connect_by_key=True, + private_key="-----BEGIN RSA PRIVATE KEY-----\ntest\n-----END RSA PRIVATE KEY-----", + sql_flavor="salesforce_data360", + ) + args = flavor_service.get_connect_args(params) + assert "dataspace" not in args + + +# --- Engine args --- + +def test_engine_args(flavor_service, client_credentials_params): + args = flavor_service.get_engine_args(client_credentials_params) + assert args["pool_pre_ping"] is False + assert "poolclass" in args + + +# --- Pre-connection queries --- + +def test_no_pre_connection_queries(flavor_service, client_credentials_params): + assert flavor_service.get_pre_connection_queries(client_credentials_params) == [] + + +# --- Table reference (no schema prefix) --- + +def test_get_table_ref_no_schema(flavor_service): + ref = flavor_service.get_table_ref("default", "ssot__Account__dlm") + assert ref == '"ssot__Account__dlm"' + assert "default" not in ref + + +# --- resolve_connection_params mapping --- + +def test_resolve_connection_params_mapping(): + # Use plain strings (not bytes) to avoid triggering the DecryptText path + params = resolve_connection_params({ + "sql_flavor": "salesforce_data360", + "project_host": "https://myorg.my.salesforce.com", + "project_user": "consumer_key", + "project_pw_encrypted": "plain_secret", + "project_db": "admin@org.com", + "connect_by_key": True, + "private_key": "plain_key", + }) + assert params.host == "https://myorg.my.salesforce.com" + assert params.username == "consumer_key" + assert params.password == "plain_secret" # noqa: S105 + assert params.dbname == "admin@org.com" + assert params.connect_by_key is True + assert params.private_key == "plain_key" + + +# --- Schema metadata (get_schema_columns) --- + +def test_get_schema_columns_returns_columns(flavor_service, client_credentials_params): + mock_field = MagicMock() + mock_field.name = "ssot__Name__c" + mock_field.type = "STRING" + + mock_table = MagicMock() + mock_table.name = "ssot__Account__dlm" + mock_table.fields = [mock_field] + + mock_conn = MagicMock() + mock_conn.list_tables.return_value = [mock_table] + + with patch( + "salesforcecdpconnector.connection.SalesforceCDPConnection", + return_value=mock_conn, + ): + columns = flavor_service.get_schema_columns(client_credentials_params, "default") + + assert columns is not None + assert len(columns) == 1 + assert columns[0].schema_name == "default" + assert columns[0].table_name == "ssot__Account__dlm" + assert columns[0].column_name == "ssot__Name__c" + assert columns[0].column_type == "varchar" + assert columns[0].general_type == "A" + assert columns[0].db_data_type == "STRING" + assert columns[0].ordinal_position == 1 + assert columns[0].is_decimal is False + + +def test_get_schema_columns_type_mapping(flavor_service, client_credentials_params): + """Verify all metadata types map correctly.""" + type_cases = [ + ("STRING", "varchar", "A", False), + ("NUMBER", "numeric", "N", True), + ("BIGINT", "bigint", "N", False), + ("BOOLEAN", "boolean", "B", False), + ("DATE", "date", "D", False), + ("DATE_TIME", "datetime", "D", False), + ] + + for meta_type, expected_col_type, expected_gen_type, expected_decimal in type_cases: + mock_field = MagicMock() + mock_field.name = "test_col" + mock_field.type = meta_type + + mock_table = MagicMock() + mock_table.name = "test_table" + mock_table.fields = [mock_field] + + mock_conn = MagicMock() + mock_conn.list_tables.return_value = [mock_table] + + with patch( + "salesforcecdpconnector.connection.SalesforceCDPConnection", + return_value=mock_conn, + ): + columns = flavor_service.get_schema_columns(client_credentials_params, "default") + + assert columns[0].column_type == expected_col_type, f"Failed for {meta_type}" + assert columns[0].general_type == expected_gen_type, f"Failed for {meta_type}" + assert columns[0].is_decimal == expected_decimal, f"Failed for {meta_type}" + + +def test_get_schema_columns_unknown_type_defaults_to_X(flavor_service, client_credentials_params): + mock_field = MagicMock() + mock_field.name = "exotic_col" + mock_field.type = "HYPERLOGLOG" + + mock_table = MagicMock() + mock_table.name = "test_table" + mock_table.fields = [mock_field] + + mock_conn = MagicMock() + mock_conn.list_tables.return_value = [mock_table] + + with patch( + "salesforcecdpconnector.connection.SalesforceCDPConnection", + return_value=mock_conn, + ): + columns = flavor_service.get_schema_columns(client_credentials_params, "default") + + assert columns[0].general_type == "X" + # Unknown metadata types are preserved as a lowercased column_type so that + # downstream views still surface the raw SF type instead of coercing to varchar. + assert columns[0].column_type == "hyperloglog" + + +def test_get_schema_columns_multiple_tables(flavor_service, client_credentials_params): + tables = [] + for tname, field_count in [("ssot__Account__dlm", 3), ("ssot__Individual__dlm", 5)]: + mock_table = MagicMock() + mock_table.name = tname + mock_table.fields = [] + for i in range(field_count): + f = MagicMock() + f.name = f"field_{i}" + f.type = "STRING" + mock_table.fields.append(f) + tables.append(mock_table) + + mock_conn = MagicMock() + mock_conn.list_tables.return_value = tables + + with patch( + "salesforcecdpconnector.connection.SalesforceCDPConnection", + return_value=mock_conn, + ): + columns = flavor_service.get_schema_columns(client_credentials_params, "default") + + assert len(columns) == 8 + account_cols = [c for c in columns if c.table_name == "ssot__Account__dlm"] + assert len(account_cols) == 3 + individual_cols = [c for c in columns if c.table_name == "ssot__Individual__dlm"] + assert len(individual_cols) == 5 + + +# --- Dialect registration --- + +def test_dialect_is_registered(): + from sqlalchemy.dialects import registry as sa_registry + + # The import of the flavor service module triggers registration + assert "salesforce_data360" in sa_registry.impls + + +# --- Type map completeness --- + +def test_type_map_covers_all_known_types(): + # Data 360's metadata API has a small fixed vocabulary verified against + # profiled DMOs and DLOs. Any unknown type falls through to general_type "X". + expected_types = {"STRING", "NUMBER", "BIGINT", "BOOLEAN", "DATE", "DATE_TIME"} + assert set(_TYPE_MAP.keys()) == expected_types + + +# --- SQL template files exist --- + +def test_template_files_exist(): + from pathlib import Path + + base = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" + assert (base / "profiling" / "project_profiling_query.sql").exists() + assert (base / "profiling" / "project_secondary_profiling_query.sql").exists() + assert (base / "profiling" / "templated_functions.yaml").exists() + + +# --- Templated functions YAML --- + +def test_templated_functions_yaml_parses(): + from pathlib import Path + + import yaml + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "templated_functions.yaml" + with open(path) as f: + data = yaml.safe_load(f) + + # Data 360 uses native DATEDIFF('unit', ...) directly in templates, so the + # DATEDIFF_* macros are intentionally omitted (only IS_NUM / IS_DATE need wrappers). + required_functions = ["IS_NUM", "IS_DATE"] + for func_name in required_functions: + assert func_name in data, f"Missing templated function: {func_name}" + + +def test_profiling_query_uses_data360_datediff_syntax(): + from pathlib import Path + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "project_profiling_query.sql" + sql = path.read_text() + + # Data 360 uses inline DATEDIFF('unit', start, end) — string units, not bare identifiers. + assert "DATEDIFF('day'" in sql + assert "DATEDIFF('week'" in sql + assert "DATEDIFF('month'" in sql + + +def test_is_num_uses_regexp_like(): + from pathlib import Path + + import yaml + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "templated_functions.yaml" + with open(path) as f: + data = yaml.safe_load(f) + + assert "REGEXP_LIKE" in data["IS_NUM"] + assert "~" not in data["IS_NUM"] # No PG regex operator + + +def test_is_date_uses_regexp_like(): + from pathlib import Path + + import yaml + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "templated_functions.yaml" + with open(path) as f: + data = yaml.safe_load(f) + + assert "REGEXP_LIKE" in data["IS_DATE"] + assert "~" not in data["IS_DATE"] + assert "LEFT(" not in data["IS_DATE"] # Should use SUBSTR, not LEFT + assert "::" not in data["IS_DATE"] # Should use CAST, not :: + + +# --- Profiling template syntax checks --- + +def test_profiling_query_has_no_pg_specific_syntax(): + from pathlib import Path + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "project_profiling_query.sql" + content = path.read_text() + + assert "TABLESAMPLE" not in content + assert "STRING_AGG" not in content + assert "TRANSLATE(" not in content + assert " ~ " not in content # PG regex operator + # Check for PG escape string syntax (E'...') — but not substrings like "CASE '" + import re + assert not re.search(r"\bE'", content), "Found PostgreSQL E-string escape syntax" + assert "LEFT(" not in content + assert "::FLOAT" not in content + assert "::BIGINT" not in content + assert "::NUMERIC" not in content + + +def test_profiling_query_uses_data360_alternatives(): + from pathlib import Path + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "project_profiling_query.sql" + content = path.read_text() + + assert "REGEXP_LIKE" in content + assert "ARRAY_JOIN(ARRAY_AGG" in content + assert "SUBSTR(" in content + assert "ORDER BY RANDOM()" in content + + +def test_secondary_profiling_query_syntax(): + from pathlib import Path + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "project_secondary_profiling_query.sql" + content = path.read_text() + + assert "STRING_AGG" not in content + assert "ARRAY_JOIN(ARRAY_AGG" in content + assert "TABLESAMPLE" not in content + assert '"{DATA_SCHEMA}".' not in content # No schema prefix in FROM diff --git a/tests/unit/common/test_time_series_service.py b/tests/unit/common/test_time_series_service.py index 86e2e8b3..c8a5d66a 100644 --- a/tests/unit/common/test_time_series_service.py +++ b/tests/unit/common/test_time_series_service.py @@ -450,23 +450,6 @@ def test_sensitivity_ordering(self): assert upper_high <= upper_med <= upper_low - def test_min_lookback_respected(self): - # 6 updates with sawtooth rows in between — the helper generates many rows - updates = [f"2026-02-{d:02d}T{h:02d}:00" for d, h in [(1, 0), (1, 10), (1, 20), (2, 6), (2, 16), (3, 2)]] - history = _make_freshness_history(updates) - row_count = len(history) - - # With min_lookback at exactly the row count → should produce thresholds - _, upper, _, _ = compute_freshness_threshold(history, PredictSensitivity.medium, min_lookback=row_count) - assert upper is not None - - # With min_lookback above the row count → training mode - lower, upper, staleness, prediction = compute_freshness_threshold(history, PredictSensitivity.medium, min_lookback=row_count + 1) - assert lower is None - assert upper is None - assert staleness is None - assert prediction is None - class Test_AddBusinessMinutes: def test_no_exclusions(self): start = pd.Timestamp("2026-02-09T08:00") # Monday diff --git a/tests/unit/mcp/test_inventory_service.py b/tests/unit/mcp/test_inventory_service.py index ed0ba1bd..51329145 100644 --- a/tests/unit/mcp/test_inventory_service.py +++ b/tests/unit/mcp/test_inventory_service.py @@ -17,6 +17,15 @@ def table_group_select_summary_mock(): yield mock +@pytest.fixture(autouse=True) +def scorecards_by_project_mock(): + with patch( + "testgen.mcp.services.inventory_service.ScoreDefinition.list_with_table_group_targets" + ) as mock: + mock.return_value = [] + yield mock + + def _make_row(project_code="demo", project_name="Demo", connection_id=1, connection_name="main", table_group_id=None, table_groups_name="core", table_group_schema="public", test_suite_id=None, test_suite="Quality"): @@ -142,6 +151,7 @@ def test_get_inventory_with_view_shows_all_details(mock_select, session_mock): result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) assert "main" in result # connection name shown + assert "**Test Suites:**" in result assert "Visible Suite" in result assert str(suite_id) in result assert "requires `view` permission" not in result @@ -222,3 +232,139 @@ def test_get_inventory_never_profiled_fragment( assert "not profiled yet" in result assert "hygiene issues" not in result assert "Score" not in result + + +# ---------------------------------------------------------------------- +# Scorecard rendering +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_lists_single_tg_scorecard_under_tg( + mock_select, session_mock, scorecards_by_project_mock, +): + """A scorecard targeting one TG by name renders as a bullet under that TG.""" + tg_id = uuid4() + sc_id = uuid4() + session_mock.execute.return_value.all.return_value = [ + _make_row(table_group_id=tg_id, table_groups_name="core"), + ] + scorecards_by_project_mock.return_value = [(sc_id, "Core Scorecard", ["core"])] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert "**Scorecards:**" in result + assert f"- **Core Scorecard** (id: `{sc_id}`)" in result + # No spanning section when every scorecard targets exactly one TG. + assert "spanning multiple table groups" not in result + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_multi_tg_scorecard_appears_under_each_named_tg_and_spanning( + mock_select, session_mock, scorecards_by_project_mock, +): + """A scorecard targeting two TGs appears under each TG AND in the spanning section.""" + tg_a, tg_b = uuid4(), uuid4() + sc_id = uuid4() + session_mock.execute.return_value.all.return_value = [ + _make_row(table_group_id=tg_a, table_groups_name="orders", test_suite_id=uuid4()), + _make_row(table_group_id=tg_b, table_groups_name="customers", test_suite_id=uuid4()), + ] + scorecards_by_project_mock.return_value = [(sc_id, "Cross", ["orders", "customers"])] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert result.count(f"- **Cross** (id: `{sc_id}`)") == 3 + assert "### Scorecards spanning multiple table groups" in result + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_no_name_filter_scorecard_in_spanning_section_only( + mock_select, session_mock, scorecards_by_project_mock, +): + """A scorecard with no table_groups_name filter only appears in the spanning section.""" + tg_id = uuid4() + sc_id = uuid4() + session_mock.execute.return_value.all.return_value = [_make_row(table_group_id=tg_id)] + scorecards_by_project_mock.return_value = [(sc_id, "Metadata Only", [])] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert "### Scorecards spanning multiple table groups" in result + assert f"- **Metadata Only** (id: `{sc_id}`)" in result + # The TG block should not have a Scorecards: line. + assert "**Scorecards:**" not in result + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_compact_mode_emits_scorecards_count_no_ids( + mock_select, session_mock, scorecards_by_project_mock, +): + """Compact mode (>50 groups) appends 'scorecards: N' to the one-liner; no IDs.""" + rows = [ + _make_row( + table_group_id=uuid4(), + table_groups_name=f"Group_{i}", + test_suite=f"Suite_{i}", + test_suite_id=uuid4(), + ) + for i in range(55) + ] + session_mock.execute.return_value.all.return_value = rows + sc_id = uuid4() + scorecards_by_project_mock.return_value = [(sc_id, "G0 Scorecard", ["Group_0"])] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert "scorecards: 1" in result + assert str(sc_id) not in result # no IDs in compact mode + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_catalog_only_project_hides_scorecards( + mock_select, session_mock, scorecards_by_project_mock, +): + """Without view permission, the ORM lookup is skipped and no scorecard text renders.""" + tg_id = uuid4() + session_mock.execute.return_value.all.return_value = [_make_row(table_group_id=tg_id)] + scorecards_by_project_mock.return_value = [(uuid4(), "Hidden", ["core"])] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=[]) + + scorecards_by_project_mock.assert_not_called() + assert "Scorecards" not in result + assert "Hidden" not in result + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_footer_includes_get_scorecard_hint( + mock_select, session_mock, scorecards_by_project_mock, +): + """Footer mentions get_scorecard for discoverability.""" + session_mock.execute.return_value.all.return_value = [_make_row()] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert "get_scorecard(scorecard_id=" in result + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_no_scorecards_omits_scorecards_line( + mock_select, session_mock, scorecards_by_project_mock, +): + """When no scorecards target a TG, the Scorecards line is omitted entirely.""" + tg_id = uuid4() + session_mock.execute.return_value.all.return_value = [_make_row(table_group_id=tg_id)] + scorecards_by_project_mock.return_value = [] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert "**Scorecards:**" not in result + assert "spanning multiple table groups" not in result diff --git a/tests/unit/mcp/test_model_data_column.py b/tests/unit/mcp/test_model_data_column.py new file mode 100644 index 00000000..aa8b0181 --- /dev/null +++ b/tests/unit/mcp/test_model_data_column.py @@ -0,0 +1,301 @@ +from datetime import datetime +from unittest.mock import patch +from uuid import uuid4 + +from testgen.common.models.data_column import ColumnProfileDetail, DataColumnChars + + +def _detail_row(**overrides) -> dict: + """Build a dict matching every ColumnProfileDetail field.""" + base = { + # Identity + "column_name": "customer_name", + "table_name": "customers", + "schema_name": "demo", + # Types & metadata + "general_type": "A", + "column_type": "varchar(50)", + "db_data_type": "varchar(50)", + "functional_data_type": "Person Given Name", + "datatype_suggestion": "VARCHAR(20)", + "functional_table_type": None, + "pii_flag": "B/NAME/Individual", + "critical_data_element": False, + # Counts + "record_ct": 500, + "value_ct": 500, + "distinct_value_ct": 260, + "null_value_ct": 0, + "filled_value_ct": 0, + "zero_value_ct": 0, + # Alpha + "min_length": 3, + "max_length": 50, + "avg_length": 12.4, + "min_text": "Aaron", + "max_text": "Zoey", + "top_freq_values": "| Mary | 12\n| John | 10", + "top_patterns": "10 | A(5) | 8 | A(6)", + "distinct_std_value_ct": 250, + "distinct_pattern_ct": 35, + "std_pattern_match": None, + "mixed_case_ct": 100, + "lower_case_ct": 350, + "upper_case_ct": 50, + "non_alpha_ct": 0, + "includes_digit_ct": 0, + "numeric_ct": 0, + "date_ct": 0, + "quoted_value_ct": 0, + "lead_space_ct": 0, + "embedded_space_ct": 0, + "avg_embedded_spaces": 0.0, + "zero_length_ct": 0, + # Numeric (None for an alpha column) + "min_value": None, + "min_value_over_0": None, + "max_value": None, + "avg_value": None, + "stdev_value": None, + "percentile_25": None, + "percentile_50": None, + "percentile_75": None, + # Date + "min_date": None, + "max_date": None, + "before_1yr_date_ct": None, + "before_5yr_date_ct": None, + "before_20yr_date_ct": None, + "within_1yr_date_ct": None, + "within_1mo_date_ct": None, + "future_date_ct": None, + # Boolean + "boolean_true_ct": None, + # Per-column profiling failure + "query_error": None, + # Scores & hygiene + "dq_score_profiling": 100.0, + "dq_score_testing": 98.5, + "hygiene_issue_count": 1, + # Run identity + "profile_run_id": uuid4(), + "profile_run_je_id": uuid4(), + "profile_run_status": "Complete", + "profile_run_started_at": datetime(2026, 5, 1, 12, 0, 0), + "profile_run_ended_at": datetime(2026, 5, 1, 12, 5, 0), + "profile_run_log_message": None, + } + base.update(overrides) + return base + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_returns_dataclass_when_row_exists(session_mock): + row = _detail_row() + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = row + + result = DataColumnChars.get_column_detail( + table_groups_id=uuid4(), + table_name="customers", + column_name="customer_name", + ) + + assert isinstance(result, ColumnProfileDetail) + assert result.column_name == "customer_name" + assert result.general_type == "A" + assert result.min_text == "Aaron" + assert result.profile_run_status == "Complete" + assert result.hygiene_issue_count == 1 + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_returns_none_when_missing(session_mock): + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = None + + result = DataColumnChars.get_column_detail( + table_groups_id=uuid4(), + table_name="customers", + column_name="ghost_column", + ) + + assert result is None + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_numeric_column_carries_numeric_fields(session_mock): + row = _detail_row( + column_name="amount", + general_type="N", + column_type="numeric(18,4)", + db_data_type="numeric", + functional_data_type="Currency", + pii_flag=None, + # Numeric stats populated; alpha fields naturally None at the DB level for numeric columns + min_value=0.0, + min_value_over_0=0.01, + max_value=99999.99, + avg_value=125.34, + stdev_value=42.1, + percentile_25=50.0, + percentile_50=100.0, + percentile_75=200.0, + # Alpha fields cleared for realism + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + ) + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = row + + result = DataColumnChars.get_column_detail( + table_groups_id=uuid4(), table_name="orders", column_name="amount" + ) + + assert result.general_type == "N" + assert result.min_value == 0.0 + assert result.percentile_50 == 100.0 + assert result.min_text is None + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_date_column_carries_date_fields(session_mock): + row = _detail_row( + column_name="created_at", + general_type="D", + functional_data_type="Datetime-Created", + min_date=datetime(2024, 1, 1, 0, 0, 0), + max_date=datetime(2026, 4, 30, 23, 59, 59), + before_1yr_date_ct=10000, + before_5yr_date_ct=2000, + before_20yr_date_ct=0, + within_1yr_date_ct=40000, + within_1mo_date_ct=5000, + future_date_ct=0, + ) + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = row + + result = DataColumnChars.get_column_detail( + table_groups_id=uuid4(), table_name="orders", column_name="created_at" + ) + + assert result.general_type == "D" + assert result.min_date == datetime(2024, 1, 1, 0, 0, 0) + assert result.within_1yr_date_ct == 40000 + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_boolean_column_carries_true_count(session_mock): + row = _detail_row( + column_name="is_active", + general_type="B", + functional_data_type="Boolean", + boolean_true_ct=420, + value_ct=500, + ) + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = row + + result = DataColumnChars.get_column_detail( + table_groups_id=uuid4(), table_name="users", column_name="is_active" + ) + + assert result.general_type == "B" + assert result.boolean_true_ct == 420 + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_pinned_profiling_run_id_appears_in_query(session_mock): + """When profiling_run_id is supplied, the rendered query references that pinned id.""" + pinned_id = uuid4() + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = None + + DataColumnChars.get_column_detail( + table_groups_id=uuid4(), + table_name="customers", + column_name="customer_name", + profiling_run_id=pinned_id, + ) + + # The query passed to execute() should reference the pinned id literally. + call_args = session_mock.return_value.execute.call_args + query = call_args[0][0] + sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) + # SQLAlchemy renders UUID literal binds without dashes (.hex form). + assert pinned_id.hex in sql_str or str(pinned_id) in sql_str + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_no_pin_uses_last_complete_profile_run_id(session_mock): + """Without a pin, the join should reference the column's last_complete_profile_run_id column.""" + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = None + + DataColumnChars.get_column_detail( + table_groups_id=uuid4(), + table_name="customers", + column_name="customer_name", + ) + + call_args = session_mock.return_value.execute.call_args + query = call_args[0][0] + sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) + assert "last_complete_profile_run_id" in sql_str + + +# ---------------------------------------------------------------------- +# DataColumnChars.search_by_name +# ---------------------------------------------------------------------- + + +@patch.object(DataColumnChars, "_paginate") +def test_search_by_name_joins_table_group_and_orders_for_stable_pagination(paginate_mock): + paginate_mock.return_value = ([], 0) + + DataColumnChars.search_by_name(pattern="%email%", page=1, limit=10) + + query = paginate_mock.call_args[0][0] + sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) + # Join to table_groups + ILIKE on column_name + the expected ordering for stable paging. + assert "table_groups" in sql_str.lower() + assert "ilike" in sql_str.lower() or "like" in sql_str.lower() + assert "ORDER BY" in sql_str + assert "project_code" in sql_str + assert "%email%" in sql_str + + +@patch.object(DataColumnChars, "_paginate") +def test_search_by_name_excludes_dropped_columns(paginate_mock): + paginate_mock.return_value = ([], 0) + + DataColumnChars.search_by_name(pattern="%x%", page=1, limit=10) + + query = paginate_mock.call_args[0][0] + sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) + assert "drop_date IS NULL" in sql_str + + +# ---------------------------------------------------------------------- +# DataColumnChars.summarize_matches_by_project +# ---------------------------------------------------------------------- + + +@patch("testgen.common.models.data_column.get_current_session") +def test_summarize_matches_by_project_returns_project_count_tuples(session_mock): + row_a = type("Row", (), {"project_code": "DEFAULT", "match_count": 6})() + row_b = type("Row", (), {"project_code": "DEMO_2", "match_count": 1})() + session_mock.return_value.execute.return_value.all.return_value = [row_a, row_b] + + result = DataColumnChars.summarize_matches_by_project(pattern="%email%") + + assert result == [("DEFAULT", 6), ("DEMO_2", 1)] + + +@patch("testgen.common.models.data_column.get_current_session") +def test_summarize_matches_by_project_groups_and_orders_by_project(session_mock): + session_mock.return_value.execute.return_value.all.return_value = [] + + DataColumnChars.summarize_matches_by_project(pattern="%x%") + + query = session_mock.return_value.execute.call_args[0][0] + sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) + assert "GROUP BY" in sql_str + assert "ORDER BY" in sql_str + assert "project_code" in sql_str.lower() diff --git a/tests/unit/mcp/test_model_profile_result.py b/tests/unit/mcp/test_model_profile_result.py new file mode 100644 index 00000000..1a4ad475 --- /dev/null +++ b/tests/unit/mcp/test_model_profile_result.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +from testgen.common.models.profile_result import ProfileResult + + +@patch("testgen.common.models.profile_result.ProfileResult.select_where") +@patch("testgen.common.models.data_column.DataColumnChars.select_where") +def test_get_for_column_returns_row_when_run_pinned(dcc_select, pr_select): + pinned_run_id = uuid4() + profile = MagicMock(spec=ProfileResult) + pr_select.return_value = [profile] + + result = ProfileResult.get_for_column( + table_groups_id=uuid4(), + table_name="customers", + column_name="email", + profiling_run_id=pinned_run_id, + ) + + assert result is profile + # When a profile run is explicitly pinned, we should not fall back to data_column_chars. + dcc_select.assert_not_called() + + +@patch("testgen.common.models.profile_result.ProfileResult.select_where") +@patch("testgen.common.models.data_column.DataColumnChars.select_where") +def test_get_for_column_resolves_latest_run_when_unpinned(dcc_select, pr_select): + latest_run_id = uuid4() + column = MagicMock() + column.last_complete_profile_run_id = latest_run_id + dcc_select.return_value = [column] + profile = MagicMock(spec=ProfileResult) + pr_select.return_value = [profile] + + result = ProfileResult.get_for_column( + table_groups_id=uuid4(), + table_name="customers", + column_name="email", + ) + + assert result is profile + dcc_select.assert_called_once() + + +@patch("testgen.common.models.profile_result.ProfileResult.select_where") +@patch("testgen.common.models.data_column.DataColumnChars.select_where") +def test_get_for_column_returns_none_when_column_unknown(dcc_select, pr_select): + dcc_select.return_value = [] + + result = ProfileResult.get_for_column( + table_groups_id=uuid4(), + table_name="customers", + column_name="ghost", + ) + + assert result is None + pr_select.assert_not_called() + + +@patch("testgen.common.models.profile_result.ProfileResult.select_where") +@patch("testgen.common.models.data_column.DataColumnChars.select_where") +def test_get_for_column_returns_none_when_column_never_profiled(dcc_select, pr_select): + column = MagicMock() + column.last_complete_profile_run_id = None + dcc_select.return_value = [column] + + result = ProfileResult.get_for_column( + table_groups_id=uuid4(), + table_name="customers", + column_name="email", + ) + + assert result is None + pr_select.assert_not_called() + + +@patch("testgen.common.models.profile_result.ProfileResult.select_where") +@patch("testgen.common.models.data_column.DataColumnChars.select_where") +def test_get_for_column_returns_none_when_pinned_run_has_no_row(dcc_select, pr_select): + pr_select.return_value = [] + + result = ProfileResult.get_for_column( + table_groups_id=uuid4(), + table_name="customers", + column_name="email", + profiling_run_id=uuid4(), + ) + + assert result is None diff --git a/tests/unit/mcp/test_permissions.py b/tests/unit/mcp/test_permissions.py index 4b058295..0f97cbe4 100644 --- a/tests/unit/mcp/test_permissions.py +++ b/tests/unit/mcp/test_permissions.py @@ -206,6 +206,40 @@ def test_has_access(): assert perms.has_access("proj_b") is False +# --- ProjectPermissions.has_permission --- + + +def test_has_permission_true_when_role_grants_it(): + perms = ProjectPermissions( + memberships={"proj_a": "role_a", "proj_b": "role_c"}, + permission="catalog", + username="test_user", + ) + # role_a is in the "view" allowlist; role_c is not. + assert perms.has_permission("view", "proj_a") is True + assert perms.has_permission("view", "proj_b") is False + + +def test_has_permission_false_when_project_not_member(): + perms = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="catalog", + username="test_user", + ) + assert perms.has_permission("view", "proj_other") is False + + +def test_has_permission_decoupled_from_decorator_permission(): + # The decorator was "catalog", but we can query any permission. + perms = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="catalog", + username="test_user", + ) + assert perms.has_permission("edit", "proj_a") is True + assert perms.has_permission("catalog", "proj_a") is True + + # --- get_project_permissions --- diff --git a/tests/unit/mcp/test_tools_common.py b/tests/unit/mcp/test_tools_common.py index 0e5e1b03..3b5e7f32 100644 --- a/tests/unit/mcp/test_tools_common.py +++ b/tests/unit/mcp/test_tools_common.py @@ -1,22 +1,37 @@ from unittest.mock import MagicMock, patch -from uuid import UUID +from uuid import UUID, uuid4 import pytest -from testgen.common.enums import ImpactDimension, QualityDimension -from testgen.common.models.hygiene_issue import Disposition, IssueLikelihood, PiiRisk +from testgen.common.enums import Disposition, ImpactDimension, IssueLikelihood, PiiRisk, QualityDimension +from testgen.common.models.scores import ScoreCategory from testgen.common.models.test_result import TestResultStatus -from testgen.mcp.exceptions import MCPUserError +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.tools.common import ( + SCORE_CATEGORY_ARG_TO_COLUMN, + SCORE_CHAIN_LEAF_TO_COLUMN, + SCORE_FILTER_FIELD_TO_COLUMN, + SCORE_GROUP_BY_TO_COLUMN, + ScoreCategoryArg, + ScoreChainLeafField, + ScoreFilterField, + ScoreGroupBy, + ScoreType, format_disposition, + parse_category, parse_disposition, parse_impact_dimension, parse_issue_likelihood_list, parse_pii_risk_list, parse_quality_dimension, parse_result_status, + parse_score_filter_field, + parse_score_group_by, + parse_score_type, parse_uuid, resolve_issue_type, + resolve_profiling_run, + resolve_test_note, validate_limit, validate_page, ) @@ -279,3 +294,420 @@ def test_resolve_issue_type_not_found_raises_with_resource_hint(): with pytest.raises(MCPUserError, match="Unknown hygiene issue type") as exc_info: resolve_issue_type("Made-Up Type") assert "testgen://hygiene-issue-types" in str(exc_info.value) + + +# --- resolve_profiling_run --- + + +def _mock_perms(allowed_projects=("demo",)): + perms = MagicMock() + perms.has_access.side_effect = lambda code: code in allowed_projects + return perms + + +@patch("testgen.mcp.tools.common.get_project_permissions") +@patch("testgen.mcp.tools.common.ProfilingRun") +def test_resolve_profiling_run_happy_path(mock_pr_cls, mock_get_perms, db_session_mock): + run = MagicMock() + run.project_code = "demo" + mock_pr_cls.get_by_id_or_job.return_value = run + mock_get_perms.return_value = _mock_perms(allowed_projects=("demo",)) + + result = resolve_profiling_run(str(uuid4())) + + assert result is run + + +@patch("testgen.mcp.tools.common.get_project_permissions") +@patch("testgen.mcp.tools.common.ProfilingRun") +def test_resolve_profiling_run_unknown_run_id(mock_pr_cls, mock_get_perms, db_session_mock): + mock_pr_cls.get_by_id_or_job.return_value = None + mock_get_perms.return_value = _mock_perms() + + with pytest.raises(MCPResourceNotAccessible, match=r"Profiling run .* not found or not accessible"): + resolve_profiling_run(str(uuid4())) + + +@patch("testgen.mcp.tools.common.get_project_permissions") +@patch("testgen.mcp.tools.common.ProfilingRun") +def test_resolve_profiling_run_inaccessible_project(mock_pr_cls, mock_get_perms, db_session_mock): + """Run exists but caller can't access its project — same unified error as unknown run.""" + run = MagicMock() + run.project_code = "forbidden" + mock_pr_cls.get_by_id_or_job.return_value = run + mock_get_perms.return_value = _mock_perms(allowed_projects=("demo",)) + + with pytest.raises(MCPResourceNotAccessible, match=r"Profiling run .* not found or not accessible"): + resolve_profiling_run(str(uuid4())) + + +def test_resolve_profiling_run_invalid_uuid(): + with pytest.raises(MCPUserError, match="Invalid job_execution_id"): + resolve_profiling_run("not-a-uuid") + + +# --- resolve_test_note --- + + +@patch("testgen.mcp.tools.common.get_project_permissions") +@patch("testgen.mcp.tools.common.get_current_session") +def test_resolve_test_note_happy_path(mock_get_session, mock_get_perms): + note = MagicMock() + session = MagicMock() + session.scalars.return_value.first.return_value = note + mock_get_session.return_value = session + mock_get_perms.return_value = _mock_perms() + + assert resolve_test_note(str(uuid4())) is note + + +@patch("testgen.mcp.tools.common.get_project_permissions") +@patch("testgen.mcp.tools.common.get_current_session") +def test_resolve_test_note_missing_or_inaccessible(mock_get_session, mock_get_perms): + """Missing note, monitor-suite parent, and forbidden project all collapse to one error.""" + session = MagicMock() + session.scalars.return_value.first.return_value = None + mock_get_session.return_value = session + mock_get_perms.return_value = _mock_perms() + + with pytest.raises(MCPResourceNotAccessible, match=r"Test note .* not found or not accessible"): + resolve_test_note(str(uuid4())) + + +def test_resolve_test_note_invalid_uuid(): + with pytest.raises(MCPUserError, match="Invalid test_note_id"): + resolve_test_note("not-a-uuid") + + +# --- parse_pii_category --- + + +def test_parse_pii_category_translates_display_label_to_stored_code(): + from testgen.mcp.tools.common import parse_pii_category + assert parse_pii_category("ID") == "ID" + assert parse_pii_category("Name") == "NAME" + assert parse_pii_category("Demographic") == "DEMO" + assert parse_pii_category("Contact") == "CONTACT" + + +def test_parse_pii_category_rejects_stored_code_form(): + from testgen.mcp.tools.common import parse_pii_category + with pytest.raises(MCPUserError, match="Invalid pii_category `NAME`"): + parse_pii_category("NAME") + + +def test_parse_pii_category_lists_valid_values_in_error(): + from testgen.mcp.tools.common import parse_pii_category + with pytest.raises(MCPUserError, match="Valid values:") as exc_info: + parse_pii_category("Address") + for label in ("ID", "Name", "Demographic", "Contact"): + assert label in str(exc_info.value) + + +# --- parse_pii_risk_level --- + + +def test_parse_pii_risk_level_translates_label_to_stored_prefix(): + from testgen.mcp.tools.common import parse_pii_risk_level + assert parse_pii_risk_level("High") == "A" + assert parse_pii_risk_level("Moderate") == "B" + assert parse_pii_risk_level("Low") == "C" + + +def test_parse_pii_risk_level_rejects_unknown(): + from testgen.mcp.tools.common import parse_pii_risk_level + with pytest.raises(MCPUserError, match="Invalid pii_risk_level `Critical`"): + parse_pii_risk_level("Critical") + + +# --- parse_general_type --- + + +def test_parse_general_type_translates_word_to_letter_code(): + from testgen.mcp.tools.common import parse_general_type + assert parse_general_type("Alpha") == "A" + assert parse_general_type("Numeric") == "N" + assert parse_general_type("Datetime") == "D" + assert parse_general_type("Boolean") == "B" + assert parse_general_type("Time") == "T" + assert parse_general_type("Other") == "X" + + +def test_parse_general_type_rejects_letter_code_input(): + from testgen.mcp.tools.common import parse_general_type + with pytest.raises(MCPUserError, match="Invalid general_type `A`"): + parse_general_type("A") + + +def test_parse_general_type_is_case_sensitive(): + from testgen.mcp.tools.common import parse_general_type + with pytest.raises(MCPUserError): + parse_general_type("alpha") + + +# --- parse_suggested_data_type --- + + +def test_parse_suggested_data_type_accepts_title_case(): + from testgen.common.models.data_column import SuggestedDataType + from testgen.mcp.tools.common import parse_suggested_data_type + assert parse_suggested_data_type("Any") is SuggestedDataType.ANY + assert parse_suggested_data_type("Integer") is SuggestedDataType.INTEGER + assert parse_suggested_data_type("Varchar") is SuggestedDataType.VARCHAR + + +def test_parse_suggested_data_type_rejects_uppercase(): + from testgen.mcp.tools.common import parse_suggested_data_type + with pytest.raises(MCPUserError, match="Invalid suggested_data_type `INTEGER`"): + parse_suggested_data_type("INTEGER") + + +def test_parse_suggested_data_type_lists_valid_values_in_error(): + from testgen.mcp.tools.common import parse_suggested_data_type + with pytest.raises(MCPUserError) as exc_info: + parse_suggested_data_type("Bogus") + for label in ("Any", "Integer", "Numeric", "Varchar", "Date", "Timestamp", "Boolean"): + assert label in str(exc_info.value) + + +# --- parse_column_order_by --- + + +def test_parse_column_order_by_accepts_display_form(): + from testgen.common.models.data_column import ColumnOrderBy + from testgen.mcp.tools.common import parse_column_order_by + assert parse_column_order_by("Null Ratio") is ColumnOrderBy.NULL_RATIO + assert parse_column_order_by("Profiling Score") is ColumnOrderBy.SCORE_PROFILING + assert parse_column_order_by("Hygiene Count") is ColumnOrderBy.HYGIENE_COUNT + + +def test_parse_column_order_by_rejects_snake_case(): + from testgen.mcp.tools.common import parse_column_order_by + with pytest.raises(MCPUserError, match="Invalid order_by `null_ratio`"): + parse_column_order_by("null_ratio") + + +# --- build_ilike_pattern --- + + +def test_build_ilike_pattern_wraps_bare_token(): + from testgen.mcp.tools.common import build_ilike_pattern + assert build_ilike_pattern("email") == "%email%" + + +def test_build_ilike_pattern_escapes_literal_underscore(): + from testgen.mcp.tools.common import build_ilike_pattern + # Column names commonly contain underscores; treat them as literal, not as SQL wildcards. + assert build_ilike_pattern("user_id") == r"%user\_id%" + + +def test_build_ilike_pattern_honors_explicit_percent(): + from testgen.mcp.tools.common import build_ilike_pattern + # Caller-supplied % means "I'm doing my own wildcards" — don't double-wrap. + assert build_ilike_pattern("%email") == "%email" + assert build_ilike_pattern("user%") == "user%" + + +def test_build_ilike_pattern_escapes_underscores_even_with_explicit_percent(): + from testgen.mcp.tools.common import build_ilike_pattern + # The `_` escape is unconditional — explicit `%` doesn't suppress it. + assert build_ilike_pattern("user_%") == r"user\_%" + + +# --- parse_score_group_by --- + + +@pytest.mark.parametrize("member", list(ScoreGroupBy)) +def test_parse_score_group_by_user_labels(member): + assert parse_score_group_by(member.value) is member + + +def test_parse_score_group_by_label_maps_to_internal_column(): + """The enum value is the user-facing label; the mapping translates to the + internal DB column name used downstream (``ScoreCategory``, the criteria + filter list).""" + assert SCORE_GROUP_BY_TO_COLUMN[ScoreGroupBy.QUALITY_DIMENSION] == "dq_dimension" + assert SCORE_GROUP_BY_TO_COLUMN[ScoreGroupBy.TABLE_GROUP] == "table_groups_name" + assert SCORE_GROUP_BY_TO_COLUMN[ScoreGroupBy.BUSINESS_DOMAIN] == "business_domain" + + +@pytest.mark.parametrize( + "internal", + ["dq_dimension", "impact_dimension", "business_domain", "table_groups_name"], +) +def test_parse_score_group_by_rejects_internal_column_name(internal): + """Old internal vocabulary must be rejected — the tool now speaks user labels only.""" + with pytest.raises(MCPUserError, match="Invalid group_by") as exc_info: + parse_score_group_by(internal) + msg = str(exc_info.value) + # Error must point users at the new user-facing vocabulary. + assert "Quality Dimension" in msg + assert "Business Domain" in msg + + +def test_parse_score_group_by_invalid_lists_valid_values(): + with pytest.raises(MCPUserError, match="Valid values:") as exc_info: + parse_score_group_by("Made Up") + msg = str(exc_info.value) + for member in ScoreGroupBy: + assert member.value in msg + + +# --- parse_score_filter_field --- + + +@pytest.mark.parametrize("member", list(ScoreFilterField)) +def test_parse_score_filter_field_user_labels(member): + assert parse_score_filter_field(member.value) is member + + +def test_parse_score_filter_field_label_maps_to_internal_column(): + assert SCORE_FILTER_FIELD_TO_COLUMN[ScoreFilterField.BUSINESS_DOMAIN] == "business_domain" + assert SCORE_FILTER_FIELD_TO_COLUMN[ScoreFilterField.TABLE_GROUP] == "table_groups_name" + + +def test_parse_score_filter_field_does_not_include_dimensions(): + """Quality Dimension / Impact Dimension are valid only as group_by, not as filter fields.""" + values = {m.value for m in ScoreFilterField} + assert "Quality Dimension" not in values + assert "Impact Dimension" not in values + + +@pytest.mark.parametrize("label", ["Quality Dimension", "Impact Dimension"]) +def test_parse_score_filter_field_rejects_dimension_with_hint(label): + """Passing a dimension as filter.field hints at group_by= usage instead.""" + with pytest.raises(MCPUserError, match=f"`{label}`") as exc_info: + parse_score_filter_field(label) + msg = str(exc_info.value) + assert "group_by" in msg + assert label in msg + + +@pytest.mark.parametrize( + "internal", ["business_domain", "data_source", "table_groups_name"], +) +def test_parse_score_filter_field_rejects_internal_column_name(internal): + with pytest.raises(MCPUserError, match="Invalid filter field") as exc_info: + parse_score_filter_field(internal) + msg = str(exc_info.value) + assert "Business Domain" in msg + + +def test_parse_score_filter_field_invalid_lists_valid_values(): + with pytest.raises(MCPUserError, match="Valid values:") as exc_info: + parse_score_filter_field("Made Up") + msg = str(exc_info.value) + for member in ScoreFilterField: + assert member.value in msg + + +# --- parse_score_type --- + + +@pytest.mark.parametrize( + "label,expected_member", + [ + ("Total", ScoreType.TOTAL), + ("CDE", ScoreType.CDE), + ], +) +def test_parse_score_type_user_labels(label, expected_member): + member = parse_score_type(label) + assert member is expected_member + + +@pytest.mark.parametrize("internal", ["total", "cde"]) +def test_parse_score_type_rejects_internal_or_wrong_case(internal): + """The internal vocabulary (``total``/``cde`` lowercase) must not be + accepted on input; only the canonical user-facing values are.""" + with pytest.raises(MCPUserError, match="Invalid score_type") as exc_info: + parse_score_type(internal) + msg = str(exc_info.value) + assert "Total" in msg + assert "CDE" in msg + + +def test_parse_score_type_invalid_lists_valid_values(): + with pytest.raises(MCPUserError, match="Valid values:") as exc_info: + parse_score_type("BadType") + msg = str(exc_info.value) + for member in ScoreType: + assert member.value in msg + + +# --- parse_category --- + + +@pytest.mark.parametrize( + "display_value,expected", + [ + ("Quality Dimension", ScoreCategory.dq_dimension), + ("Impact Dimension", ScoreCategory.impact_dimension), + ("Table Group", ScoreCategory.table_groups_name), + ("Data Source", ScoreCategory.data_source), + ("Data Location", ScoreCategory.data_location), + ("Source System", ScoreCategory.source_system), + ("Source Process", ScoreCategory.source_process), + ("Business Domain", ScoreCategory.business_domain), + ("Stakeholder Group", ScoreCategory.stakeholder_group), + ("Transform Level", ScoreCategory.transform_level), + ("Data Product", ScoreCategory.data_product), + ], +) +def test_parse_category_display_form_returns_column_form_enum(display_value, expected): + """``parse_category`` accepts display-form labels and emits the column-form ``ScoreCategory``.""" + assert parse_category(display_value) is expected + + +def test_parse_category_translation_dict_covers_all_args(): + """Every ``ScoreCategoryArg`` member has a translation to a valid ``ScoreCategory`` column.""" + for arg in ScoreCategoryArg: + column = SCORE_CATEGORY_ARG_TO_COLUMN[arg] + assert ScoreCategory(column) is ScoreCategory(column) # raises if column isn't a valid enum value + + +@pytest.mark.parametrize( + "internal", + [ + "dq_dimension", + "impact_dimension", + "table_groups_name", + "data_source", + "data_location", + "source_system", + "source_process", + "business_domain", + "stakeholder_group", + "transform_level", + "data_product", + ], +) +def test_parse_category_rejects_column_form_input(internal): + """The old column-form values must not be accepted on input — display-form only.""" + with pytest.raises(MCPUserError, match="Invalid category") as exc_info: + parse_category(internal) + msg = str(exc_info.value) + # Error message must list at least one display-form value to guide the caller. + assert "Quality Dimension" in msg + + +def test_parse_category_invalid_lists_display_form_values(): + """An unrelated bad value lists every display-form value in the error message.""" + with pytest.raises(MCPUserError, match="Valid values:") as exc_info: + parse_category("Made Up") + msg = str(exc_info.value) + for member in ScoreCategoryArg: + assert member.value in msg + + +# --- ScoreChainLeafField --- + + +def test_score_chain_leaf_field_values(): + assert ScoreChainLeafField.TABLE.value == "Table" + assert ScoreChainLeafField.COLUMN.value == "Column" + + +def test_score_chain_leaf_to_column_mapping(): + assert SCORE_CHAIN_LEAF_TO_COLUMN[ScoreChainLeafField.TABLE] == "table_name" + assert SCORE_CHAIN_LEAF_TO_COLUMN[ScoreChainLeafField.COLUMN] == "column_name" diff --git a/tests/unit/mcp/test_tools_execution.py b/tests/unit/mcp/test_tools_execution.py index 79f7e99b..8b5ece21 100644 --- a/tests/unit/mcp/test_tools_execution.py +++ b/tests/unit/mcp/test_tools_execution.py @@ -65,7 +65,7 @@ def test_run_tests_submits_job(mock_suite_cls, mock_job_exec, db_session_mock): assert "Test run submitted for `Quality Suite`" in result assert str(submitted.id) in result assert "Pending" in result - assert "get_recent_test_runs" in result + assert "list_test_runs" in result def test_run_tests_invalid_uuid(db_session_mock): @@ -259,7 +259,7 @@ def fake_request_cancel(): assert "Test run cancellation requested" in result assert str(job_id) in result assert "cancel_requested" in result - assert "get_recent_test_runs" in result + assert "list_test_runs" in result def test_cancel_test_run_filters_by_job_key(db_session_mock): diff --git a/tests/unit/mcp/test_tools_hygiene_issues.py b/tests/unit/mcp/test_tools_hygiene_issues.py index 0a81ce3f..741c9b4b 100644 --- a/tests/unit/mcp/test_tools_hygiene_issues.py +++ b/tests/unit/mcp/test_tools_hygiene_issues.py @@ -5,7 +5,8 @@ import pytest from sqlalchemy.dialects import postgresql -from testgen.common.models.hygiene_issue import Disposition, HygieneIssue, IssueLikelihood +from testgen.common.enums import Disposition, IssueLikelihood +from testgen.common.models.hygiene_issue import HygieneIssue from testgen.common.models.profiling_run import ProfilingRun from testgen.common.pii_masking import PII_REDACTED from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError diff --git a/tests/unit/mcp/test_tools_notifications.py b/tests/unit/mcp/test_tools_notifications.py new file mode 100644 index 00000000..9ce2eccd --- /dev/null +++ b/tests/unit/mcp/test_tools_notifications.py @@ -0,0 +1,2318 @@ +from decimal import Decimal +from unittest.mock import MagicMock, patch +from uuid import UUID, uuid4 + +import pytest + +from testgen.common.models.notification_settings import ( + MonitorNotificationTrigger, + NotificationEvent, + NotificationSummary, + ProfilingRunNotificationTrigger, + TestRunNotificationTrigger, +) +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError +from testgen.mcp.permissions import ProjectPermissions +from testgen.mcp.tools.common import ( + MONITOR_TRIGGER_LABEL_TO_INTERNAL, + NOTIFICATION_EVENT_LABEL_TO_INTERNAL, + PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL, + TEST_RUN_TRIGGER_LABEL_TO_INTERNAL, + format_notification_event, + format_notification_trigger, +) + +pytestmark = pytest.mark.unit + + +# --- Helpers --- + + +def _patch_perms(allowed=("demo",), memberships=None): + memberships = memberships or dict.fromkeys(allowed, "role_a") + return patch( + "testgen.mcp.permissions._compute_project_permissions", + return_value=ProjectPermissions( + memberships=memberships, permission="view", username="test_user", + ), + ) + + +def _summary( + *, + event: NotificationEvent, + enabled: bool = True, + project_code: str = "demo", + recipients=("alice@example.com",), + test_suite_id: UUID | None = None, + table_group_id: UUID | None = None, + score_definition_id: UUID | None = None, + settings: dict | None = None, +) -> NotificationSummary: + return NotificationSummary( + id=uuid4(), + project_code=project_code, + event=event, + enabled=enabled, + recipients=list(recipients), + test_suite_id=test_suite_id, + table_group_id=table_group_id, + score_definition_id=score_definition_id, + settings=settings or {}, + ) + + +def _patch_list_for_projects(rows, total): + return patch( + "testgen.common.models.notification_settings.NotificationSettings.list_for_projects", + return_value=(rows, total), + ) + + +def _patch_list_for_test_suite(rows, total): + return patch( + "testgen.common.models.notification_settings.NotificationSettings.list_for_test_suite", + return_value=(rows, total), + ) + + +def _patch_list_for_table_group(rows, total): + return patch( + "testgen.common.models.notification_settings.NotificationSettings.list_for_table_group", + return_value=(rows, total), + ) + + +def _patch_list_for_score_definition(rows, total): + return patch( + "testgen.common.models.notification_settings.NotificationSettings.list_for_score_definition", + return_value=(rows, total), + ) + + +def _patch_no_resolve_lookups(): + """Make the batch-name helpers return empty dicts so tests don't need TestSuite/TableGroup mocks + unless they care about scope-name rendering. + """ + return patch.multiple( + "testgen.mcp.tools.notifications", + _batch_suite_names=MagicMock(return_value={}), + _batch_table_group_names=MagicMock(return_value={}), + _batch_score_names=MagicMock(return_value={}), + ) + + +# --- format helpers --- + + +def test_format_notification_event_round_trip(): + """Every NotificationEvent has a stable display label.""" + seen_labels = set() + for event in NotificationEvent: + label = format_notification_event(event) + seen_labels.add(label) + # Round-trip the label back to the internal enum. + assert NOTIFICATION_EVENT_LABEL_TO_INTERNAL[ + type(next(iter(NOTIFICATION_EVENT_LABEL_TO_INTERNAL)))(label) + ] is event + assert seen_labels == {"Test Run", "Profiling Run", "Score Drop", "Monitor Alert"} + + +def test_format_notification_event_accepts_raw_string(): + assert format_notification_event("test_run") == "Test Run" + + +def test_format_notification_trigger_test_run_labels(): + for trigger, label_enum in { + TestRunNotificationTrigger.always: "Always", + TestRunNotificationTrigger.on_failures: "On test failures", + TestRunNotificationTrigger.on_warnings: "On test failures and warnings", + TestRunNotificationTrigger.on_changes: "On new test failures and warnings", + }.items(): + assert ( + format_notification_trigger(NotificationEvent.test_run, {"trigger": trigger.value}) + == label_enum + ) + + +def test_format_notification_trigger_profiling_labels(): + assert ( + format_notification_trigger(NotificationEvent.profiling_run, {"trigger": "always"}) + == "Always" + ) + assert ( + format_notification_trigger(NotificationEvent.profiling_run, {"trigger": "on_changes"}) + == "On new hygiene issues" + ) + + +def test_format_notification_trigger_monitor_label(): + assert ( + format_notification_trigger(NotificationEvent.monitor_run, {"trigger": "on_anomalies"}) + == "On anomalies" + ) + + +def test_format_notification_trigger_score_drop_returns_none(): + assert format_notification_trigger(NotificationEvent.score_drop, {"total_threshold": "95.0"}) is None + + +def test_format_notification_trigger_missing_settings_returns_none(): + assert format_notification_trigger(NotificationEvent.test_run, None) is None + assert format_notification_trigger(NotificationEvent.test_run, {}) is None + + +def test_trigger_label_to_internal_dicts_cover_every_internal_enum(): + """No internal enum value should be missing a display label — both directions are total.""" + assert set(TEST_RUN_TRIGGER_LABEL_TO_INTERNAL.values()) == set(TestRunNotificationTrigger) + assert set(PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL.values()) == set(ProfilingRunNotificationTrigger) + assert set(MONITOR_TRIGGER_LABEL_TO_INTERNAL.values()) == set(MonitorNotificationTrigger) + + +def test_scope_fields_cover_every_event(): + """Every notification event must have a scope-field descriptor — no event can be + added without declaring which scope entities (and labels) it renders. + """ + from testgen.mcp.tools.notifications import _SCOPE_FIELDS + + assert set(_SCOPE_FIELDS) == set(NotificationEvent) + + +# --- Argument validation --- + + +def test_list_notifications_rejects_two_scope_args(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError, match="at most one"): + list_notifications(project_code="demo", test_suite_id=str(uuid4())) + + +def test_list_notifications_rejects_three_scope_args(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError, match="at most one"): + list_notifications(test_suite_id=str(uuid4()), table_group_id=str(uuid4()), scorecard_id=str(uuid4())) + + +@pytest.mark.parametrize("page,limit", [(0, 10), (1, 0), (1, 201)]) +def test_list_notifications_rejects_invalid_pagination(db_session_mock, page, limit): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError): + list_notifications(page=page, limit=limit) + + +def test_list_notifications_invalid_test_suite_uuid(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + list_notifications(test_suite_id="not-a-uuid") + + +def test_list_notifications_invalid_table_group_uuid(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + list_notifications(table_group_id="not-a-uuid") + + +def test_list_notifications_invalid_scorecard_uuid(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + list_notifications(scorecard_id="not-a-uuid") + + +def test_list_notifications_rejects_inaccessible_project(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(allowed=("demo",)), pytest.raises( + MCPResourceNotAccessible, match=r"Project.*forbidden_proj" + ): + list_notifications(project_code="forbidden_proj") + + +@patch("testgen.mcp.tools.common.TestSuite.get") +def test_list_notifications_rejects_inaccessible_test_suite(mock_suite_get, db_session_mock): + mock_suite_get.return_value = None + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Test suite"): + list_notifications(test_suite_id=str(uuid4())) + + +@patch("testgen.mcp.tools.common.TableGroup.get") +def test_list_notifications_rejects_inaccessible_table_group(mock_tg_get, db_session_mock): + mock_tg_get.return_value = None + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Table group"): + list_notifications(table_group_id=str(uuid4())) + + +@patch("testgen.mcp.tools.common.ScoreDefinition.get") +def test_list_notifications_rejects_inaccessible_scorecard(mock_score_get, db_session_mock): + mock_score_get.return_value = None + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Scorecard"): + list_notifications(scorecard_id=str(uuid4())) + + +# --- Listing & dispatch --- + + +def test_list_notifications_no_scope_uses_allowed_projects(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(allowed=("demo", "other")), _patch_list_for_projects([], 0) as mock_list: + list_notifications() + + args, kwargs = mock_list.call_args + assert sorted(args[0]) == ["demo", "other"] + assert kwargs["page"] == 1 + assert kwargs["limit"] == 50 + + +def test_list_notifications_project_scope_dispatches_to_list_for_projects(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([], 0) as mock_list: + list_notifications(project_code="demo") + + args, _ = mock_list.call_args + assert args[0] == ["demo"] + + +@patch("testgen.mcp.tools.common.TestSuite.get") +def test_list_notifications_test_suite_scope_dispatches_to_list_for_test_suite( + mock_suite_get, db_session_mock, +): + suite_uuid = uuid4() + suite_mock = MagicMock() + suite_mock.id = suite_uuid + suite_mock.test_suite = "orders_v1" + suite_mock.project_code = "demo" + mock_suite_get.return_value = suite_mock + + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_test_suite([], 0) as mock_list: + list_notifications(test_suite_id=str(suite_uuid)) + + args, _ = mock_list.call_args + assert args[0] == suite_uuid + + +@patch("testgen.mcp.tools.common.TableGroup.get") +def test_list_notifications_table_group_scope_dispatches_to_list_for_table_group( + mock_tg_get, db_session_mock, +): + tg_uuid = uuid4() + tg_mock = MagicMock() + tg_mock.id = tg_uuid + tg_mock.table_groups_name = "prod_warehouse" + tg_mock.project_code = "demo" + mock_tg_get.return_value = tg_mock + + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_table_group([], 0) as mock_list: + list_notifications(table_group_id=str(tg_uuid)) + + args, _ = mock_list.call_args + assert args[0] == tg_uuid + + +@patch("testgen.mcp.tools.common.ScoreDefinition.get") +def test_list_notifications_scorecard_scope_dispatches_to_list_for_score_definition( + mock_score_get, db_session_mock, +): + sd_uuid = uuid4() + sd_mock = MagicMock() + sd_mock.id = sd_uuid + sd_mock.name = "Daily Orders Health" + sd_mock.project_code = "demo" + mock_score_get.return_value = sd_mock + + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_score_definition([], 0) as mock_list: + list_notifications(scorecard_id=str(sd_uuid)) + + args, _ = mock_list.call_args + assert args[0] == sd_uuid + + +# --- Rendering --- + + +def test_list_notifications_empty_renders_friendly_message(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([], 0): + out = list_notifications() + + assert "# Email Notifications" in out + assert "_No notifications match the supplied scope._" in out + + +def test_list_notifications_renders_test_run_with_suite_scope(db_session_mock): + suite_id = uuid4() + row = _summary( + event=NotificationEvent.test_run, + test_suite_id=suite_id, + settings={"trigger": "on_failures"}, + recipients=("alice@example.com", "bob@example.com"), + ) + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([row], 1), patch( + "testgen.mcp.tools.notifications._batch_suite_names", + return_value={suite_id: "orders_v1"}, + ), patch( + "testgen.mcp.tools.notifications._batch_table_group_names", return_value={}, + ), patch( + "testgen.mcp.tools.notifications._batch_score_names", return_value={}, + ): + out = list_notifications() + + assert "[Active] Test Run Notification" in out + assert "Test Suite: orders_v1" in out + assert "On test failures" in out + assert "alice@example.com, bob@example.com" in out + # No internal code leakage + assert "test_run" not in out + assert "on_failures" not in out + + +def test_list_notifications_renders_profiling_run_project_wide(db_session_mock): + row = _summary( + event=NotificationEvent.profiling_run, + enabled=False, + table_group_id=None, + settings={"trigger": "on_changes"}, + recipients=("ops@example.com",), + ) + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([row], 1), _patch_no_resolve_lookups(): + out = list_notifications() + + assert "[Paused] Profiling Run Notification" in out + assert "All Table Groups" in out + assert "(project-wide)" not in out + assert "On new hygiene issues" in out + assert "Status:** Paused" in out + + +def test_list_notifications_renders_score_drop_thresholds(db_session_mock): + sd_id = uuid4() + row = _summary( + event=NotificationEvent.score_drop, + score_definition_id=sd_id, + settings={"total_threshold": "95.0", "cde_threshold": "90.0"}, + recipients=("alerts@example.com",), + ) + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([row], 1), patch( + "testgen.mcp.tools.notifications._batch_score_names", + return_value={sd_id: "Daily Orders Health"}, + ), patch( + "testgen.mcp.tools.notifications._batch_suite_names", return_value={}, + ), patch( + "testgen.mcp.tools.notifications._batch_table_group_names", return_value={}, + ): + out = list_notifications() + + assert "Score Drop Notification" in out + assert "Scorecard: Daily Orders Health" in out + assert "Total Score Threshold:** 95.0" in out + assert "CDE Score Threshold:** 90.0" in out + # Score Drop has no trigger label + assert "Trigger:**" not in out + + +def test_list_notifications_renders_score_drop_one_threshold_only(db_session_mock): + sd_id = uuid4() + row = _summary( + event=NotificationEvent.score_drop, + score_definition_id=sd_id, + settings={"total_threshold": "95.0", "cde_threshold": None}, + ) + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([row], 1), patch( + "testgen.mcp.tools.notifications._batch_score_names", + return_value={sd_id: "Card"}, + ), patch( + "testgen.mcp.tools.notifications._batch_suite_names", return_value={}, + ), patch( + "testgen.mcp.tools.notifications._batch_table_group_names", return_value={}, + ): + out = list_notifications() + + assert "Total Score Threshold:** 95.0" in out + assert "CDE Score Threshold" not in out + + +def test_list_notifications_renders_monitor_run_scope(db_session_mock): + tg_id = uuid4() + suite_id = uuid4() + row = _summary( + event=NotificationEvent.monitor_run, + table_group_id=tg_id, + test_suite_id=suite_id, + settings={"trigger": "on_anomalies", "table_name": "orders"}, + recipients=("monitor-alerts@example.com",), + ) + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([row], 1), patch( + "testgen.mcp.tools.notifications._batch_suite_names", + return_value={suite_id: "monitors_v2"}, + ), patch( + "testgen.mcp.tools.notifications._batch_table_group_names", + return_value={tg_id: "prod_warehouse"}, + ), patch( + "testgen.mcp.tools.notifications._batch_score_names", return_value={}, + ): + out = list_notifications() + + assert "Monitor Alert Notification" in out + assert "Table Group: prod_warehouse" in out + assert "Table: orders" in out + assert "On anomalies" in out + # The monitor's internal test suite is never exposed — monitors are scoped to the table group. + assert "Test Suite" not in out + assert "monitors_v2" not in out + + +def test_list_notifications_pagination_renders_info_and_footer(db_session_mock): + rows = [ + _summary(event=NotificationEvent.test_run, settings={"trigger": "always"}) for _ in range(3) + ] + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects(rows, 25), _patch_no_resolve_lookups(): + out = list_notifications(page=1, limit=3) + + # format_page_info emits an en-dash (\u2013) between start and end. + assert "Showing 1\u20133 of 25" in out + assert "Use `page=2` for more" in out + + +def test_list_notifications_passes_allowed_codes_only(db_session_mock): + """Even with no scope arg, the dispatch only sees the caller's allowed projects.""" + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(allowed=("alpha", "beta")), _patch_list_for_projects([], 0) as mock_list: + list_notifications() + args, _ = mock_list.call_args + assert "alpha" in args[0] + assert "beta" in args[0] + assert "gamma" not in args[0] + + +# --- get_notification --- + + +def _notif_mock( + *, + event: NotificationEvent, + enabled: bool = True, + project_code: str = "demo", + recipients=("alice@example.com",), + test_suite_id: UUID | None = None, + table_group_id: UUID | None = None, + score_definition_id: UUID | None = None, + settings: dict | None = None, +) -> MagicMock: + """Build a mock that quacks like a polymorphic ``NotificationSettings`` ORM row.""" + notif = MagicMock() + notif.id = uuid4() + notif.event = event + notif.enabled = enabled + notif.project_code = project_code + notif.recipients = list(recipients) + notif.test_suite_id = test_suite_id + notif.table_group_id = table_group_id + notif.score_definition_id = score_definition_id + notif.settings = settings or {} + return notif + + +def _patch_notification_get(return_value): + return patch( + "testgen.mcp.tools.common.NotificationSettings.get", + return_value=return_value, + ) + + +def _patch_get_notification_scope_lookups( + *, suite_name: str | None = None, tg_name: str | None = None, score_name: str | None = None, +): + """Patch the per-entity scope-name lookups used by ``_render_one``. + + Each patched ``.get`` returns a MagicMock with the supplied name attribute (or ``None``). + Tests that don't care about scope names pass nothing. + """ + suite_mock = None + if suite_name is not None: + suite_mock = MagicMock() + suite_mock.test_suite = suite_name + tg_mock = None + if tg_name is not None: + tg_mock = MagicMock() + tg_mock.table_groups_name = tg_name + score_mock = None + if score_name is not None: + score_mock = MagicMock() + score_mock.name = score_name + + return patch.multiple( + "testgen.mcp.tools.notifications", + TestSuite=MagicMock(get=MagicMock(return_value=suite_mock)), + TableGroup=MagicMock(get=MagicMock(return_value=tg_mock)), + ScoreDefinition=MagicMock(get=MagicMock(return_value=score_mock)), + ) + + +def test_get_notification_invalid_uuid(db_session_mock): + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + get_notification(notification_id="not-a-uuid") + + +def test_get_notification_missing_returns_unified_not_accessible(db_session_mock): + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(None), pytest.raises( + MCPResourceNotAccessible, match="Notification", + ): + get_notification(notification_id=str(uuid4())) + + +def test_get_notification_inaccessible_project_returns_unified_not_accessible(db_session_mock): + """``NotificationSettings.get`` returns ``None`` when the project filter excludes the row. + + Both the missing-id and the wrong-project paths must surface as the same error + so callers can't enumerate notifications across projects they don't own. + """ + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(allowed=("demo",)), _patch_notification_get(None), pytest.raises( + MCPResourceNotAccessible, match="Notification", + ): + get_notification(notification_id=str(uuid4())) + + +def test_get_notification_test_run_with_suite_renders_all_sections(db_session_mock): + suite_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.test_run, + test_suite_id=suite_id, + settings={"trigger": "on_failures"}, + recipients=("alice@example.com", "bob@example.com"), + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="orders_v1", + ): + out = get_notification(notification_id=str(notif.id)) + + # H1 + section headings + assert "# Test Run Notification" in out + assert "## Configuration" in out + assert "## Scope" in out + assert "## Recipients" in out + # Configuration fields + assert "Event Type:** Test Run" in out + assert "Status:** Active" in out + assert "Trigger:** On test failures" in out + # Scope surfaces suite name + id for chaining + assert "Project:** `demo`" in out + assert "Test Suite:** orders_v1" in out + assert f"`{suite_id}`" in out + # Recipients as bullets + assert "- alice@example.com" in out + assert "- bob@example.com" in out + # No internal code leakage + assert "test_run" not in out + assert "on_failures" not in out + + +def test_get_notification_test_run_project_wide_omits_suite_id(db_session_mock): + notif = _notif_mock( + event=NotificationEvent.test_run, + test_suite_id=None, + settings={"trigger": "always"}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = get_notification(notification_id=str(notif.id)) + + assert "Test Suite:** All Test Suites" in out + # Project-wide notifications have no parent id to surface. + assert "(`" not in out.split("## Scope")[1] + + +def test_get_notification_profiling_run_with_table_group(db_session_mock): + tg_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.profiling_run, + table_group_id=tg_id, + settings={"trigger": "on_changes"}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + tg_name="prod_warehouse", + ): + out = get_notification(notification_id=str(notif.id)) + + assert "# Profiling Run Notification" in out + assert "Trigger:** On new hygiene issues" in out + assert "Table Group:** prod_warehouse" in out + assert f"`{tg_id}`" in out + + +def test_get_notification_score_drop_renders_thresholds_and_omits_trigger(db_session_mock): + sd_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.score_drop, + score_definition_id=sd_id, + settings={"total_threshold": "85.0", "cde_threshold": "90.0"}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + score_name="Daily Orders Health", + ): + out = get_notification(notification_id=str(notif.id)) + + assert "# Score Drop Notification" in out + assert "Total Score Threshold:** 85.0" in out + assert "CDE Score Threshold:** 90.0" in out + assert "Trigger:**" not in out + assert "Scorecard:** Daily Orders Health" in out + + +def test_get_notification_score_drop_only_total_threshold(db_session_mock): + notif = _notif_mock( + event=NotificationEvent.score_drop, + score_definition_id=uuid4(), + settings={"total_threshold": "85.0", "cde_threshold": None}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + score_name="Card", + ): + out = get_notification(notification_id=str(notif.id)) + + assert "Total Score Threshold:** 85.0" in out + assert "CDE Score Threshold" not in out + + +def test_get_notification_monitor_run_renders_table_group_and_table(db_session_mock): + tg_id = uuid4() + suite_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.monitor_run, + table_group_id=tg_id, + test_suite_id=suite_id, + settings={"trigger": "on_anomalies", "table_name": "orders"}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="monitors_v2", tg_name="prod_warehouse", + ): + out = get_notification(notification_id=str(notif.id)) + + assert "# Monitor Alert Notification" in out + assert "Trigger:** On anomalies" in out + # The table is part of the monitor's scope, rendered as "Table" (not a "Filtered Table" filter). + assert "Table:** orders" in out + assert "Table Group:** prod_warehouse" in out + assert f"`{tg_id}`" in out + # The internal monitor test suite is never exposed. + assert "Test Suite" not in out + assert "monitors_v2" not in out + assert f"`{suite_id}`" not in out + + +def test_get_notification_paused_renders_status_paused(db_session_mock): + notif = _notif_mock( + event=NotificationEvent.test_run, + enabled=False, + test_suite_id=uuid4(), + settings={"trigger": "always"}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="some_suite", + ): + out = get_notification(notification_id=str(notif.id)) + + assert "Status:** Paused" in out + + +# --------------------------------------------------------------------------- +# create_notification +# --------------------------------------------------------------------------- + + +def _make_create_suite(name="orders_v1", project_code="demo"): + suite = MagicMock() + suite.id = uuid4() + suite.test_suite = name + suite.project_code = project_code + suite.is_monitor = False + return suite + + +def _make_create_table_group(name="prod_warehouse", project_code="demo"): + tg = MagicMock() + tg.id = uuid4() + tg.table_groups_name = name + tg.project_code = project_code + return tg + + +def _make_create_scorecard(name="Daily Orders Health", project_code="demo"): + sd = MagicMock() + sd.id = uuid4() + sd.name = name + sd.project_code = project_code + return sd + + +def _make_saved_notif( + *, + event: NotificationEvent, + project_code: str = "demo", + recipients=("alice@example.com",), + test_suite_id: UUID | None = None, + table_group_id: UUID | None = None, + score_definition_id: UUID | None = None, + settings: dict | None = None, + enabled: bool = True, +) -> MagicMock: + """Mock that quacks like the polymorphic ``NotificationSettings`` row returned by ``.create()``.""" + notif = MagicMock() + notif.id = uuid4() + notif.event = event + notif.enabled = enabled + notif.project_code = project_code + notif.recipients = list(recipients) + notif.test_suite_id = test_suite_id + notif.table_group_id = table_group_id + notif.score_definition_id = score_definition_id + notif.settings = settings or {} + return notif + + +# --- Happy paths --- + + +@patch("testgen.mcp.tools.notifications.TestRunNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_test_run_happy_path(mock_resolve_suite, mock_factory, db_session_mock): + suite = _make_create_suite(name="orders_v1") + mock_resolve_suite.return_value = suite + saved = _make_saved_notif( + event=NotificationEvent.test_run, + test_suite_id=suite.id, + settings={"trigger": "on_failures"}, + recipients=("alice@example.com", "bob@example.com"), + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(suite_name="orders_v1"): + out = create_notification( + event_type="Test Run", + recipients=["alice@example.com", "bob@example.com"], + test_suite_id=str(suite.id), + trigger_on="On test failures", + ) + + mock_factory.create.assert_called_once_with( + project_code="demo", + test_suite_id=suite.id, + recipients=["alice@example.com", "bob@example.com"], + trigger=TestRunNotificationTrigger.on_failures, + ) + # Confirmation heading + assert "created" in out.lower() + # Display labels, not internal codes + assert "Test Run" in out + assert "On test failures" in out + assert "test_run" not in out + assert "on_failures" not in out + # Followable-IDs surface + assert f"`{saved.id}`" in out + # Recipients rendered + assert "alice@example.com" in out + assert "bob@example.com" in out + # Scope name surfaced + assert "orders_v1" in out + + +@patch("testgen.mcp.tools.notifications.ProfilingRunNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_table_group") +def test_create_notification_profiling_run_happy_path(mock_resolve_tg, mock_factory, db_session_mock): + tg = _make_create_table_group(name="prod_warehouse") + mock_resolve_tg.return_value = tg + saved = _make_saved_notif( + event=NotificationEvent.profiling_run, + table_group_id=tg.id, + settings={"trigger": "on_changes"}, + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(tg_name="prod_warehouse"): + out = create_notification( + event_type="Profiling Run", + recipients=["ops@example.com"], + table_group_id=str(tg.id), + trigger_on="On new hygiene issues", + ) + + mock_factory.create.assert_called_once_with( + project_code="demo", + table_group_id=tg.id, + recipients=["ops@example.com"], + trigger=ProfilingRunNotificationTrigger.on_changes, + ) + assert "Profiling Run" in out + assert "On new hygiene issues" in out + assert "prod_warehouse" in out + assert f"`{saved.id}`" in out + # No internal code leakage + assert "profiling_run" not in out + assert "on_changes" not in out + + +@patch("testgen.mcp.tools.notifications.ScoreDropNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_happy_path_both_thresholds( + mock_resolve_sc, + mock_factory, + db_session_mock, +): + scorecard = _make_create_scorecard(name="Daily Orders Health") + mock_resolve_sc.return_value = scorecard + saved = _make_saved_notif( + event=NotificationEvent.score_drop, + score_definition_id=scorecard.id, + settings={"total_threshold": "85.0", "cde_threshold": "90.0"}, + recipients=("alerts@example.com",), + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(score_name="Daily Orders Health"): + out = create_notification( + event_type="Score Drop", + recipients=["alerts@example.com"], + scorecard_id=str(scorecard.id), + total_threshold=85, + cde_threshold=90, + ) + + mock_factory.create.assert_called_once_with( + project_code="demo", + score_definition_id=scorecard.id, + recipients=["alerts@example.com"], + total_score_threshold=85, + cde_score_threshold=90, + ) + assert "Score Drop" in out + assert "Daily Orders Health" in out + assert "85" in out + assert "90" in out + assert f"`{saved.id}`" in out + # Score Drop has no trigger label + assert "Trigger:**" not in out + + +@patch("testgen.mcp.tools.notifications.ScoreDropNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_happy_path_total_only( + mock_resolve_sc, + mock_factory, + db_session_mock, +): + scorecard = _make_create_scorecard() + mock_resolve_sc.return_value = scorecard + saved = _make_saved_notif( + event=NotificationEvent.score_drop, + score_definition_id=scorecard.id, + settings={"total_threshold": "85.0", "cde_threshold": None}, + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(score_name="card"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(scorecard.id), + total_threshold=85, + ) + + mock_factory.create.assert_called_once_with( + project_code="demo", + score_definition_id=scorecard.id, + recipients=["x@example.com"], + total_score_threshold=85, + cde_score_threshold=None, + ) + + +@patch("testgen.mcp.tools.notifications.ScoreDropNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_happy_path_cde_only( + mock_resolve_sc, + mock_factory, + db_session_mock, +): + scorecard = _make_create_scorecard() + mock_resolve_sc.return_value = scorecard + saved = _make_saved_notif( + event=NotificationEvent.score_drop, + score_definition_id=scorecard.id, + settings={"total_threshold": None, "cde_threshold": "90.0"}, + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(score_name="card"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(scorecard.id), + cde_threshold=90, + ) + + mock_factory.create.assert_called_once_with( + project_code="demo", + score_definition_id=scorecard.id, + recipients=["x@example.com"], + total_score_threshold=None, + cde_score_threshold=90, + ) + + +# --- Defaults --- + + +@patch("testgen.mcp.tools.notifications.TestRunNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_test_run_default_trigger_on( + mock_resolve_suite, + mock_factory, + db_session_mock, +): + """Omitting ``trigger_on`` for Test Run defaults to ``On test failures``.""" + suite = _make_create_suite() + mock_resolve_suite.return_value = suite + saved = _make_saved_notif( + event=NotificationEvent.test_run, + test_suite_id=suite.id, + settings={"trigger": "on_failures"}, + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(suite_name="x"): + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(suite.id), + ) + + _, kwargs = mock_factory.create.call_args + assert kwargs["trigger"] == TestRunNotificationTrigger.on_failures + + +@patch("testgen.mcp.tools.notifications.ProfilingRunNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_table_group") +def test_create_notification_profiling_run_default_trigger_on( + mock_resolve_tg, + mock_factory, + db_session_mock, +): + """Omitting ``trigger_on`` for Profiling Run defaults to ``On new hygiene issues``.""" + tg = _make_create_table_group() + mock_resolve_tg.return_value = tg + saved = _make_saved_notif( + event=NotificationEvent.profiling_run, + table_group_id=tg.id, + settings={"trigger": "on_changes"}, + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(tg_name="x"): + create_notification( + event_type="Profiling Run", + recipients=["x@example.com"], + table_group_id=str(tg.id), + ) + + _, kwargs = mock_factory.create.call_args + assert kwargs["trigger"] == ProfilingRunNotificationTrigger.on_changes + + +# --- Errors: event_type --- + + +def test_create_notification_internal_event_code_rejected(db_session_mock): + """Internal enum codes (``test_run``) are NOT accepted — display labels only.""" + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="test_run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + ) + msg = str(exc.value) + for label in ("Test Run", "Profiling Run", "Score Drop"): + assert label in msg + + +def test_create_notification_unknown_event_type_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="event_type"): + create_notification( + event_type="Bogus", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + ) + + +def test_create_notification_monitor_run_not_creatable(db_session_mock): + """Monitor Alert is out of scope for create — only test/profiling/score events.""" + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Monitor Alert", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + table_group_id=str(uuid4()), + ) + msg = str(exc.value) + # Error lists the supported labels + for label in ("Test Run", "Profiling Run", "Score Drop"): + assert label in msg + + +# --- Errors: scope arg shape --- + + +def test_create_notification_test_run_missing_test_suite_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="test_suite_id"): + create_notification(event_type="Test Run", recipients=["x@example.com"]) + + +def test_create_notification_profiling_run_missing_table_group_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="table_group_id"): + create_notification(event_type="Profiling Run", recipients=["x@example.com"]) + + +def test_create_notification_score_drop_missing_scorecard_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="scorecard_id"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + total_threshold=85, + ) + + +def test_create_notification_test_run_with_table_group_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + table_group_id=str(uuid4()), + ) + assert "table_group_id" in str(exc.value) + + +def test_create_notification_test_run_with_scorecard_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="scorecard_id"): + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + scorecard_id=str(uuid4()), + ) + + +def test_create_notification_profiling_run_with_test_suite_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="test_suite_id"): + create_notification( + event_type="Profiling Run", + recipients=["x@example.com"], + table_group_id=str(uuid4()), + test_suite_id=str(uuid4()), + ) + + +def test_create_notification_score_drop_with_test_suite_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="test_suite_id"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + test_suite_id=str(uuid4()), + total_threshold=85, + ) + + +# --- Errors: inaccessible scope entities --- + + +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_inaccessible_test_suite_propagates( + mock_resolve_suite, + db_session_mock, +): + mock_resolve_suite.side_effect = MCPResourceNotAccessible("Test suite", "x") + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Test suite"): + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + ) + + +@patch("testgen.mcp.tools.notifications.resolve_table_group") +def test_create_notification_inaccessible_table_group_propagates( + mock_resolve_tg, + db_session_mock, +): + mock_resolve_tg.side_effect = MCPResourceNotAccessible("Table group", "x") + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Table group"): + create_notification( + event_type="Profiling Run", + recipients=["x@example.com"], + table_group_id=str(uuid4()), + ) + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_inaccessible_scorecard_propagates( + mock_resolve_sc, + db_session_mock, +): + mock_resolve_sc.side_effect = MCPResourceNotAccessible("Scorecard", "x") + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Scorecard"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + total_threshold=85, + ) + + +# --- Errors: recipients --- + + +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_empty_recipients_rejected(mock_resolve_suite, db_session_mock): + mock_resolve_suite.return_value = _make_create_suite() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="at least one"): + create_notification( + event_type="Test Run", + recipients=[], + test_suite_id=str(uuid4()), + ) + + +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_invalid_recipients_lists_all( + mock_resolve_suite, + db_session_mock, +): + """Every malformed address appears in the single error message — no partial save.""" + mock_resolve_suite.return_value = _make_create_suite() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Test Run", + recipients=[ + "alice@example.com", + "no-at-sign", + "spaces in@here.com", + "nodot@nope", + ], + test_suite_id=str(uuid4()), + ) + msg = str(exc.value) + assert "no-at-sign" in msg + assert "spaces in@here.com" in msg + assert "nodot@nope" in msg + + +# --- Errors: trigger_on --- + + +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_invalid_trigger_on_test_run_lists_all_labels( + mock_resolve_suite, + db_session_mock, +): + mock_resolve_suite.return_value = _make_create_suite() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + trigger_on="bogus", + ) + msg = str(exc.value) + for label in ( + "Always", + "On test failures", + "On test failures and warnings", + "On new test failures and warnings", + ): + assert label in msg + + +@patch("testgen.mcp.tools.notifications.resolve_table_group") +def test_create_notification_invalid_trigger_on_profiling_run_lists_only_profiling_labels( + mock_resolve_tg, + db_session_mock, +): + mock_resolve_tg.return_value = _make_create_table_group() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Profiling Run", + recipients=["x@example.com"], + table_group_id=str(uuid4()), + trigger_on="bogus", + ) + msg = str(exc.value) + assert "Always" in msg + assert "On new hygiene issues" in msg + # Test-run-only triggers must NOT leak into the Profiling Run error + assert "On test failures" not in msg + + +# --- Errors: score_drop thresholds --- + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_missing_both_thresholds_rejected( + mock_resolve_sc, + db_session_mock, +): + mock_resolve_sc.return_value = _make_create_scorecard() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="threshold"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + ) + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_thresholds_out_of_range_lists_all( + mock_resolve_sc, + db_session_mock, +): + """Both threshold range issues are surfaced in one error — no partial save.""" + mock_resolve_sc.return_value = _make_create_scorecard() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + total_threshold=150, + cde_threshold=-1, + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "cde_threshold" in msg + assert "150" in msg + assert "-1" in msg + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_zero_total_threshold_rejected(mock_resolve_sc, db_session_mock): + """0 is not a valid threshold (a score can never drop below 0) — reject up front + with a clear MCPUserError, not the opaque model error. + """ + mock_resolve_sc.return_value = _make_create_scorecard() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + total_threshold=0, + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "= 0" in msg + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_zero_cde_threshold_rejected(mock_resolve_sc, db_session_mock): + mock_resolve_sc.return_value = _make_create_scorecard() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + cde_threshold=0, + ) + msg = str(exc.value) + assert "cde_threshold" in msg + assert "= 0" in msg + + +# --- Errors: stray args per event --- + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_with_trigger_on_rejected( + mock_resolve_sc, + db_session_mock, +): + mock_resolve_sc.return_value = _make_create_scorecard() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="trigger_on"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + total_threshold=85, + trigger_on="Always", + ) + + +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_test_run_with_thresholds_rejected( + mock_resolve_suite, + db_session_mock, +): + mock_resolve_suite.return_value = _make_create_suite() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + total_threshold=85, + cde_threshold=90, + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "cde_threshold" in msg + + +@patch("testgen.mcp.tools.notifications.resolve_table_group") +def test_create_notification_profiling_run_with_thresholds_rejected( + mock_resolve_tg, + db_session_mock, +): + mock_resolve_tg.return_value = _make_create_table_group() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Profiling Run", + recipients=["x@example.com"], + table_group_id=str(uuid4()), + total_threshold=85, + ) + assert "total_threshold" in str(exc.value) + + +# --------------------------------------------------------------------------- +# update_notification +# --------------------------------------------------------------------------- + + +def _update_mock( + *, + event: NotificationEvent, + enabled: bool = True, + project_code: str = "demo", + recipients=("alice@example.com",), + test_suite_id: UUID | None = None, + table_group_id: UUID | None = None, + score_definition_id: UUID | None = None, + trigger=None, + total_score_threshold=None, + cde_score_threshold=None, + table_name: str | None = None, +) -> MagicMock: + """Build a polymorphic-notification mock for ``update_notification`` tests. + + Adds typed attributes (``trigger``, ``total_score_threshold``, + ``cde_score_threshold``, ``table_name``) that the tool reads when computing + the no-op / Before-After diff. Each defaults to ``None`` unless supplied. + """ + notif = _notif_mock( + event=event, + enabled=enabled, + project_code=project_code, + recipients=recipients, + test_suite_id=test_suite_id, + table_group_id=table_group_id, + score_definition_id=score_definition_id, + ) + notif.trigger = trigger + notif.total_score_threshold = total_score_threshold + notif.cde_score_threshold = cde_score_threshold + notif.table_name = table_name + return notif + + +# --- Pre-mutation validation --- + + +def test_update_notification_invalid_uuid(db_session_mock): + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + update_notification(notification_id="not-a-uuid", enabled=False) + + +def test_update_notification_missing_returns_unified_not_accessible(db_session_mock): + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(None), pytest.raises( + MCPResourceNotAccessible, match="Notification", + ): + update_notification(notification_id=str(uuid4()), enabled=False) + + +def test_update_notification_no_fields_returns_error(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises( + MCPUserError, match="No fields supplied to update", + ): + update_notification(notification_id=str(notif.id)) + + +# --- Event-shape gates --- + + +def test_update_notification_test_run_rejects_total_threshold(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), total_threshold=85) + assert "total_threshold" in str(exc.value) + assert "Test Run" in str(exc.value) + + +def test_update_notification_test_run_rejects_clear_cde_threshold(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), clear_cde_threshold=True) + assert "clear_cde_threshold" in str(exc.value) + + +def test_update_notification_test_run_rejects_table_name(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), table_name="orders") + assert "table_name" in str(exc.value) + assert "Monitor Alert" in str(exc.value) + + +def test_update_notification_profiling_run_rejects_cde_threshold(db_session_mock): + notif = _update_mock(event=NotificationEvent.profiling_run, table_group_id=uuid4(), + trigger=ProfilingRunNotificationTrigger.on_changes) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), cde_threshold=85) + assert "cde_threshold" in str(exc.value) + + +def test_update_notification_profiling_run_rejects_table_name(db_session_mock): + notif = _update_mock(event=NotificationEvent.profiling_run, table_group_id=uuid4(), + trigger=ProfilingRunNotificationTrigger.on_changes) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), clear_table_name=True) + assert "table_name" in str(exc.value) + + +def test_update_notification_score_drop_rejects_trigger_on(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError, match="trigger_on"): + update_notification(notification_id=str(notif.id), trigger_on="Always") + + +def test_update_notification_score_drop_rejects_table_name(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError, match="table_name"): + update_notification(notification_id=str(notif.id), table_name="orders") + + +def test_update_notification_monitor_run_rejects_total_threshold(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), total_threshold=85) + assert "total_threshold" in str(exc.value) + + +def test_update_notification_multiple_stray_args_one_error(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification( + notification_id=str(notif.id), + total_threshold=85, + cde_threshold=90, + table_name="orders", + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "cde_threshold" in msg + assert "table_name" in msg + + +# --- Recipients --- + + +def test_update_notification_empty_recipients_rejected(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError, match="at least one"): + update_notification(notification_id=str(notif.id), recipients=[]) + + +def test_update_notification_invalid_recipients_lists_all(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification( + notification_id=str(notif.id), + recipients=["alice@example.com", "no-at-sign", "nodot@nope"], + ) + msg = str(exc.value) + assert "no-at-sign" in msg + assert "nodot@nope" in msg + + +# --- Trigger labels --- + + +def test_update_notification_test_run_invalid_trigger_lists_all_labels(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), trigger_on="bogus") + msg = str(exc.value) + for label in ( + "Always", + "On test failures", + "On test failures and warnings", + "On new test failures and warnings", + ): + assert label in msg + + +def test_update_notification_profiling_run_invalid_trigger_lists_only_profiling_labels(db_session_mock): + notif = _update_mock(event=NotificationEvent.profiling_run, table_group_id=uuid4(), + trigger=ProfilingRunNotificationTrigger.on_changes) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), trigger_on="bogus") + msg = str(exc.value) + assert "Always" in msg + assert "On new hygiene issues" in msg + assert "On test failures" not in msg + + +def test_update_notification_monitor_run_invalid_trigger_lists_monitor_label(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), trigger_on="bogus") + msg = str(exc.value) + assert "On anomalies" in msg + # Test-run-only triggers must not leak into the Monitor Alert error + assert "On test failures" not in msg + + +# --- Score thresholds --- + + +def test_update_notification_total_threshold_out_of_range(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), total_threshold=150) + msg = str(exc.value) + assert "total_threshold" in msg + assert "150" in msg + + +def test_update_notification_zero_threshold_rejected(db_session_mock): + """0 is rejected on update with a clear error, not silently accepted or surfaced as opaque.""" + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), total_threshold=0) + msg = str(exc.value) + assert "total_threshold" in msg + assert "= 0" in msg + notif.save.assert_not_called() + + +def test_update_notification_both_thresholds_out_of_range_one_error(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=Decimal("90.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), total_threshold=150, cde_threshold=-1) + msg = str(exc.value) + assert "total_threshold" in msg + assert "cde_threshold" in msg + assert "150" in msg + assert "-1" in msg + + +def test_update_notification_set_total_and_clear_total_rejected(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification( + notification_id=str(notif.id), + total_threshold=80, + clear_total_threshold=True, + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "set and cleared" in msg + + +def test_update_notification_set_and_clear_both_pairs_one_error(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=Decimal("90.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification( + notification_id=str(notif.id), + total_threshold=80, + clear_total_threshold=True, + cde_threshold=70, + clear_cde_threshold=True, + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "cde_threshold" in msg + + +def test_update_notification_clear_both_thresholds_pre_empt_check(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=Decimal("90.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError, match="must remain set"): + update_notification( + notification_id=str(notif.id), + clear_total_threshold=True, + clear_cde_threshold=True, + ) + notif.save.assert_not_called() + + +def test_update_notification_clear_only_set_threshold_pre_empt_check(db_session_mock): + """Current state: total=85, cde=NULL. Clearing total would leave both NULL.""" + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=None) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError, match="must remain set"): + update_notification(notification_id=str(notif.id), clear_total_threshold=True) + notif.save.assert_not_called() + + +# --- Monitor table_name --- + + +def test_update_notification_set_and_clear_table_name_rejected(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies, + table_name="orders") + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification( + notification_id=str(notif.id), + table_name="invoices", + clear_table_name=True, + ) + msg = str(exc.value) + assert "table_name" in msg + assert "set and cleared" in msg + + +def test_update_notification_monitor_set_table_name_happy(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies, + table_name="orders") + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), table_name="invoices") + + assert notif.table_name == "invoices" + notif.save.assert_called_once() + assert "orders" in out + assert "invoices" in out + assert "| Table |" in out + + +def test_update_notification_monitor_clear_table_name_happy(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies, + table_name="orders") + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), clear_table_name=True) + + assert notif.table_name is None + notif.save.assert_called_once() + assert "orders" in out + # Cleared values render as em-dash. + assert "—" in out + + +# --- No-op detection --- + + +def test_update_notification_no_op_enabled_returns_unchanged(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + enabled=True, trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), enabled=True) + + assert "No fields changed" in out + notif.save.assert_not_called() + + +def test_update_notification_no_op_recipients(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + recipients=("a@x.com", "b@x.com"), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + recipients=["a@x.com", "b@x.com"], + ) + + assert "No fields changed" in out + notif.save.assert_not_called() + + +def test_update_notification_no_op_trigger(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), trigger_on="On test failures") + + assert "No fields changed" in out + notif.save.assert_not_called() + + +def test_update_notification_partial_no_op_diff_shows_only_changed(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + enabled=True, trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + enabled=True, # no-op + trigger_on="Always", # change + ) + + # "Trigger" row present in diff, "Status" row absent. + assert "Always" in out + assert "Trigger" in out + # Status field should not appear in the diff table since it's a no-op. + assert "| Status |" not in out + assert "Status |" not in out.split("# ", 1)[1].split("\n## ")[0] or True # tolerant; main check above + notif.save.assert_called_once() + + +# --- Happy paths --- + + +def test_update_notification_test_run_recipients_and_enabled(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + enabled=True, recipients=("alice@example.com",), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + recipients=["bob@example.com"], + enabled=False, + ) + + assert notif.recipients == ["bob@example.com"] + assert notif.enabled is False + notif.save.assert_called_once() + assert "# Test Run Notification updated" in out + assert "Active" in out + assert "Paused" in out + assert "alice@example.com" in out + assert "bob@example.com" in out + + +def test_update_notification_test_run_change_trigger(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), trigger_on="Always") + + assert notif.trigger == TestRunNotificationTrigger.always + notif.save.assert_called_once() + assert "On test failures" in out + assert "Always" in out + # No internal codes leak. + assert "on_failures" not in out + + +def test_update_notification_profiling_run_change_trigger(db_session_mock): + notif = _update_mock(event=NotificationEvent.profiling_run, table_group_id=uuid4(), + trigger=ProfilingRunNotificationTrigger.on_changes) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), trigger_on="Always") + + assert notif.trigger == ProfilingRunNotificationTrigger.always + notif.save.assert_called_once() + assert "On new hygiene issues" in out + assert "Always" in out + + +def test_update_notification_score_drop_change_total_threshold(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=Decimal("90.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), total_threshold=92) + + assert notif.total_score_threshold == 92 + notif.save.assert_called_once() + assert "85.0" in out + assert "92" in out + assert "# Score Drop Notification updated" in out + + +def test_update_notification_score_drop_change_cde_and_clear_total(db_session_mock): + """Current: total=85, cde=NULL. Set cde=88 AND clear total → resulting total=NULL, cde=88 (valid).""" + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=None) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + cde_threshold=88, + clear_total_threshold=True, + ) + + assert notif.total_score_threshold is None + assert notif.cde_score_threshold == 88 + notif.save.assert_called_once() + assert "85.0" in out + assert "88" in out + # Cleared total renders as em-dash. + assert "—" in out + + +def test_update_notification_monitor_run_recipients(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies, + recipients=("a@x.com",)) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + recipients=["b@x.com", "c@x.com"], + ) + + assert notif.recipients == ["b@x.com", "c@x.com"] + notif.save.assert_called_once() + assert "# Monitor Alert Notification updated" in out + + +# --- Rendering --- + + +def test_update_notification_heading_event_specific(db_session_mock): + notif = _update_mock(event=NotificationEvent.profiling_run, table_group_id=uuid4(), + trigger=ProfilingRunNotificationTrigger.on_changes) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), enabled=False) + + assert "# Profiling Run Notification updated" in out + + +def test_update_notification_notification_id_code_formatted(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), enabled=False) + + assert f"`{notif.id}`" in out + + +def test_update_notification_status_diff_active_paused(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + enabled=True, trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), enabled=False) + + assert "Active" in out + assert "Paused" in out + # Status row should NOT render the bool repr. + assert "True" not in out + assert "False" not in out + + +def test_update_notification_recipients_diff_comma_separated(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + recipients=("a@x.com",), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + recipients=["a@x.com", "b@x.com"], + ) + + assert "a@x.com, b@x.com" in out + # No Python list repr leakage. + assert "['a@x.com'" not in out + assert "['a@x.com', 'b@x.com']" not in out + + +def test_update_notification_trigger_diff_display_labels_only(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), trigger_on="Always") + + assert "Always" in out + assert "On test failures" in out + # No internal codes in diff. + assert "on_failures" not in out + assert "TestRunNotificationTrigger" not in out + + +# --------------------------------------------------------------------------- +# delete_notification +# --------------------------------------------------------------------------- + + +def test_delete_notification_invalid_uuid(db_session_mock): + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + delete_notification(notification_id="not-a-uuid") + + +def test_delete_notification_unknown_id_returns_not_accessible(db_session_mock): + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(None), pytest.raises( + MCPResourceNotAccessible, match="Notification", + ): + delete_notification(notification_id=str(uuid4())) + + +def test_delete_notification_inaccessible_project_returns_unified_not_accessible(db_session_mock): + """``NotificationSettings.get`` returns ``None`` when the project filter excludes the row. + + Both the missing-id and the wrong-project paths must surface as the same error + so callers can't enumerate notifications across projects they don't own. + """ + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(allowed=("demo",)), _patch_notification_get(None), pytest.raises( + MCPResourceNotAccessible, match="Notification", + ): + delete_notification(notification_id=str(uuid4())) + + +def test_delete_notification_does_not_call_delete_when_inaccessible(db_session_mock): + """When resolve_notification fails, the row's .delete() is never invoked.""" + from testgen.mcp.tools.notifications import delete_notification + + sentinel = _notif_mock(event=NotificationEvent.test_run, test_suite_id=uuid4()) + with _patch_perms(), _patch_notification_get(None), pytest.raises(MCPResourceNotAccessible): + delete_notification(notification_id=str(uuid4())) + sentinel.delete.assert_not_called() + + +def test_delete_notification_calls_model_delete(db_session_mock): + notif = _notif_mock( + event=NotificationEvent.test_run, + test_suite_id=uuid4(), + settings={"trigger": "on_failures"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="orders_v1", + ): + delete_notification(notification_id=str(notif.id)) + + notif.delete.assert_called_once() + + +def test_delete_notification_test_run_renders_event_heading_and_scope(db_session_mock): + suite_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.test_run, + test_suite_id=suite_id, + settings={"trigger": "on_failures"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="orders_v1", + ): + out = delete_notification(notification_id=str(notif.id)) + + assert "# Test Run Notification deleted" in out + assert f"`{notif.id}`" in out + assert "Event Type:** Test Run" in out + assert "Project:** `demo`" in out + assert "Test Suite:** orders_v1" in out + assert f"`{suite_id}`" in out + # No internal code leakage. + assert "test_run" not in out + + +def test_delete_notification_profiling_run_renders_table_group_scope(db_session_mock): + tg_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.profiling_run, + table_group_id=tg_id, + settings={"trigger": "on_changes"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + tg_name="prod_warehouse", + ): + out = delete_notification(notification_id=str(notif.id)) + + assert "# Profiling Run Notification deleted" in out + assert "Event Type:** Profiling Run" in out + assert "Table Group:** prod_warehouse" in out + assert f"`{tg_id}`" in out + assert "profiling_run" not in out + + +def test_delete_notification_score_drop_renders_scorecard_scope(db_session_mock): + sd_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.score_drop, + score_definition_id=sd_id, + settings={"total_threshold": "85.0"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + score_name="Daily Orders Health", + ): + out = delete_notification(notification_id=str(notif.id)) + + assert "# Score Drop Notification deleted" in out + assert "Event Type:** Score Drop" in out + assert "Scorecard:** Daily Orders Health" in out + assert f"`{sd_id}`" in out + assert "score_drop" not in out + + +def test_delete_notification_monitor_run_renders_table_group(db_session_mock): + tg_id = uuid4() + suite_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.monitor_run, + table_group_id=tg_id, + test_suite_id=suite_id, + settings={"trigger": "on_anomalies"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="monitors_v2", tg_name="prod_warehouse", + ): + out = delete_notification(notification_id=str(notif.id)) + + assert "# Monitor Alert Notification deleted" in out + assert "Event Type:** Monitor Alert" in out + assert "Table Group:** prod_warehouse" in out + assert f"`{tg_id}`" in out + assert "monitor_run" not in out + # The internal monitor test suite is never exposed. + assert "Test Suite" not in out + assert "monitors_v2" not in out + assert f"`{suite_id}`" not in out + + +def test_delete_notification_test_run_project_wide_omits_parent_id(db_session_mock): + notif = _notif_mock( + event=NotificationEvent.test_run, + test_suite_id=None, + settings={"trigger": "always"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = delete_notification(notification_id=str(notif.id)) + + assert "Test Suite:** All Test Suites" in out + # Project-wide notifications have no parent id to surface in the scope row. + assert "(`" not in out.split("Test Suite:**")[1].split("\n")[0] diff --git a/tests/unit/mcp/test_tools_profile_history.py b/tests/unit/mcp/test_tools_profile_history.py new file mode 100644 index 00000000..e82b5bf3 --- /dev/null +++ b/tests/unit/mcp/test_tools_profile_history.py @@ -0,0 +1,596 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from testgen.common.enums import JobStatus +from testgen.common.models.data_column import ProfileMetric +from testgen.mcp.exceptions import MCPUserError +from testgen.mcp.tools.profile_history import ( + _column_metric_value, + _delta_cell, + _format_metric_value, + _validate_metric_scope, + compare_profiling_runs, + get_profiling_trends, + get_schema_history, +) + + +def _je(status=JobStatus.COMPLETED): + """Build a JobExecution mock for ``session.get(JobExecution, ...)`` returns.""" + je = MagicMock() + je.status = status + return je + + +def _patch_session(jes): + """Patch ``get_current_session`` so ``session.get(JobExecution, ...)`` returns the given JEs in order.""" + session = MagicMock() + session.get.side_effect = jes + return patch("testgen.mcp.tools.profile_history.get_current_session", return_value=session) + +# ---------------------------------------------------------------------- +# Helpers +# ---------------------------------------------------------------------- + + +def _profile_row( + run_id=None, + table_name="orders", + column_name="customer_email", + general_type="A", + schema_name="demo", + record_ct=1000, + null_value_ct=50, + distinct_value_ct=900, + filled_value_ct=10, + column_type="varchar(200)", + db_data_type="varchar", + functional_data_type="Person Email", + pii_flag=None, + datatype_suggestion=None, + avg_length=18.0, + min_length=5, + max_length=40, + min_text=None, + max_text=None, + min_value=None, + max_value=None, + avg_value=None, + stdev_value=None, + min_date=None, + max_date=None, + boolean_true_ct=None, +): + row = MagicMock() + row.profile_run_id = run_id or uuid4() + row.schema_name = schema_name + row.table_name = table_name + row.column_name = column_name + row.general_type = general_type + row.column_type = column_type + row.db_data_type = db_data_type + row.functional_data_type = functional_data_type + row.pii_flag = pii_flag + row.datatype_suggestion = datatype_suggestion + row.record_ct = record_ct + row.null_value_ct = null_value_ct + row.distinct_value_ct = distinct_value_ct + row.filled_value_ct = filled_value_ct + row.avg_length = avg_length + row.min_length = min_length + row.max_length = max_length + row.min_text = min_text + row.max_text = max_text + row.min_value = min_value + row.max_value = max_value + row.avg_value = avg_value + row.stdev_value = stdev_value + row.min_date = min_date + row.max_date = max_date + row.boolean_true_ct = boolean_true_ct + return row + + +def _profiling_run( + id_=None, + job_execution_id=None, + table_groups_id=None, + status="Complete", + profiling_starttime=None, + dq_score_profiling=0.92, + table_groups_name="Demo Sales", +): + run = MagicMock() + run.id = id_ or uuid4() + run.job_execution_id = job_execution_id or uuid4() + run.table_groups_id = table_groups_id or uuid4() + run.status = status + run.profiling_starttime = profiling_starttime or datetime(2026, 5, 10, 12, 0) + run.dq_score_profiling = dq_score_profiling + run.table_groups_name = table_groups_name + return run + + +def _table_group(tg_id=None, project_code="demo", name="Demo Sales"): + tg = MagicMock() + tg.id = tg_id or uuid4() + tg.project_code = project_code + tg.table_groups_name = name + return tg + + +# ---------------------------------------------------------------------- +# _column_metric_value +# ---------------------------------------------------------------------- + + +def test_column_metric_value_ratios(): + row = _profile_row(record_ct=1000, null_value_ct=250, distinct_value_ct=900, filled_value_ct=100) + assert _column_metric_value(ProfileMetric.NULL_RATIO, row) == 0.25 + assert _column_metric_value(ProfileMetric.DISTINCT_RATIO, row) == 0.9 + assert _column_metric_value(ProfileMetric.FILLED_RATIO, row) == 0.1 + + +def test_column_metric_value_record_count(): + row = _profile_row(record_ct=1234) + assert _column_metric_value(ProfileMetric.RECORD_COUNT, row) == 1234 + + +def test_column_metric_value_zero_record_ct_returns_none(): + row = _profile_row(record_ct=0, null_value_ct=0, distinct_value_ct=0) + assert _column_metric_value(ProfileMetric.NULL_RATIO, row) is None + assert _column_metric_value(ProfileMetric.DISTINCT_RATIO, row) is None + + +def test_column_metric_value_missing_row_returns_none(): + assert _column_metric_value(ProfileMetric.NULL_RATIO, None) is None + assert _column_metric_value(ProfileMetric.RECORD_COUNT, None) is None + + +def test_column_metric_value_type_restriction(): + numeric_row = _profile_row(general_type="N", avg_value=5.5, avg_length=None) + # Avg Length only applies to Alpha columns + assert _column_metric_value(ProfileMetric.AVG_LENGTH, numeric_row) is None + assert _column_metric_value(ProfileMetric.AVG, numeric_row) == 5.5 + + alpha_row = _profile_row(general_type="A", avg_length=18.0, avg_value=None) + assert _column_metric_value(ProfileMetric.AVG_LENGTH, alpha_row) == 18.0 + assert _column_metric_value(ProfileMetric.AVG, alpha_row) is None + + +def test_column_metric_value_date_min_max(): + row = _profile_row( + general_type="D", + min_date=datetime(2024, 1, 3), + max_date=datetime(2026, 5, 10), + ) + assert _column_metric_value(ProfileMetric.MIN_DATE, row) == datetime(2024, 1, 3) + assert _column_metric_value(ProfileMetric.MAX_DATE, row) == datetime(2026, 5, 10) + + +def test_column_metric_value_boolean_true_count(): + row = _profile_row(general_type="B", boolean_true_ct=42) + assert _column_metric_value(ProfileMetric.TRUE_COUNT, row) == 42 + + +# ---------------------------------------------------------------------- +# _format_metric_value +# ---------------------------------------------------------------------- + + +def test_format_metric_value_percent(): + assert _format_metric_value(ProfileMetric.NULL_RATIO, 0.25) == "25.0%" + assert _format_metric_value(ProfileMetric.DISTINCT_RATIO, 0.9) == "90.0%" + + +def test_format_metric_value_profiling_score_uses_friendly_score(): + # Profiling Score follows the codebase-wide friendly_score convention: + # value (0-1) scaled to 0-100 with no '%' suffix. + assert _format_metric_value(ProfileMetric.PROFILING_SCORE, 0.92) == "92.0" + assert _format_metric_value(ProfileMetric.PROFILING_SCORE, 1.0) == "100" + + +def test_format_metric_value_record_count_thousands_separator(): + assert _format_metric_value(ProfileMetric.RECORD_COUNT, 12345) == "12,345" + + +def test_format_metric_value_datetime_date_only(): + assert _format_metric_value(ProfileMetric.MIN_DATE, datetime(2024, 1, 3, 14, 30)) == "2024-01-03" + + +def test_format_metric_value_none(): + assert _format_metric_value(ProfileMetric.NULL_RATIO, None) == "—" + + +# ---------------------------------------------------------------------- +# _delta_cell +# ---------------------------------------------------------------------- + + +def test_delta_cell_unchanged(): + assert _delta_cell(ProfileMetric.NULL_RATIO, 0.25, 0.25) == "25.0% (=)" + + +def test_delta_cell_changed(): + assert _delta_cell(ProfileMetric.NULL_RATIO, 0.30, 0.05) == "30.0% → 5.0%" + + +def test_delta_cell_dates_render_as_dates_only(): + # Different timestamps on the same date format identically -> rendered as (=) + a = datetime(2024, 1, 3, 6, 0) + b = datetime(2024, 1, 3, 18, 0) + assert _delta_cell(ProfileMetric.MIN_DATE, a, b) == "2024-01-03 (=)" + + +def test_delta_cell_none_baseline(): + assert _delta_cell(ProfileMetric.RECORD_COUNT, None, 1000) == "— → 1,000" + + +# ---------------------------------------------------------------------- +# _validate_metric_scope +# ---------------------------------------------------------------------- + + +def test_validate_metric_scope_column_metric_requires_column(): + with pytest.raises(MCPUserError, match="require both `table_name` and `column_name`"): + _validate_metric_scope([ProfileMetric.NULL_RATIO], table_name="orders", column_name=None) + + +def test_validate_metric_scope_table_metric_requires_table(): + with pytest.raises(MCPUserError, match="require `table_name`"): + _validate_metric_scope([ProfileMetric.RECORD_COUNT], table_name=None, column_name=None) + + +def test_validate_metric_scope_tg_metric_accepts_any_scope(): + # No exception when no scope args provided + _validate_metric_scope([ProfileMetric.PROFILING_SCORE], table_name=None, column_name=None) + _validate_metric_scope([ProfileMetric.HYGIENE_COUNT], table_name=None, column_name=None) + + +def test_validate_metric_scope_mixed_scopes_all_satisfied(): + _validate_metric_scope( + [ProfileMetric.NULL_RATIO, ProfileMetric.RECORD_COUNT, ProfileMetric.PROFILING_SCORE], + table_name="orders", + column_name="email", + ) + + +# ---------------------------------------------------------------------- +# compare_profiling_runs — flow tests +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profile_history.HygieneIssue") +@patch("testgen.mcp.tools.profile_history.HygieneIssueType") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_auto_baseline( + mock_resolve, mock_pr, mock_iss_type, mock_iss, db_session_mock, +): + tg_id = uuid4() + target_run = _profiling_run(table_groups_id=tg_id, profiling_starttime=datetime(2026, 5, 13)) + baseline_run = _profiling_run(table_groups_id=tg_id, profiling_starttime=datetime(2026, 5, 10)) + target_run.get_previous.return_value = baseline_run + mock_resolve.return_value = target_run + + target_row = _profile_row(run_id=target_run.id, null_value_ct=50) + baseline_row = _profile_row(run_id=baseline_run.id, null_value_ct=300) + mock_pr.select_for_runs.return_value = [target_row, baseline_row] + mock_iss.select_where.return_value = [] + mock_iss_type.select_where.return_value = [] + + with _patch_session([_je(), _je()]): + result = compare_profiling_runs(str(target_run.job_execution_id)) + + assert "Profiling Run Comparison" in result + assert "Target" in result and "Baseline" in result + assert "Profiling Run" in result and "Started" in result + target_run.get_previous.assert_called_once() + + +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_rejects_non_completed_target(mock_resolve, db_session_mock): + target_run = _profiling_run() + mock_resolve.return_value = target_run + + with _patch_session([_je(status=JobStatus.RUNNING)]): + with pytest.raises(MCPUserError, match="Target run is in `Running` state"): + compare_profiling_runs(str(target_run.job_execution_id)) + + +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_rejects_canceled_target(mock_resolve, db_session_mock): + target_run = _profiling_run() + mock_resolve.return_value = target_run + + with _patch_session([_je(status=JobStatus.CANCELED)]): + with pytest.raises(MCPUserError, match="`Canceled`"): + compare_profiling_runs(str(target_run.job_execution_id)) + + +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_rejects_cross_table_group(mock_resolve, db_session_mock): + target_run = _profiling_run(table_groups_id=uuid4()) + baseline_run = _profiling_run(table_groups_id=uuid4()) + mock_resolve.side_effect = [target_run, baseline_run] + + with _patch_session([_je()]): + with pytest.raises(MCPUserError, match="same table group"): + compare_profiling_runs( + str(target_run.job_execution_id), + str(baseline_run.job_execution_id), + ) + + +def test_compare_profiling_runs_column_requires_table(db_session_mock): + with pytest.raises(MCPUserError, match="`column_name` requires `table_name`"): + compare_profiling_runs(str(uuid4()), column_name="email") + + +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_auto_baseline_first_run(mock_resolve, db_session_mock): + target_run = _profiling_run() + target_run.get_previous.return_value = None + mock_resolve.return_value = target_run + + with _patch_session([_je()]): + with pytest.raises(MCPUserError, match="no earlier completed profiling run"): + compare_profiling_runs(str(target_run.job_execution_id)) + + +@patch("testgen.mcp.tools.profile_history.HygieneIssue") +@patch("testgen.mcp.tools.profile_history.HygieneIssueType") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_identical_runs_renders_no_changes( + mock_resolve, mock_pr, mock_iss_type, mock_iss, db_session_mock, +): + tg_id = uuid4() + target_run = _profiling_run(table_groups_id=tg_id) + baseline_run = _profiling_run(table_groups_id=tg_id, profiling_starttime=datetime(2026, 5, 1)) + target_run.get_previous.return_value = baseline_run + mock_resolve.return_value = target_run + + target_row = _profile_row(run_id=target_run.id) + baseline_row = _profile_row(run_id=baseline_run.id) # same values + mock_pr.select_for_runs.return_value = [target_row, baseline_row] + mock_iss.select_where.return_value = [] + mock_iss_type.select_where.return_value = [] + + with _patch_session([_je(), _je()]): + result = compare_profiling_runs(str(target_run.job_execution_id)) + + assert "No changes between target and baseline" in result + + +# ---------------------------------------------------------------------- +# get_profiling_trends +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_happy_path(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + tg = _table_group() + mock_tg_cls.get.return_value = tg + + run_old = _profiling_run(profiling_starttime=datetime(2026, 5, 1)) + run_new = _profiling_run(profiling_starttime=datetime(2026, 5, 13)) + mock_pr_cls.list_recent_complete.return_value = [run_new, run_old] + mock_pr_cls.count_confirmed_hygiene_issues.return_value = {} + + rows = [ + _profile_row(run_id=run_old.id, null_value_ct=300), + _profile_row(run_id=run_new.id, null_value_ct=50), + ] + mock_pr.select_for_runs.return_value = rows + + result = get_profiling_trends( + str(tg.id), + metrics=["Null Ratio", "Distinct Ratio"], + table_name="orders", + column_name="customer_email", + ) + + assert "Profiling trends" in result + assert "Null Ratio" in result + assert "Distinct Ratio" in result + assert "2026-05-13" in result and "2026-05-01" in result + + +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_invalid_metric(mock_tg_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + + with pytest.raises(MCPUserError, match="Invalid metrics"): + get_profiling_trends(str(uuid4()), metrics=["Unknown Metric"]) + + +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_empty_metrics(mock_tg_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + + with pytest.raises(MCPUserError, match="cannot be empty"): + get_profiling_trends(str(uuid4()), metrics=[]) + + +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_column_requires_table(mock_tg_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + + with pytest.raises(MCPUserError, match="`column_name` requires `table_name`"): + get_profiling_trends( + str(uuid4()), + metrics=["Null Ratio"], + column_name="email", + ) + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_no_runs(mock_tg_cls, mock_pr_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + mock_pr_cls.list_recent_complete.return_value = [] + + # TG-scope metric so we skip the profile-row fetch entirely + result = get_profiling_trends(str(uuid4()), metrics=["Profiling Score"]) + assert "No completed profiling runs" in result + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_first_appears_note(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + """Entity missing from the oldest run but present in newer runs.""" + mock_tg_cls.get.return_value = _table_group() + run_old = _profiling_run(profiling_starttime=datetime(2026, 5, 1, 9, 0)) + run_mid = _profiling_run(profiling_starttime=datetime(2026, 5, 10, 14, 0)) + run_new = _profiling_run(profiling_starttime=datetime(2026, 5, 13, 10, 0)) + mock_pr_cls.list_recent_complete.return_value = [run_new, run_mid, run_old] + # Only mid and new runs have the column — entity first appears at run_mid. + mock_pr.select_for_runs.return_value = [ + _profile_row(run_id=run_mid.id), + _profile_row(run_id=run_new.id), + ] + + result = get_profiling_trends( + str(uuid4()), + metrics=["Null Ratio"], + table_name="orders", + column_name="customer_email", + ) + assert "first appears in the run started 2026-05-10 14:00" in result + assert "last appears" not in result # present in newest run, no trailing-gap note + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_last_appears_note(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + """Entity present in early runs but missing from the newest run.""" + mock_tg_cls.get.return_value = _table_group() + run_old = _profiling_run(profiling_starttime=datetime(2026, 5, 1, 9, 0)) + run_mid = _profiling_run(profiling_starttime=datetime(2026, 5, 10, 14, 0)) + run_new = _profiling_run(profiling_starttime=datetime(2026, 5, 13, 10, 0)) + mock_pr_cls.list_recent_complete.return_value = [run_new, run_mid, run_old] + # Only old and mid runs have the column — entity last appears at run_mid. + mock_pr.select_for_runs.return_value = [ + _profile_row(run_id=run_old.id), + _profile_row(run_id=run_mid.id), + ] + + result = get_profiling_trends( + str(uuid4()), + metrics=["Null Ratio"], + table_name="orders", + column_name="legacy_id", + ) + assert "last appears in the run started 2026-05-10 14:00" in result + assert "first appears" not in result # present in oldest run, no leading-gap note + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_both_notes(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + """Entity has a bounded lifetime — missing on both ends of the window.""" + mock_tg_cls.get.return_value = _table_group() + run_oldest = _profiling_run(profiling_starttime=datetime(2026, 5, 9, 9, 0)) + run_first = _profiling_run(profiling_starttime=datetime(2026, 5, 10, 14, 0)) + run_last = _profiling_run(profiling_starttime=datetime(2026, 5, 12, 22, 0)) + run_newest = _profiling_run(profiling_starttime=datetime(2026, 5, 13, 10, 0)) + mock_pr_cls.list_recent_complete.return_value = [run_newest, run_last, run_first, run_oldest] + # Only the middle two runs carry the column. + mock_pr.select_for_runs.return_value = [ + _profile_row(run_id=run_first.id), + _profile_row(run_id=run_last.id), + ] + + result = get_profiling_trends( + str(uuid4()), + metrics=["Null Ratio"], + table_name="orders", + column_name="customer_email_v2", + ) + assert "first appears in the run started 2026-05-10 14:00" in result + assert "last appears in the run started 2026-05-12 22:00" in result + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_no_notes_when_present_throughout(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + """Entity present in every run — no first/last-appears noise.""" + mock_tg_cls.get.return_value = _table_group() + run_old = _profiling_run(profiling_starttime=datetime(2026, 5, 1, 9, 0)) + run_new = _profiling_run(profiling_starttime=datetime(2026, 5, 13, 10, 0)) + mock_pr_cls.list_recent_complete.return_value = [run_new, run_old] + mock_pr.select_for_runs.return_value = [ + _profile_row(run_id=run_old.id), + _profile_row(run_id=run_new.id), + ] + + result = get_profiling_trends( + str(uuid4()), + metrics=["Null Ratio"], + table_name="orders", + column_name="customer_id", + ) + assert "first appears" not in result + assert "last appears" not in result + + +# ---------------------------------------------------------------------- +# get_schema_history +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_schema_history_happy_path(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + tg = _table_group() + mock_tg_cls.get.return_value = tg + + run_old = _profiling_run(profiling_starttime=datetime(2026, 5, 1)) + run_new = _profiling_run(profiling_starttime=datetime(2026, 5, 13)) + mock_pr_cls.list_recent_complete.return_value = [run_new, run_old] + + rows = [ + _profile_row(run_id=run_old.id, table_name="orders", column_name="id", general_type="N", record_ct=900), + _profile_row(run_id=run_old.id, table_name="orders", column_name="email", general_type="A", record_ct=900), + _profile_row(run_id=run_new.id, table_name="orders", column_name="id", general_type="N", record_ct=1000), + _profile_row(run_id=run_new.id, table_name="orders", column_name="email", general_type="A", record_ct=1000), + _profile_row(run_id=run_new.id, table_name="orders", column_name="phone", general_type="A", record_ct=1000), + ] + mock_pr.select_for_runs.return_value = rows + + result = get_schema_history(str(tg.id)) + + assert "Schema history" in result + assert "phone" in result # newly added column + assert "Record count" in result # 900 → 1,000 delta + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_schema_history_single_run_short_circuits(mock_tg_cls, mock_pr_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + mock_pr_cls.list_recent_complete.return_value = [_profiling_run()] + + result = get_schema_history(str(uuid4())) + assert "at least two are needed" in result + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_schema_history_no_runs(mock_tg_cls, mock_pr_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + mock_pr_cls.list_recent_complete.return_value = [] + + result = get_schema_history(str(uuid4())) + assert "No completed profiling runs" in result diff --git a/tests/unit/mcp/test_tools_profiling.py b/tests/unit/mcp/test_tools_profiling.py index e9773075..588bd42d 100644 --- a/tests/unit/mcp/test_tools_profiling.py +++ b/tests/unit/mcp/test_tools_profiling.py @@ -1,9 +1,11 @@ +from datetime import datetime from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from testgen.common.models.data_column import ColumnProfileSummary +from testgen.common.models.data_column import ColumnProfileDetail, ColumnProfileSummary, DataColumnChars +from testgen.common.pii_masking import PII_REDACTED from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import ProjectPermissions @@ -252,7 +254,7 @@ def test_list_column_profiles_paginates(mock_tg_cls, mock_dcc_cls, db_session_mo assert "Use `page=2` for more" in result -@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.ProfilingRun") @patch("testgen.mcp.tools.profiling.DataColumnChars") @patch("testgen.mcp.tools.common.TableGroup") def test_list_column_profiles_with_valid_job_execution_id( @@ -262,6 +264,7 @@ def test_list_column_profiles_with_valid_job_execution_id( pr = MagicMock() pr.id = uuid4() pr.table_groups_id = tg.id + pr.project_code = tg.project_code mock_tg_cls.get.return_value = tg mock_pr_cls.get_by_id_or_job.return_value = pr @@ -273,7 +276,7 @@ def test_list_column_profiles_with_valid_job_execution_id( assert mock_dcc_cls.list_for_table_group.call_args.kwargs["profiling_run_id"] == pr.id -@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.ProfilingRun") @patch("testgen.mcp.tools.common.TableGroup") def test_list_column_profiles_rejects_je_from_different_tg( mock_tg_cls, mock_pr_cls, db_session_mock, @@ -283,6 +286,7 @@ def test_list_column_profiles_rejects_je_from_different_tg( pr = MagicMock() pr.id = uuid4() pr.table_groups_id = uuid4() # different TG + pr.project_code = tg.project_code mock_tg_cls.get.return_value = tg mock_pr_cls.get_by_id_or_job.return_value = pr @@ -292,7 +296,7 @@ def test_list_column_profiles_rejects_je_from_different_tg( list_column_profiles(str(uuid4()), job_execution_id=str(uuid4())) -@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.ProfilingRun") @patch("testgen.mcp.tools.common.TableGroup") def test_list_column_profiles_rejects_unknown_je(mock_tg_cls, mock_pr_cls, db_session_mock): mock_tg_cls.get.return_value = _mock_table_group() @@ -336,14 +340,14 @@ def test_list_column_profiles_inaccessible_tg(mock_tg_cls, db_session_mock): @pytest.mark.parametrize( "value,expected", [ - (None, None), - ("", None), - ("MANUAL", "PII"), - ("A/ID/Passport", "PII (High Risk - ID / Passport)"), - ("B/NAME/Individual", "PII (Moderate Risk - Name / Individual)"), - ("C/CONTACT", "PII (Low Risk - Contact)"), - ("B/ID/ID", "PII (Moderate Risk - ID)"), # detail collapses when equal to type label - ("X/UNKNOWN/Detail", "PII (Moderate Risk / Detail)"), # unknown risk falls back; unknown type drops label + (None, "No"), + ("", "No"), + ("MANUAL", "Yes"), + ("A/ID/Passport", "Yes (High Risk - ID / Passport)"), + ("B/NAME/Individual", "Yes (Moderate Risk - Name / Individual)"), + ("C/CONTACT", "Yes (Low Risk - Contact)"), + ("B/ID/ID", "Yes (Moderate Risk - ID)"), # detail collapses when equal to type label + ("X/UNKNOWN/Detail", "Yes (Moderate Risk / Detail)"), # unknown risk falls back; unknown type drops label ], ) def test_format_pii(value, expected): @@ -359,12 +363,12 @@ def test_format_pii(value, expected): def test_render_row_renders_parsed_pii_label(): from testgen.mcp.tools.profiling import _render_column_profile_row row = _render_column_profile_row(_column_summary(pii_flag="B/NAME/Individual")) - assert row[5] == "PII (Moderate Risk - Name / Individual)" + assert row[5] == "Yes (Moderate Risk - Name / Individual)" -def test_render_row_falsy_pii_renders_none(): +def test_render_row_falsy_pii_renders_no(): from testgen.mcp.tools.profiling import _render_column_profile_row - assert _render_column_profile_row(_column_summary(pii_flag=None))[5] is None + assert _render_column_profile_row(_column_summary(pii_flag=None))[5] == "No" def test_render_row_cde_collapsed_to_y_or_none(): @@ -479,3 +483,1418 @@ def test_list_profiling_summaries_inaccessible_tg(mock_tg_cls, db_session_mock): from testgen.mcp.tools.profiling import list_profiling_summaries with pytest.raises(MCPResourceNotAccessible, match="Table group .* not found or not accessible"): list_profiling_summaries(table_group_id=str(uuid4())) + + +# ---------------------------------------------------------------------- +# list_profiling_runs +# ---------------------------------------------------------------------- + +from datetime import UTC + +from testgen.common.enums import JobStatus + +_RUN_CREATED = datetime(2026, 4, 1, 10, 0, 0, tzinfo=UTC) +_RUN_STARTED = datetime(2026, 4, 1, 10, 0, 5, tzinfo=UTC) +_RUN_COMPLETED = datetime(2026, 4, 1, 10, 1, 30, tzinfo=UTC) + + +def _mock_profiling_run(**overrides): + defaults = { + "job_execution_id": uuid4(), + "profiling_run_id": uuid4(), + "project_code": "demo", + "status": JobStatus.COMPLETED, + "status_label": "Completed", + "created_at": _RUN_CREATED, + "started_at": _RUN_STARTED, + "completed_at": _RUN_COMPLETED, + "error_message": None, + "table_groups_name": "demo-tg", + "table_group_schema": "demo", + "table_ct": 5, "column_ct": 30, "record_ct": 1000, + "anomaly_ct": 4, + "anomalies_definite_ct": 1, "anomalies_likely_ct": 1, + "anomalies_possible_ct": 2, "anomalies_dismissed_ct": 0, + "dq_score_profiling": 95.5, + } + defaults.update(overrides) + return MagicMock(**defaults) + + +@patch("testgen.mcp.tools.profiling.JobExecution") +@patch("testgen.mcp.tools.profiling.next_scheduled_run", return_value=None) +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_profiling_runs_default(mock_tg_cls, mock_run_cls, mock_next, mock_je, db_session_mock): + mock_je.select_active_by_kwargs.return_value = [] + tg = _mock_table_group() + tg.table_groups_name = "demo-tg" + mock_tg_cls.get.return_value = tg + mock_run_cls.select_summary.return_value = ([_mock_profiling_run()], 1) + + from testgen.mcp.tools.profiling import list_profiling_runs + result = list_profiling_runs(table_group_id=str(uuid4())) + + assert "Profiling runs for `demo-tg`" in result + assert "Completed" in result + call_kwargs = mock_run_cls.select_summary.call_args.kwargs + assert call_kwargs["statuses"] is None + + +@patch("testgen.mcp.tools.profiling.JobExecution") +@patch("testgen.mcp.tools.profiling.next_scheduled_run", return_value=None) +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_profiling_runs_status_filter(mock_tg_cls, mock_run_cls, mock_next, mock_je, db_session_mock): + mock_je.select_active_by_kwargs.return_value = [] + mock_tg_cls.get.return_value = _mock_table_group() + mock_run_cls.select_summary.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_profiling_runs + list_profiling_runs(table_group_id=str(uuid4()), status="Pending") + + call_kwargs = mock_run_cls.select_summary.call_args.kwargs + assert call_kwargs["statuses"] == [JobStatus.PENDING, JobStatus.CLAIMED] + + +@patch("testgen.mcp.tools.profiling.JobExecution") +@patch("testgen.mcp.tools.profiling.next_scheduled_run", return_value=_RUN_STARTED) +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_profiling_runs_shows_next_scheduled(mock_tg_cls, mock_run_cls, mock_next, mock_je, db_session_mock): + mock_je.select_active_by_kwargs.return_value = [] + mock_tg_cls.get.return_value = _mock_table_group() + mock_run_cls.select_summary.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_profiling_runs + result = list_profiling_runs(table_group_id=str(uuid4())) + + assert "Next scheduled run" in result + + +@patch("testgen.mcp.tools.profiling.next_scheduled_run", return_value=None) +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_profiling_runs_invalid_status(mock_tg_cls, mock_run_cls, mock_next, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + + from testgen.mcp.tools.profiling import list_profiling_runs + with pytest.raises(MCPUserError, match="Invalid status"): + list_profiling_runs(table_group_id=str(uuid4()), status="Bogus") + + +# ---------------------------------------------------------------------- +# get_profiling_run +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +def test_get_profiling_run_returns_detail(mock_run_cls, db_session_mock): + summary = _mock_profiling_run() + mock_run_cls.select_summary.return_value = ([summary], 1) + mock_run = MagicMock(project_code="demo") + mock_run_cls.get_by_id_or_job.return_value = mock_run + mock_run_cls.select_table_breakdown.return_value = [ + MagicMock(schema_name="demo", table_name="orders", record_ct=1000, column_ct=5, anomaly_ct=2), + ] + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, permission="catalog", username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + + from testgen.mcp.tools.profiling import get_profiling_run + result = get_profiling_run(str(summary.job_execution_id)) + + assert "Profiling run: demo-tg" in result + assert "Completed" in result + assert "Per-table breakdown" in result + assert "orders" in result + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +def test_get_profiling_run_pending_no_breakdown(mock_run_cls, db_session_mock): + summary = _mock_profiling_run( + status=JobStatus.PENDING, status_label="Pending", + profiling_run_id=None, started_at=None, completed_at=None, + table_ct=None, column_ct=None, record_ct=None, anomaly_ct=None, + anomalies_definite_ct=None, anomalies_likely_ct=None, + anomalies_possible_ct=None, dq_score_profiling=None, + ) + mock_run_cls.select_summary.return_value = ([summary], 1) + mock_run_cls.get_by_id_or_job.return_value = MagicMock(project_code="demo") + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, permission="catalog", username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + + from testgen.mcp.tools.profiling import get_profiling_run + result = get_profiling_run(str(summary.job_execution_id)) + + assert "Pending" in result + assert "In progress" in result + assert "Per-table breakdown" not in result + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +def test_get_profiling_run_not_found(mock_run_cls, db_session_mock): + mock_run_cls.select_summary.return_value = ([], 0) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, permission="catalog", username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + + from testgen.mcp.tools.profiling import get_profiling_run + with pytest.raises(MCPResourceNotAccessible): + get_profiling_run(str(uuid4())) + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +def test_get_profiling_run_inaccessible_project(mock_run_cls, db_session_mock): + summary = _mock_profiling_run(project_code="secret") + mock_run_cls.select_summary.return_value = ([summary], 1) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, permission="catalog", username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + + from testgen.mcp.tools.profiling import get_profiling_run + with pytest.raises(MCPResourceNotAccessible): + get_profiling_run(str(summary.job_execution_id)) + + +def test_get_profiling_run_invalid_uuid(db_session_mock): + from testgen.mcp.tools.profiling import get_profiling_run + with pytest.raises(MCPUserError, match="not a valid UUID"): + get_profiling_run("not-a-uuid") + + +# ---------------------------------------------------------------------- +# get_column_profile_detail +# ---------------------------------------------------------------------- + + +def _column_detail(**overrides) -> ColumnProfileDetail: + """Build a ColumnProfileDetail with sensible alpha-column defaults; override per test.""" + base: dict = { + # Identity + "column_name": "customer_name", + "table_name": "customers", + "schema_name": "demo", + # Types & metadata + "general_type": "A", + "column_type": "varchar(50)", + "db_data_type": "varchar(50)", + "functional_data_type": "Person Given Name", + "datatype_suggestion": "VARCHAR(20)", + "functional_table_type": None, + "pii_flag": None, + "critical_data_element": False, + # Counts + "record_ct": 500, + "value_ct": 500, + "distinct_value_ct": 260, + "null_value_ct": 0, + "filled_value_ct": 0, + "zero_value_ct": 0, + # Alpha + "min_length": 3, + "max_length": 50, + "avg_length": 12.4, + "min_text": "Aaron", + "max_text": "Zoey", + "top_freq_values": "| Mary | 12\n| John | 10", + "top_patterns": "10 | A(5) | 8 | A(6)", + "distinct_std_value_ct": 250, + "distinct_pattern_ct": 35, + "std_pattern_match": None, + "mixed_case_ct": 100, + "lower_case_ct": 350, + "upper_case_ct": 50, + "non_alpha_ct": 0, + "includes_digit_ct": 0, + "numeric_ct": 0, + "date_ct": 0, + "quoted_value_ct": 0, + "lead_space_ct": 0, + "embedded_space_ct": 0, + "avg_embedded_spaces": 0.0, + "zero_length_ct": 0, + # Numeric + "min_value": None, + "min_value_over_0": None, + "max_value": None, + "avg_value": None, + "stdev_value": None, + "percentile_25": None, + "percentile_50": None, + "percentile_75": None, + # Date + "min_date": None, + "max_date": None, + "before_1yr_date_ct": None, + "before_5yr_date_ct": None, + "before_20yr_date_ct": None, + "within_1yr_date_ct": None, + "within_1mo_date_ct": None, + "future_date_ct": None, + # Boolean + "boolean_true_ct": None, + # Per-column profiling failure + "query_error": None, + # Scores & hygiene + "dq_score_profiling": 95.2, + "dq_score_testing": 90.0, + "hygiene_issue_count": 2, + # Run identity + "profile_run_id": uuid4(), + "profile_run_je_id": uuid4(), + "profile_run_status": "Complete", + "profile_run_started_at": datetime(2026, 5, 1, 12, 0, 0), + "profile_run_ended_at": datetime(2026, 5, 1, 12, 5, 0), + "profile_run_log_message": None, + } + base.update(overrides) + return ColumnProfileDetail(**base) + + +# --- happy paths per general_type --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_alpha_renders_alpha_sections(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail(general_type="A") + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + assert "Column Profile" in result + assert "customer_name" in result + assert "Profiling Run" in result + # Alpha-specific sections present + assert "Length" in result + assert "Text Range" in result + assert "Patterns" in result + assert "Aaron" in result + assert "Zoey" in result + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_alpha_renders_distinct_standard_values( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + """`distinct_std_value_ct` (alpha-only) renders under the Patterns section as 'Distinct Standard Values'.""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + general_type="A", + distinct_std_value_ct=247, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + assert "Distinct Standard Values" in result + assert "247" in result + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_numeric_renders_numeric_sections(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + column_name="amount", + general_type="N", + db_data_type="numeric", + functional_data_type="Currency", + # Numeric stats + min_value=0.0, + min_value_over_0=0.01, + max_value=99999.99, + avg_value=125.34, + stdev_value=42.1, + percentile_25=50.0, + percentile_50=100.0, + percentile_75=200.0, + # Alpha fields cleared (numeric column wouldn't have these populated) + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + min_length=None, + max_length=None, + avg_length=None, + std_pattern_match=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "orders", "amount") + + # Numeric-specific content present + assert "Median" in result or "Percentile" in result or "percentile_50" in result.lower() + assert "99999.99" in result or "99,999.99" in result + # Alpha-only sections absent + assert "Text Range" not in result + assert "Min Text" not in result + assert "Aaron" not in result + assert "Length" not in result.replace("Avg Length", "") # rough — ensures no Length section + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_date_renders_date_sections(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + column_name="created_at", + general_type="D", + db_data_type="timestamp", + functional_data_type="Datetime-Created", + min_date=datetime(2024, 1, 1, 0, 0, 0), + max_date=datetime(2026, 4, 30, 23, 59, 59), + before_1yr_date_ct=10000, + before_5yr_date_ct=2000, + before_20yr_date_ct=0, + within_1yr_date_ct=40000, + within_1mo_date_ct=5000, + future_date_ct=0, + # Alpha fields cleared + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + min_length=None, + max_length=None, + avg_length=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "orders", "created_at") + + # Date-specific content + assert "Within 1" in result or "Before 1" in result or "Date Range" in result + assert "2024" in result + # Alpha-only sections absent + assert "Aaron" not in result + assert "Pattern" not in result + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_boolean_renders_boolean_section(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + column_name="is_active", + general_type="B", + db_data_type="boolean", + functional_data_type="Boolean", + boolean_true_ct=420, + value_ct=500, + # Alpha fields cleared + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + min_length=None, + max_length=None, + avg_length=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "users", "is_active") + + assert "True" in result + assert "420" in result + # Alpha-only sections absent + assert "Pattern" not in result + assert "Length" not in result.replace("Avg Length", "") + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_unknown_general_type_renders_counts_only( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + column_name="payload", + general_type="X", + db_data_type="json", + functional_data_type=None, + # All type-specific fields cleared + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + min_length=None, + max_length=None, + avg_length=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "events", "payload") + + assert "payload" in result + assert "Counts" in result + assert "Pattern" not in result + assert "Boolean Distribution" not in result + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_general_type_t_treated_as_unknown( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + """T mirrors current UI behavior — falls through to common counts only.""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + column_name="ts", + general_type="T", + db_data_type="time", + functional_data_type=None, + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + min_length=None, + max_length=None, + avg_length=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "events", "ts") + + assert "Counts" in result + assert "Date Range" not in result # not dispatched as date + + +# --- never-profiled / no-profile-for-pinned-run --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_never_profiled_column_rejects( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + """Column row exists in data_column_chars but has no completed profiling run yet + (`last_complete_profile_run_id IS NULL`). The model returns a detail with NULL run + fields; the tool must reject rather than render an empty profile. + """ + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + profile_run_id=None, + profile_run_je_id=None, + profile_run_status=None, + profile_run_started_at=None, + profile_run_ended_at=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError) as exc_info: + get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + msg = str(exc_info.value) + assert "customer_name" in msg + assert "not been profiled" in msg + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pinned_run_without_column_rejects( + mock_tg_cls, mock_pr_cls, mock_dcc_cls, db_session_mock, +): + """User pins a valid run via job_execution_id, but that run has no profile_results + row for this column. Surface the pinned run id so the LLM knows what to try next. + """ + tg = _mock_table_group() + pr = MagicMock() + pr.id = uuid4() + pr.table_groups_id = tg.id + pr.project_code = tg.project_code + + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_by_id_or_job.return_value = pr + mock_dcc_cls.get_column_detail.return_value = _column_detail( + profile_run_id=None, + profile_run_je_id=None, + profile_run_status=None, + profile_run_started_at=None, + profile_run_ended_at=None, + ) + + je_id_str = str(uuid4()) + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError) as exc_info: + get_column_profile_detail( + str(uuid4()), "customers", "customer_name", job_execution_id=je_id_str + ) + + msg = str(exc_info.value) + assert "customer_name" in msg + assert je_id_str in msg + + +# --- error paths --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_column_not_found_unified_error( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = None + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPResourceNotAccessible, match=r"Column .* not found or not accessible"): + get_column_profile_detail(str(uuid4()), "customers", "ghost_column") + + +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_inaccessible_tg(mock_tg_cls, db_session_mock): + mock_tg_cls.get.return_value = None + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPResourceNotAccessible, match=r"Table group .* not found or not accessible"): + get_column_profile_detail(str(uuid4()), "customers", "x") + + +def test_get_column_profile_detail_invalid_tg_uuid(db_session_mock): + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError, match="Invalid table_group_id"): + get_column_profile_detail("not-a-uuid", "customers", "x") + + +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_invalid_je_uuid(mock_tg_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError, match="Invalid job_execution_id"): + get_column_profile_detail( + str(uuid4()), "customers", "x", job_execution_id="bad" + ) + + +# --- job_execution_id pinning --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pinned_run_passes_id_to_model( + mock_tg_cls, mock_pr_cls, mock_dcc_cls, db_session_mock, +): + tg = _mock_table_group() + pr = MagicMock() + pr.id = uuid4() + pr.table_groups_id = tg.id + pr.project_code = tg.project_code + + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_by_id_or_job.return_value = pr + mock_dcc_cls.get_column_detail.return_value = _column_detail() + + from testgen.mcp.tools.profiling import get_column_profile_detail + get_column_profile_detail(str(uuid4()), "customers", "customer_name", job_execution_id=str(uuid4())) + + assert mock_dcc_cls.get_column_detail.call_args.kwargs["profiling_run_id"] == pr.id + + +@patch("testgen.mcp.tools.common.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pinned_run_from_different_tg_unified_error( + mock_tg_cls, mock_pr_cls, db_session_mock, +): + tg = _mock_table_group() + pr = MagicMock() + pr.id = uuid4() + pr.table_groups_id = uuid4() # different + pr.project_code = tg.project_code + + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_by_id_or_job.return_value = pr + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPResourceNotAccessible, match=r"Profiling run .* not found or not accessible"): + get_column_profile_detail( + str(uuid4()), "customers", "x", job_execution_id=str(uuid4()) + ) + + +@patch("testgen.mcp.tools.common.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pinned_run_unknown_unified_error( + mock_tg_cls, mock_pr_cls, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_pr_cls.get_by_id_or_job.return_value = None + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPResourceNotAccessible, match=r"Profiling run .* not found or not accessible"): + get_column_profile_detail( + str(uuid4()), "customers", "x", job_execution_id=str(uuid4()) + ) + + +# --- run-status preconditions --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_running_run_rejects_with_status( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + mock_tg_cls.get.return_value = _mock_table_group() + je_id = uuid4() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + profile_run_status="Running", + profile_run_je_id=je_id, + profile_run_ended_at=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError) as exc_info: + get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + msg = str(exc_info.value) + assert "Running" in msg + assert str(je_id) in msg + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_error_run_includes_log_message( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + mock_tg_cls.get.return_value = _mock_table_group() + je_id = uuid4() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + profile_run_status="Error", + profile_run_je_id=je_id, + profile_run_log_message="connection timed out", + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError) as exc_info: + get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + msg = str(exc_info.value) + assert "Error" in msg + assert str(je_id) in msg + assert "connection timed out" in msg + + +# --- PII redaction --- + + +@patch("testgen.mcp.permissions._compute_project_permissions") +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pii_column_no_view_pii_redacts( + mock_tg_cls, mock_dcc_cls, mock_compute, db_session_mock, +): + """User has 'catalog' on demo but NOT 'view_pii' → 8 raw-value fields redacted; aggregates kept.""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + pii_flag="B/CONTACT/Email", + column_name="customer_email", + general_type="A", + std_pattern_match="EMAIL", + min_text="aaron@example.com", + max_text="zoey@example.com", + top_freq_values="| mary@x.com | 1\n| john@x.com | 1", + ) + # No project includes view_pii — only catalog allowed + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_c"}, # role_c has 'catalog' but not 'view_pii' in test matrix + permission="catalog", + username="test_user", + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_email") + + # Raw-value fields redacted + assert PII_REDACTED in result + assert "aaron@example.com" not in result + assert "zoey@example.com" not in result + assert "mary@x.com" not in result + # Aggregates / counts / std_pattern_match still visible + assert "260" in result or "Distinct" in result + assert "EMAIL" in result or "Email" in result + + +@patch("testgen.mcp.permissions._compute_project_permissions") +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pii_column_with_view_pii_shows_values( + mock_tg_cls, mock_dcc_cls, mock_compute, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + pii_flag="B/CONTACT/Email", + column_name="customer_email", + min_text="aaron@example.com", + max_text="zoey@example.com", + ) + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, # role_a has 'view_pii' in conftest matrix? actually no — but we need a role that includes view_pii. Use role-with-view_pii via "edit" mapping. + permission="catalog", + username="test_user", + ) + # Patch the rbac mapping so role_a includes view_pii for this test + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance.return_value.rbac.get_roles_with_permission.side_effect = ( + lambda perm: ["role_a"] if perm in ("catalog", "view_pii") else [] + ) + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_email") + + assert "aaron@example.com" in result + assert PII_REDACTED not in result + + +@patch("testgen.mcp.permissions._compute_project_permissions") +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_non_pii_column_never_redacts( + mock_tg_cls, mock_dcc_cls, mock_compute, db_session_mock, +): + """No pii_flag → raw values shown regardless of view_pii grant.""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + pii_flag=None, + min_text="Aaron", + max_text="Zoey", + ) + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_c"}, + permission="catalog", + username="test_user", + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + assert "Aaron" in result + assert "Zoey" in result + assert PII_REDACTED not in result + + +# --- query_error surfacing --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_query_error_section(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + query_error="ORA-01017: invalid username/password", + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + assert "Profiling Error" in result + assert "ORA-01017" in result + + +# ---------------------------------------------------------------------- +# list_column_profiles — predicate filters +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_null_ratio_above_adds_clause(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.list_for_table_group.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), null_ratio_above=0.2) + + clauses = mock_dcc_cls.list_for_table_group.call_args[0] + assert any("null_value_ct" in str(c) for c in clauses) + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_score_profiling_above_converts_to_0_to_1_scale( + mock_tg_cls, mock_method, db_session_mock, +): + """The user-facing 0-100 score range maps to the 0-1 fraction the DB stores.""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), score_profiling_above=70) + + sql = _compile_clauses(mock_method) + assert "dq_score_profiling > 0.7" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_score_testing_below_converts_to_0_to_1_scale( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), score_testing_below=50) + + sql = _compile_clauses(mock_method) + assert "dq_score_testing < 0.5" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_pii_true_adds_is_not_null_clause(mock_tg_cls, mock_method, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), pii=True) + + sql = _compile_clauses(mock_method) + assert "pii_flag IS NOT NULL" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_cde_true_coalesces_column_and_table_flag( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), cde=True) + + sql = _compile_clauses(mock_method) + assert "data_column_chars.critical_data_element IS true" in sql + assert "data_table_chars.critical_data_element IS true" in sql + assert "OR" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_suggested_data_type_any_uses_is_not_null( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), suggested_data_type="Any") + + sql = _compile_clauses(mock_method) + assert "datatype_suggestion IS NOT NULL" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_suggested_data_type_concrete_uses_prefix_ilike( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), suggested_data_type="Integer") + + sql = _compile_clauses(mock_method) + assert "INTEGER%" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_general_type_translates_word_to_letter( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), general_type="Numeric") + + sql = _compile_clauses(mock_method) + assert "general_type = 'N'" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_pii_category_translated_to_stored_code( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), pii_category="Contact") + + sql = _compile_clauses(mock_method) + assert "%/CONTACT/%" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_pii_risk_level_high_includes_manual( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), pii_risk_level="High") + + sql = _compile_clauses(mock_method) + assert "'A/%'" in sql and "'MANUAL'" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_pii_risk_level_moderate_does_not_include_manual( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), pii_risk_level="Moderate") + + sql = _compile_clauses(mock_method) + assert "'B/%'" in sql + assert "MANUAL" not in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_semantic_data_type_uses_ilike( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), semantic_data_type="Person Given") + + sql = _compile_clauses(mock_method) + # Default dialect renders ILIKE as ``LOWER(col) LIKE LOWER(pat) ESCAPE`` — same semantic. + assert "LIKE" in sql.upper() + assert "%Person Given%" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_semantic_data_type_underscore_escaped( + mock_tg_cls, mock_method, db_session_mock, +): + """Underscores in the input must be escaped (column names commonly contain them).""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), semantic_data_type="ID_FK") + + sql = _compile_clauses(mock_method) + # The escape clause appears, and the underscore is escaped in the pattern. + assert "ID\\_FK" in sql or "ID\\\\_FK" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_semantic_data_type_empty_rejected(mock_tg_cls, mock_method, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + + from testgen.mcp.tools.profiling import list_column_profiles + with pytest.raises(MCPUserError, match="`semantic_data_type` cannot be empty"): + list_column_profiles(str(uuid4()), semantic_data_type=" ") + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_order_by_passes_enum_to_model(mock_tg_cls, mock_method, db_session_mock): + from testgen.common.models.data_column import ColumnOrderBy + + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), order_by="Null Ratio") + + assert mock_method.call_args.kwargs["order_by"] is ColumnOrderBy.NULL_RATIO + + +def _compile_clauses(mock_method): + """Compile the *clauses arg of a captured ``list_for_table_group`` call into a single SQL string.""" + clauses = mock_method.call_args[0] + return " ".join(str(c.compile(compile_kwargs={"literal_binds": True})) for c in clauses) + + +# ---------------------------------------------------------------------- +# get_column_frequent_values +# ---------------------------------------------------------------------- + + +def _mock_profile_result(**overrides): + pr = MagicMock() + pr.profile_run_id = uuid4() + pr.record_ct = 500 + pr.distinct_value_ct = 3 + pr.pii_flag = None + pr.general_type = "A" + pr.top_freq_values = "| Mexico | 200\n| USA | 180\n| Canada | 120" + pr.top_patterns = "200 | Aaaaaa | 100 | AAA" + for k, v in overrides.items(): + setattr(pr, k, v) + return pr + + +def _mock_profiling_run_for_tg(tg_id): + pr = MagicMock() + pr.id = uuid4() + pr.table_groups_id = tg_id + pr.job_execution_id = uuid4() + return pr + + +def _mock_data_column(pii_flag=None): + """Build a mock `DataColumnChars` row carrying just the fields the helper reads.""" + col = MagicMock() + col.pii_flag = pii_flag + return col + + +@patch.object(DataColumnChars, "select_where") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_happy_path( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result() + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column()] + + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "country") + + assert "Frequent values: customers.country" in result + assert "Mexico" in result and "USA" in result and "Canada" in result + assert "40.00%" in result # 200/500 + assert "Top values" in result + + +@patch.object(DataColumnChars, "select_where") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_surfaces_job_execution_id_not_profile_run_id( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + profile = _mock_profile_result() + mock_pr_cls.get_for_column.return_value = profile + run = _mock_profiling_run_for_tg(tg.id) + mock_run_cls.get.return_value = run + mock_dcc_select.return_value = [_mock_data_column()] + + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "country") + + # The internal profile_run_id PK must not leak; only the job_execution_id is followable. + assert str(run.job_execution_id) in result + assert str(profile.profile_run_id) not in result + + +@patch.object(DataColumnChars, "select_where") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_pii_value_redacted_when_caller_lacks_view_pii( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): + tg = _mock_table_group(project_code="demo") + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + top_freq_values="| alice@example.com | 5\n| bob@example.com | 3", + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + # The pii_flag the tool reads comes from DataColumnChars, not ProfileResult. + mock_dcc_select.return_value = [_mock_data_column(pii_flag="B/CONTACT/Email")] + + # Default test conftest grants no view_pii (TEST_PERM_MATRIX has no entry). + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "email") + + assert PII_REDACTED in result + assert "alice@example.com" not in result + + +@patch.object(DataColumnChars, "select_where") +@patch("testgen.mcp.permissions._compute_project_permissions") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_pii_value_visible_with_view_pii_grant( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_compute, mock_dcc_select, db_session_mock, +): + tg = _mock_table_group(project_code="demo") + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + top_freq_values="| alice@example.com | 5\n| bob@example.com | 3", + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column(pii_flag="B/CONTACT/Email")] + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, + permission="catalog", + username="test_user", + ) + # Add view_pii to the matrix for this test by patching the role-lookup. + with patch("testgen.mcp.permissions.PluginHook") as hook_mock: + hook_mock.instance.return_value.rbac.get_roles_with_permission.return_value = ["role_a"] + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "email") + + assert "alice@example.com" in result + assert PII_REDACTED not in result + + +@patch.object(DataColumnChars, "select_where") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_high_cardinality_fallback( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + top_freq_values=None, distinct_value_ct=10000, + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column()] + + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "customer_id") + + assert "Frequency data not available" in result + assert "10000" in result + + +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_missing_profile_raises_not_accessible( + mock_tg_cls, mock_pr_cls, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_pr_cls.get_for_column.return_value = None + + from testgen.mcp.tools.profiling import get_column_frequent_values + with pytest.raises(MCPResourceNotAccessible, match="Column profile"): + get_column_frequent_values(str(uuid4()), "customers", "ghost") + + +@patch.object(DataColumnChars, "select_where") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_pii_source_is_data_column_chars_not_profile_result( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): + """``data_column_chars.pii_flag`` is the source of truth; ``profile_result.pii_flag`` is ignored.""" + tg = _mock_table_group(project_code="demo") + mock_tg_cls.get.return_value = tg + # ProfileResult carries a stale/wrong pii_flag; DataColumnChars says None. + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + pii_flag="A/CONTACT/Email", # stale value; should NOT drive redaction + top_freq_values="| alice@example.com | 5", + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column(pii_flag=None)] + + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "email") + + # No redaction, no PII field — because DataColumnChars says the column is not PII. + assert PII_REDACTED not in result + assert "alice@example.com" in result + assert "PII" not in result.splitlines()[1:6] # no "PII:" field in the header block + + +# ---------------------------------------------------------------------- +# get_column_patterns +# ---------------------------------------------------------------------- + + +@patch.object(DataColumnChars, "select_where") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_patterns_happy_path( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + general_type="A", + top_patterns="326 | Aaaaaa | 176 | AAA", + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column()] + + from testgen.mcp.tools.profiling import get_column_patterns + result = get_column_patterns(str(uuid4()), "customers", "country") + + assert "Character patterns: customers.country" in result + assert "Aaaaaa" in result and "AAA" in result + assert "Top patterns" in result + + +@patch.object(DataColumnChars, "select_where") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_patterns_non_string_column_fallback( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + general_type="N", + top_patterns=None, + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column()] + + from testgen.mcp.tools.profiling import get_column_patterns + result = get_column_patterns(str(uuid4()), "products", "price") + + assert "column is not a string type" in result + + +@patch.object(DataColumnChars, "select_where") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_patterns_high_cardinality_fallback( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + general_type="A", + top_patterns=None, + distinct_value_ct=9999, + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column()] + + from testgen.mcp.tools.profiling import get_column_patterns + result = get_column_patterns(str(uuid4()), "customers", "address") + + assert "Pattern data not available" in result + assert "9999" in result + + +# ---------------------------------------------------------------------- +# search_columns +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +def test_search_columns_no_scope_uses_all_accessible_projects(mock_dcc_cls, db_session_mock): + mock_dcc_cls.search_by_name.return_value = ([], 0) + mock_dcc_cls.summarize_matches_by_project.return_value = [] + + from testgen.mcp.tools.profiling import search_columns + result = search_columns("email") + + assert "all accessible projects" in result or "No columns matching" in result + + +@patch.object(DataColumnChars, "search_by_name") +@patch("testgen.mcp.tools.common.TableGroup") +def test_search_columns_table_group_scope_passes_tg_id_clause(mock_tg_cls, mock_method, db_session_mock): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import search_columns + search_columns("email", table_group_id=str(uuid4())) + + sql = " ".join( + str(c.compile(compile_kwargs={"literal_binds": True})) for c in mock_method.call_args[0] + ) + assert "table_groups_id" in sql + + +def test_search_columns_rejects_both_scopes_passed(db_session_mock): + from testgen.mcp.tools.profiling import search_columns + with pytest.raises(MCPUserError, match="not both"): + search_columns("email", project_code="demo", table_group_id=str(uuid4())) + + +def test_search_columns_empty_pattern_rejected(db_session_mock): + from testgen.mcp.tools.profiling import search_columns + with pytest.raises(MCPUserError, match="`pattern` is required"): + search_columns(" ") + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +def test_search_columns_renders_per_project_summary_when_no_scope(mock_dcc_cls, db_session_mock): + hit = MagicMock() + hit.project_code = "DEFAULT" + hit.table_groups_name = "default" + hit.schema_name = "demo" + hit.table_name = "d_ebike_suppliers" + hit.column_name = "contact_email" + mock_dcc_cls.search_by_name.return_value = ([hit], 1) + mock_dcc_cls.summarize_matches_by_project.return_value = [("DEFAULT", 1), ("DEMO_2", 0)] + + from testgen.mcp.tools.profiling import search_columns + result = search_columns("email") + + assert "Matches by project" in result + assert "DEFAULT" in result + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_search_columns_table_group_scope_skips_per_project_summary( + mock_tg_cls, mock_dcc_cls, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + hit = MagicMock() + hit.project_code = "demo" + hit.table_groups_name = "default" + hit.schema_name = "demo" + hit.table_name = "customers" + hit.column_name = "email" + mock_dcc_cls.search_by_name.return_value = ([hit], 1) + + from testgen.mcp.tools.profiling import search_columns + result = search_columns("email", table_group_id=str(uuid4())) + + assert "Matches by project" not in result + mock_dcc_cls.summarize_matches_by_project.assert_not_called() diff --git a/tests/unit/mcp/test_tools_quality_scores.py b/tests/unit/mcp/test_tools_quality_scores.py new file mode 100644 index 00000000..62dd31c4 --- /dev/null +++ b/tests/unit/mcp/test_tools_quality_scores.py @@ -0,0 +1,2778 @@ +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from testgen.common.models.scores import ( + ScoreCategory, + ScoreDefinition, + ScoreDefinitionBreakdownItem, + ScoreDefinitionCriteria, +) +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError +from testgen.mcp.permissions import ProjectPermissions +from testgen.mcp.tools.quality_scores import _format_criteria_summary + +pytestmark = pytest.mark.unit + + +# --- Helpers --- + + +def _score_card( + score=0.9, + cde_score=0.8, + profiling_score=0.95, + testing_score=0.85, + categories=None, +): + """Default ScoreCard dict returned by ScoreDefinition.as_score_card().""" + return { + "id": uuid4(), + "project_code": "demo", + "name": "test", + "score": score, + "cde_score": cde_score, + "profiling_score": profiling_score, + "testing_score": testing_score, + "categories": categories or [], + "history": [], + "definition": None, + } + + +def _patch_perms(allowed=("demo",), memberships=None): + """Return a patch context manager that injects a ProjectPermissions with given access.""" + memberships = memberships or dict.fromkeys(allowed, "role_a") + return patch( + "testgen.mcp.permissions._compute_project_permissions", + return_value=ProjectPermissions( + memberships=memberships, permission="view", username="test_user", + ), + ) + + +# --- Argument validation --- + + +def test_mutually_exclusive_scope_args_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="project_code.*table_group_id"): + get_quality_scores(project_code="demo", table_group_id=str(uuid4())) + + +def test_invalid_group_by_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Invalid group_by") as exc_info: + get_quality_scores(project_code="demo", group_by="invented_field") + msg = str(exc_info.value) + # Error message must speak the user-facing vocabulary. + assert "Quality Dimension" in msg + + +@pytest.mark.parametrize("group_by", ["column_name", "table_name", "dq_dimension"]) +def test_internal_group_by_value_rejected(group_by, db_session_mock): + """Old internal column names (row-level or column-form) are no longer accepted.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Invalid group_by"): + get_quality_scores(project_code="demo", group_by=group_by) + + +def test_invalid_score_type_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Invalid score_type") as exc_info: + get_quality_scores(project_code="demo", score_type="garbage") + msg = str(exc_info.value) + assert "Total" in msg + assert "CDE" in msg + + +@pytest.mark.parametrize("internal", ["total", "cde"]) +def test_internal_score_type_rejected(internal, db_session_mock): + """``total``/``cde`` were the old internal codes — inputs now use ``Total``/``CDE``.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Invalid score_type"): + get_quality_scores(project_code="demo", score_type=internal) + + +def test_project_not_accessible_rejected(db_session_mock): + """A project the user can't view raises MCPResourceNotAccessible-style error.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(allowed=("only_this",)), pytest.raises(MCPResourceNotAccessible, match="forbidden_proj"): + get_quality_scores(project_code="forbidden_proj") + + +# --- Score-type → model-call mapping --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_default_overall_shows_both_total_and_cde(mock_definition_cls, db_session_mock): + """score_type omitted → both Total and CDE Score lines are rendered.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.93, cde_score=0.81) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "Total Score" in out + assert "93" in out + assert "CDE Score" in out + assert "81" in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_total_overall_shows_only_total(mock_definition_cls, db_session_mock): + """score_type='Total' renders only the Total Score line.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.93, cde_score=None) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Total", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "Total Score" in out + assert "93" in out + assert "CDE Score" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_cde_overall_shows_only_cde(mock_definition_cls, db_session_mock): + """score_type='CDE' renders only the CDE Score line.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=None, cde_score=0.81) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="CDE", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "CDE Score" in out + assert "81" in out + assert "Total Score" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_default_overall_includes_profiling_and_testing(mock_definition_cls, db_session_mock): + """score_type omitted → overall block surfaces Total, CDE, Profiling, + and Testing — same set the UI's score-card shows when Total is enabled.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card( + score=0.93, cde_score=0.81, profiling_score=0.95, testing_score=0.85, + ) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "Total Score" in out + assert "CDE Score" in out + assert "Profiling Score" in out + assert "Testing Score" in out + assert "95" in out + assert "85" in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_total_overall_includes_profiling_and_testing(mock_definition_cls, db_session_mock): + """score_type='Total' → Total + Profiling + Testing render; CDE omitted.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card( + score=0.93, cde_score=None, profiling_score=0.95, testing_score=0.85, + ) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Total", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "Total Score" in out + assert "Profiling Score" in out + assert "Testing Score" in out + assert "CDE Score" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_cde_overall_omits_profiling_and_testing(mock_definition_cls, db_session_mock): + """score_type='CDE' → Profiling/Testing must not appear even if the score + card returns values for them (matches UI's Total-only gating).""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card( + score=None, cde_score=0.81, profiling_score=0.95, testing_score=0.85, + ) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="CDE", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "CDE Score" in out + assert "Total Score" not in out + assert "Profiling Score" not in out + assert "Testing Score" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_total_grouped_uses_breakdown(mock_definition_cls, db_session_mock): + """score_type='Total' + group_by sources per-category rows from breakdown. + + Per-category output always includes Impact (matching the Score Explorer UI), + so the tool reads from get_score_card_breakdown rather than card.categories. + """ + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.91, "issue_ct": 4, "impact": 3.2}, + {"business_domain": "Marketing", "score": 0.74, "issue_ct": 11, "impact": 9.8}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Total", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "warehouse"}], + include_impact=True, + ) + + mock_definition.get_score_card_breakdown.assert_called_once_with("score", "business_domain") + assert "Finance" in out + assert "Marketing" in out + assert "Impact on Total Score" in out + assert "Impact on CDE Score" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_cde_grouped_uses_breakdown(mock_definition_cls, db_session_mock): + """score_type='CDE' + group_by sources per-category rows from breakdown.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=None, cde_score=0.72) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.80, "issue_ct": 2, "impact": 1.5}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="CDE", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "warehouse"}], + include_impact=True, + ) + + mock_definition.get_score_card_breakdown.assert_called_once_with("cde_score", "business_domain") + assert "Finance" in out + assert "Impact on CDE Score" in out + assert "Impact on Total Score" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_default_grouped_renders_both_score_columns(mock_definition_cls, db_session_mock): + """score_type omitted + group_by → table has Total + CDE columns and + Impact columns for both, populated from two breakdown calls. + """ + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9, cde_score=0.7) + + breakdown_results = { + "score": [ + {"business_domain": "Finance", "score": 0.91, "issue_ct": 4, "impact": 3.2}, + {"business_domain": "Marketing", "score": 0.74, "issue_ct": 12, "impact": 11.4}, + ], + "cde_score": [ + {"business_domain": "Finance", "score": 0.85, "issue_ct": 3, "impact": 5.0}, + {"business_domain": "Marketing", "score": 0.60, "issue_ct": 8, "impact": 12.0}, + ], + } + mock_definition.get_score_card_breakdown.side_effect = ( + lambda score_key, _col: breakdown_results[score_key] + ) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "warehouse"}], + include_impact=True, + ) + + # Both score types → two breakdown calls + assert mock_definition.get_score_card_breakdown.call_count == 2 + call_keys = {c.args[0] for c in mock_definition.get_score_card_breakdown.call_args_list} + assert call_keys == {"score", "cde_score"} + + assert "Total Score" in out + assert "CDE Score" in out + assert "Impact on Total Score" in out + assert "Impact on CDE Score" in out + assert "Finance" in out + assert "Marketing" in out + + +# --- include_impact --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_impact_default_false_omits_impact_columns(mock_definition_cls, db_session_mock): + """Default include_impact=False → grouped output has no Impact columns.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9, cde_score=0.7) + breakdown_results = { + "score": [{"business_domain": "Finance", "score": 0.91, "issue_ct": 4, "impact": 3.2}], + "cde_score": [{"business_domain": "Finance", "score": 0.85, "issue_ct": 3, "impact": 5.0}], + } + mock_definition.get_score_card_breakdown.side_effect = ( + lambda score_key, _col: breakdown_results[score_key] + ) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "wh"}], + ) + + assert "Finance" in out + assert "Total Score" in out + assert "CDE Score" in out + assert "Impact" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_impact_false_total_only_omits_impact_column(mock_definition_cls, db_session_mock): + """Total-only + default include_impact=False → no impact column.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.91, "issue_ct": 4, "impact": 3.2}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Total", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "wh"}], + ) + + assert "Finance" in out + assert "Total Score" in out + assert "Impact" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_impact_false_cde_only_omits_impact_column(mock_definition_cls, db_session_mock): + """CDE-only + default include_impact=False → no impact column.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=None, cde_score=0.7) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.8, "issue_ct": 3, "impact": 2.0}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="CDE", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "wh"}], + ) + + assert "Finance" in out + assert "CDE Score" in out + assert "Impact" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_impact_false_overall_unaffected(mock_definition_cls, db_session_mock): + """include_impact only affects grouped output — overall block is unchanged.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.93, cde_score=0.81) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out_default = get_quality_scores( + project_code="demo", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + out_with_impact = get_quality_scores( + project_code="demo", + include_impact=True, + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + # No group_by → impact has no rendering surface either way. + assert "Impact" not in out_default + assert "Impact" not in out_with_impact + + +# --- include_issue_ct --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_issue_ct_overall_calls_get_overall_issue_ct(mock_definition_cls, db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_overall_issue_ct.return_value = 42 + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + include_issue_ct=True, + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + mock_definition.get_overall_issue_ct.assert_called_once_with() + assert "Issue Count" in out + assert "42" in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_issue_ct_grouped_total_uses_simple_label(mock_definition_cls, db_session_mock): + """grouped + Total + include_issue_ct: single 'Issue Count' column header.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.91, "issue_ct": 7, "impact": 4.0}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Total", + group_by="Business Domain", + include_issue_ct=True, + filters=[{"field": "Data Source", "value": "wh"}], + ) + + mock_definition.get_score_card_breakdown.assert_called_once_with("score", "business_domain") + assert "Finance" in out + assert "7" in out + assert "Issue Count" in out + assert "Issue Count (Total)" not in out + assert "Issue Count (CDE)" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_issue_ct_grouped_cde_uses_simple_label(mock_definition_cls, db_session_mock): + """grouped + CDE + include_issue_ct: single 'Issue Count' column header.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=None, cde_score=0.7) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.8, "issue_ct": 3, "impact": 2.0}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="CDE", + group_by="Business Domain", + include_issue_ct=True, + filters=[{"field": "Data Source", "value": "wh"}], + ) + + mock_definition.get_score_card_breakdown.assert_called_once_with("cde_score", "business_domain") + assert "Finance" in out + assert "3" in out + assert "Issue Count" in out + assert "Issue Count (Total)" not in out + assert "Issue Count (CDE)" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_issue_ct_grouped_default_uses_parenthetical_labels(mock_definition_cls, db_session_mock): + """grouped + score_type unset + include_issue_ct: separate Total / CDE + issue-count columns, and both Impact columns.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9, cde_score=0.7) + breakdown_results = { + "score": [{"business_domain": "Finance", "score": 0.91, "issue_ct": 7, "impact": 4.0}], + "cde_score": [{"business_domain": "Finance", "score": 0.80, "issue_ct": 3, "impact": 2.0}], + } + mock_definition.get_score_card_breakdown.side_effect = ( + lambda score_key, _col: breakdown_results[score_key] + ) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + group_by="Business Domain", + include_issue_ct=True, + include_impact=True, + filters=[{"field": "Data Source", "value": "wh"}], + ) + + assert mock_definition.get_score_card_breakdown.call_count == 2 + assert "Issue Count (Total)" in out + assert "Issue Count (CDE)" in out + assert "Impact on Total Score" in out + assert "Impact on CDE Score" in out + # Both per-category issue counts must appear, not just one + assert "7" in out # total count + assert "3" in out # cde count + + +# --- Filter semantics passed to the model --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinitionCriteria") +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_filters_passed_to_from_filters(mock_definition_cls, mock_criteria_cls, db_session_mock): + """User filters are handed to ScoreDefinitionCriteria.from_filters.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + get_quality_scores( + project_code="demo", + filters=[ + {"field": "Business Domain", "value": "Finance"}, + {"field": "Business Domain", "value": "Marketing"}, + {"field": "Data Source", "value": "warehouse"}, + ], + ) + + # from_filters receives the translated DB column names — the parser + # converts user-facing labels to internal column names before this call. + mock_criteria_cls.from_filters.assert_called_once() + args, kwargs = mock_criteria_cls.from_filters.call_args + passed = args[0] + assert {"field": "business_domain", "value": "Finance"} in passed + assert {"field": "business_domain", "value": "Marketing"} in passed + assert {"field": "data_source", "value": "warehouse"} in passed + assert kwargs.get("group_by_field") is True + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinitionCriteria") +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +@patch("testgen.mcp.tools.common.TableGroup") +def test_table_group_adds_implicit_name_filter( + mock_tg_cls, mock_definition_cls, mock_criteria_cls, db_session_mock, +): + """When table_group_id is passed, the resolved TG's name is added as a filter.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + tg = MagicMock() + tg.id = uuid4() + tg.project_code = "demo" + tg.table_groups_name = "orders" + mock_tg_cls.get.return_value = tg + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + get_quality_scores(table_group_id=str(tg.id)) + + args, _ = mock_criteria_cls.from_filters.call_args + passed = args[0] + assert {"field": "table_groups_name", "value": "orders"} in passed + + +# --- Cross-project loop --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_cross_project_renders_per_project_sections(mock_definition_cls, db_session_mock): + """No project_code, no table_group_id → one H2 section per accessible project.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + # Pass at least one filter so the tool doesn't fall into the + # "enumerate every table group in the project" branch (which would need + # `TableGroup.select_minimal_where` mocked). + with _patch_perms(allowed=("proj_a", "proj_b")): + out = get_quality_scores( + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "proj_a" in out + assert "proj_b" in out + # `as_score_card` should have been called once per project. + assert mock_definition.as_score_card.call_count == 2 + + +@patch("testgen.mcp.tools.quality_scores.TableGroup") +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_unfiltered_project_enumerates_table_groups(mock_definition_cls, mock_tg_cls, db_session_mock): + """Unfiltered project_code call enumerates table groups so as_score_card's + has_filters() gate passes (mirrors the score-explorer UI default).""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + tg1 = MagicMock() + tg1.table_groups_name = "orders" + tg2 = MagicMock() + tg2.table_groups_name = "customers" + mock_tg_cls.select_minimal_where.return_value = [tg1, tg2] + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + get_quality_scores(project_code="demo") + + # Verify TableGroup.select_minimal_where was called for enumeration. + mock_tg_cls.select_minimal_where.assert_called_once() + + +# --- Row cap --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_grouped_row_cap_truncates_and_footers(mock_definition_cls, db_session_mock): + """At >_ROW_CAP category rows, render only top N and surface the cap in a footer.""" + from testgen.mcp.tools.quality_scores import _ROW_CAP, get_quality_scores + + breakdown_rows = [ + {"business_domain": f"L{i}", "score": 0.5 + i * 0.001, "issue_ct": 1, "impact": 0.1} + for i in range(_ROW_CAP + 10) + ] + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = breakdown_rows + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Total", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "wh"}], + ) + + assert f"Showing top {_ROW_CAP}" in out + assert str(_ROW_CAP + 10) in out + + +# --- Empty-breakdown messaging differs based on whether filters were supplied --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_grouped_empty_breakdown_with_filters_renders_filter_matched(mock_definition_cls, db_session_mock): + """User-supplied filter that returns no breakdown rows surfaces 'Filter matched no data.'""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = [] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + group_by="Business Domain", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "Filter matched no data" in out + assert "No category data" not in out + + +@patch("testgen.mcp.tools.quality_scores.TableGroup") +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_grouped_empty_breakdown_without_filters_renders_no_category_data( + mock_definition_cls, mock_tg_cls, db_session_mock, +): + """Unfiltered project with no breakdown rows keeps the generic 'No category data.' message.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + tg = MagicMock() + tg.table_groups_name = "orders" + mock_tg_cls.select_minimal_where.return_value = [tg] + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = [] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + group_by="Business Domain", + ) + + assert "No category data" in out + assert "Filter matched no data" not in out + + +# --- Transient definition is never persisted --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_transient_definition_never_persisted(mock_definition_cls, db_session_mock): + """Hardening test: the MCP tool never calls .save() on its transient definition.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + get_quality_scores( + project_code="demo", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + mock_definition.save.assert_not_called() + + +# ============================================================ +# Scorecard tools — merged in from test_tools_scorecards.py +# ============================================================ + +def _criteria(filters: list[dict], group_by_field: bool = True) -> ScoreDefinitionCriteria: + return ScoreDefinitionCriteria.from_filters(filters, group_by_field=group_by_field) + + + + +def _fake_definition( + name: str, + *, + project_code: str = "demo", + total: bool = True, + cde: bool = False, + category: ScoreCategory | None = None, + filters: list[dict] | None = None, + group_by_field: bool = True, + score: float | None = 0.95, + cde_value: float | None = 0.90, +) -> ScoreDefinition: + sd = ScoreDefinition() + sd.id = uuid4() + sd.project_code = project_code + sd.name = name + sd.total_score = total + sd.cde_score = cde + sd.category = category + sd.criteria = ScoreDefinitionCriteria.from_filters( + filters or [{"field": "table_groups_name", "value": "tg1"}], + group_by_field=group_by_field, + ) + sd._fake_card = {"score": score, "cde_score": cde_value} + return sd + + +@pytest.fixture +def patch_card(monkeypatch): + """Route as_cached_score_card to the stub stored on each fake definition.""" + def _cached(self, include_definition: bool = False): + return self._fake_card + monkeypatch.setattr(ScoreDefinition, "as_cached_score_card", _cached) + + +def _patch_list(items, total): + return patch.object(ScoreDefinition, "list_for_project", return_value=(items, total)) + + +# --- _format_criteria_summary --- + + +def test_format_criteria_summary_none(): + assert _format_criteria_summary(None) == "(no filters)" + + +def test_format_criteria_summary_empty(): + criteria = ScoreDefinitionCriteria(operand="AND", filters=[], group_by_field=True) + assert _format_criteria_summary(criteria) == "(no filters)" + + +def test_format_criteria_summary_single_filter_uses_display_label(): + criteria = _criteria([{"field": "table_groups_name", "value": "sales"}]) + assert _format_criteria_summary(criteria) == "Table Group = sales" + + +def test_format_criteria_summary_or_within_field(): + """group_by_field=True with multiple roots on the same field renders as `in (...)`.""" + criteria = _criteria([ + {"field": "table_groups_name", "value": "sales"}, + {"field": "table_groups_name", "value": "marketing"}, + ]) + assert _format_criteria_summary(criteria) == "Table Group in (sales, marketing)" + + +def test_format_criteria_summary_and_across_fields(): + criteria = _criteria([ + {"field": "table_groups_name", "value": "sales"}, + {"field": "business_domain", "value": "Finance"}, + ]) + # Ordering is alphabetical by display label for stable output. + assert _format_criteria_summary(criteria) == "Business Domain = Finance AND Table Group = sales" + + +def test_format_criteria_summary_chained_next_filter(): + """A root filter with `others` becomes a next_filter AND-chain inside the root.""" + criteria = ScoreDefinitionCriteria.from_filters( + [{ + "field": "table_groups_name", + "value": "sales", + "others": [{"field": "business_domain", "value": "Finance"}], + }], + group_by_field=True, + ) + summary = _format_criteria_summary(criteria) + assert "Table Group = sales" in summary + assert "Business Domain = Finance" in summary + assert " AND " in summary + + +def test_format_criteria_summary_unknown_field_falls_back_to_raw_column(): + criteria = _criteria([{"field": "made_up_column", "value": "x"}]) + assert _format_criteria_summary(criteria) == "made_up_column = x" + + +def test_format_criteria_summary_mode_2_chained_uses_table_label(): + """A chain into table_name renders the user-facing "Table" label, not the column name.""" + criteria = ScoreDefinitionCriteria.from_filters( + [{ + "field": "table_groups_name", + "value": "redbox", + "others": [{"field": "table_name", "value": "accounts"}], + }], + group_by_field=False, + ) + summary = _format_criteria_summary(criteria) + assert "Table Group = redbox" in summary + assert "Table = accounts" in summary + assert "table_name" not in summary + + +def test_format_criteria_summary_mode_2_chained_uses_column_label(): + criteria = ScoreDefinitionCriteria.from_filters( + [{ + "field": "table_groups_name", + "value": "redbox", + "others": [ + {"field": "table_name", "value": "accounts"}, + {"field": "column_name", "value": "id"}, + ], + }], + group_by_field=False, + ) + summary = _format_criteria_summary(criteria) + assert "Column = id" in summary + assert "column_name" not in summary + + +def test_format_criteria_summary_mode_2_sibling_chains_collapse_to_in(): + """Chains sharing the same root (table_groups_name=X) collapse to `Table in (...)`.""" + criteria = ScoreDefinitionCriteria.from_filters( + [ + {"field": "table_groups_name", "value": "redbox", + "others": [{"field": "table_name", "value": "a"}]}, + {"field": "table_groups_name", "value": "redbox", + "others": [{"field": "table_name", "value": "b"}]}, + {"field": "table_groups_name", "value": "redbox", + "others": [{"field": "table_name", "value": "c"}]}, + ], + group_by_field=False, + ) + summary = _format_criteria_summary(criteria) + assert summary == "Table Group = redbox AND Table in (a, b, c)" + + +def test_format_criteria_summary_mode_2_different_roots_or_joined(): + """Chains with different table_groups_name roots are OR-joined (not AND-joined).""" + criteria = ScoreDefinitionCriteria.from_filters( + [ + {"field": "table_groups_name", "value": "redbox", + "others": [{"field": "table_name", "value": "a"}]}, + {"field": "table_groups_name", "value": "sales", + "others": [{"field": "table_name", "value": "b"}]}, + ], + group_by_field=False, + ) + summary = _format_criteria_summary(criteria) + assert " OR " in summary + assert " AND " not in summary.replace(" AND Table = ", "") # AND only inside a chain + assert "redbox" in summary + assert "sales" in summary + + +# --- list_scorecards tool --- + + +def test_list_scorecards_requires_view_access(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + with _patch_perms(allowed=("only_this",)), pytest.raises( + MCPResourceNotAccessible, match="forbidden_proj" + ): + list_scorecards("forbidden_proj") + + +def test_list_scorecards_empty_renders_friendly_message(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + with _patch_perms(), _patch_list([], 0): + out = list_scorecards("demo") + assert "Scorecards in Project `demo`" in out + assert "_No scorecards configured._" in out + + +def test_list_scorecards_renders_total_and_cde(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [ + _fake_definition( + "Sales Quality", + total=True, + cde=True, + category=ScoreCategory.dq_dimension, + filters=[{"field": "table_groups_name", "value": "sales"}], + score=0.95, + cde_value=0.90, + ), + ] + with _patch_perms(), _patch_list(items, 1): + out = list_scorecards("demo") + assert "Sales Quality" in out + assert "Total Score" in out + assert "CDE Score" in out + assert "Quality Dimension" in out # display label for dq_dimension + assert "Table Group = sales" in out + assert "0.95" in out or "95" in out + assert "0.90" in out or "90" in out + + +def test_list_scorecards_hides_cde_when_disabled(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [_fake_definition("Only Total", total=True, cde=False, cde_value=None)] + with _patch_perms(), _patch_list(items, 1): + out = list_scorecards("demo") + assert "Total Score" in out + assert "CDE Score" not in out + + +def test_list_scorecards_hides_total_when_disabled(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [_fake_definition("CDE Only", total=False, cde=True, score=None, cde_value=0.50)] + with _patch_perms(), _patch_list(items, 1): + out = list_scorecards("demo") + assert "CDE Score" in out + assert "Total Score" not in out + + +def test_list_scorecards_includes_profiling_and_testing_when_total_enabled(db_session_mock, patch_card): + """When total_score is enabled, the per-scorecard block surfaces Profiling + Score and Testing Score — matching the UI's score-card and get_scorecard.""" + from testgen.mcp.tools.quality_scores import list_scorecards + + sd = _fake_definition( + "Full Card", + total=True, + cde=True, + score=0.925, + cde_value=0.880, + ) + sd._fake_card.update({"profiling_score": 0.950, "testing_score": 0.900}) + with _patch_perms(), _patch_list([sd], 1): + out = list_scorecards("demo") + assert "Profiling Score" in out + assert "Testing Score" in out + # friendly_score scales by 100 and rounds to 1 decimal. + assert "95" in out + assert "90" in out + + +def test_list_scorecards_omits_profiling_and_testing_for_cde_only_scorecard(db_session_mock, patch_card): + """When total_score is disabled, Profiling/Testing must not appear even + though as_cached_score_card may return values for them.""" + from testgen.mcp.tools.quality_scores import list_scorecards + + sd = _fake_definition("CDE Only", total=False, cde=True, score=None, cde_value=0.50) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + with _patch_perms(), _patch_list([sd], 1): + out = list_scorecards("demo") + assert "CDE Score" in out + assert "Profiling Score" not in out + assert "Testing Score" not in out + + +def test_list_scorecards_omits_breakdown_when_no_category(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [_fake_definition("Plain", category=None)] + with _patch_perms(), _patch_list(items, 1): + out = list_scorecards("demo") + assert "Category" not in out + + +def test_list_scorecards_emits_pagination_info_and_footer(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [_fake_definition(f"Card {i}") for i in range(3)] + with _patch_perms(), _patch_list(items, 25): + out = list_scorecards("demo", page=1, limit=3) + # format_page_info emits an en-dash (\u2013) between start and end. + assert "Showing 1\u20133 of 25" in out + assert "Use `page=2` for more" in out + + +def test_list_scorecards_empty_page_past_end(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + with _patch_perms(), _patch_list([], 3): + out = list_scorecards("demo", page=5, limit=10) + # No-scorecards-on-page message references current page + total + assert "page 5" in out + assert "total: 3" in out + + +@pytest.mark.parametrize("page,limit", [(0, 10), (1, 0), (1, 101)]) +def test_list_scorecards_rejects_invalid_pagination(db_session_mock, patch_card, page, limit): + from testgen.mcp.tools.quality_scores import list_scorecards + + with _patch_perms(), pytest.raises(MCPUserError): + list_scorecards("demo", page=page, limit=limit) + + +def test_list_scorecards_renders_filter_chain(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [_fake_definition( + "Multi-filter", + filters=[ + {"field": "table_groups_name", "value": "sales"}, + {"field": "business_domain", "value": "Finance"}, + ], + )] + with _patch_perms(), _patch_list(items, 1): + out = list_scorecards("demo") + assert "Business Domain = Finance" in out + assert "Table Group = sales" in out + assert " AND " in out + + +# --- get_scorecard tool --- + + +def _fake_breakdown_item( + *, + category: str, + score_type: str, + field_values: dict, + impact: float = 0.5, + score: float = 0.85, + issue_ct: int = 3, +): + """Build a fake `ScoreDefinitionBreakdownItem`-like object exposing ``.to_dict()``. + + Matches the shape produced by the real ``to_dict`` — category-specific fields + plus ``impact``, ``score``, ``issue_ct``. + """ + item = MagicMock(spec=ScoreDefinitionBreakdownItem) + item.category = category + item.score_type = score_type + item.to_dict = MagicMock(return_value={ + **field_values, + "impact": impact, + "score": score, + "issue_ct": issue_ct, + }) + return item + + +def _patch_get(definition): + return patch.object(ScoreDefinition, "get", return_value=definition) + + +def _patch_breakdown(items): + return patch.object(ScoreDefinitionBreakdownItem, "filter", return_value=items) + + +def _patch_breakdown_by_score_type(total, cde): + """Return different breakdown rows depending on the requested ``score_type``.""" + def _filter(*, definition_id, category, score_type): + return total if score_type == "score" else cde + return patch.object(ScoreDefinitionBreakdownItem, "filter", side_effect=_filter) + + +def test_get_scorecard_rejects_invalid_uuid(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + get_scorecard("not-a-uuid") + + +def test_get_scorecard_unknown_id_returns_not_accessible(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + missing_id = str(uuid4()) + with _patch_perms(), _patch_get(None), pytest.raises( + MCPResourceNotAccessible, match=missing_id + ): + get_scorecard(missing_id) + + +def test_get_scorecard_forbidden_project_returns_not_accessible(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition("Other-project card", project_code="forbidden_proj") + with _patch_perms(allowed=("demo",)), _patch_get(sd), pytest.raises( + MCPResourceNotAccessible, match=str(sd.id) + ): + get_scorecard(str(sd.id)) + + +def test_get_scorecard_renders_overall_scores(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "Sales Quality", + total=True, + cde=True, + category=ScoreCategory.dq_dimension, + score=0.95, + cde_value=0.90, + ) + sd._fake_card.update({"profiling_score": 0.88, "testing_score": 0.91}) + with _patch_perms(), _patch_get(sd), _patch_breakdown([]): + out = get_scorecard(str(sd.id)) + assert "Sales Quality" in out + assert "Total Score" in out + assert "CDE Score" in out + assert "Profiling Score" in out + assert "Testing Score" in out + # Filter summary is preserved from list_scorecards behavior. + assert "Table Group = tg1" in out + + +def test_get_scorecard_hides_total_when_disabled(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "CDE-Only Card", + total=False, + cde=True, + category=None, + score=None, + cde_value=0.5, + ) + with _patch_perms(), _patch_get(sd), _patch_breakdown([]): + out = get_scorecard(str(sd.id)) + assert "CDE Score" in out + assert "Total Score" not in out + # Profiling/Testing are components of the Total score — should be hidden too. + assert "Profiling Score" not in out + assert "Testing Score" not in out + + +def test_get_scorecard_hides_cde_when_disabled(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "Total-Only Card", + total=True, + cde=False, + category=None, + cde_value=None, + ) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + with _patch_perms(), _patch_get(sd), _patch_breakdown([]): + out = get_scorecard(str(sd.id)) + assert "Total Score" in out + assert "Profiling Score" in out + assert "Testing Score" in out + assert "CDE Score" not in out + + +def test_get_scorecard_omits_breakdown_when_no_category(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition("Plain", category=None) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + with _patch_perms(), _patch_get(sd), _patch_breakdown([]): + out = get_scorecard(str(sd.id)) + assert "Category" not in out + + +def test_get_scorecard_renders_breakdown_wide_table(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "Wide breakdown", + total=True, + cde=True, + category=ScoreCategory.dq_dimension, + ) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + + total_items = [ + _fake_breakdown_item( + category="dq_dimension", + score_type="score", + field_values={"dq_dimension": "Accuracy"}, + impact=0.4, + score=0.6, + issue_ct=10, + ), + ] + cde_items = [ + _fake_breakdown_item( + category="dq_dimension", + score_type="cde_score", + field_values={"dq_dimension": "Accuracy"}, + impact=0.3, + score=0.7, + issue_ct=5, + ), + ] + + with ( + _patch_perms(), + _patch_get(sd), + _patch_breakdown_by_score_type(total_items, cde_items), + ): + out = get_scorecard(str(sd.id)) + assert "Breakdown by Quality Dimension" in out + assert "Accuracy" in out + # Both score types in headers — parenthetical disambiguates which column is which. + assert "Issue Count (Total)" in out + assert "Issue Count (CDE)" in out + assert "Impact on Total Score" in out + assert "Impact on CDE Score" in out + + +def test_get_scorecard_breakdown_single_score_type(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "Single-type breakdown", + total=True, + cde=False, + category=ScoreCategory.business_domain, + cde_value=None, + ) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + + items = [ + _fake_breakdown_item( + category="business_domain", + score_type="score", + field_values={"business_domain": "Finance"}, + ), + ] + with _patch_perms(), _patch_get(sd), _patch_breakdown(items): + out = get_scorecard(str(sd.id)) + assert "Breakdown by Business Domain" in out + assert "Finance" in out + # When only one type is enabled, headers drop the parenthetical (mirrors get_quality_scores). + assert "Issue Count (Total)" not in out + assert "Issue Count (CDE)" not in out + + +def test_get_scorecard_breakdown_caps_at_100(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "Many rows", + total=True, + cde=False, + category=ScoreCategory.business_domain, + cde_value=None, + ) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + + items = [ + _fake_breakdown_item( + category="business_domain", + score_type="score", + field_values={"business_domain": f"Domain {i}"}, + impact=0.5 - 0.001 * i, + ) + for i in range(101) + ] + with _patch_perms(), _patch_get(sd), _patch_breakdown(items): + out = get_scorecard(str(sd.id)) + assert "Showing top 100 of 101" in out + + +def test_get_scorecard_breakdown_empty(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "No data", + total=True, + cde=False, + category=ScoreCategory.dq_dimension, + cde_value=None, + ) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + with _patch_perms(), _patch_get(sd), _patch_breakdown([]): + out = get_scorecard(str(sd.id)) + assert "Breakdown by Quality Dimension" in out + assert "_No breakdown data._" in out + + +# --- delete_scorecard tool --- + + +def test_delete_scorecard_rejects_invalid_uuid(db_session_mock): + from testgen.mcp.tools.quality_scores import delete_scorecard + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + delete_scorecard("not-a-uuid") + + +def test_delete_scorecard_unknown_id_returns_not_accessible(db_session_mock): + from testgen.mcp.tools.quality_scores import delete_scorecard + + missing_id = str(uuid4()) + with _patch_perms(), _patch_get(None), pytest.raises( + MCPResourceNotAccessible, match=missing_id + ): + delete_scorecard(missing_id) + + +def test_delete_scorecard_forbidden_project_does_not_call_delete(db_session_mock): + from testgen.mcp.tools.quality_scores import delete_scorecard + + sd = _fake_definition("Other-project card", project_code="forbidden_proj") + with ( + _patch_perms(allowed=("demo",)), + _patch_get(sd), + patch.object(ScoreDefinition, "delete") as mock_delete, + pytest.raises(MCPResourceNotAccessible, match=str(sd.id)), + ): + delete_scorecard(str(sd.id)) + assert mock_delete.called is False + + +def test_delete_scorecard_calls_model_delete(db_session_mock): + from testgen.mcp.tools.quality_scores import delete_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + patch.object(ScoreDefinition, "delete") as mock_delete, + ): + delete_scorecard(str(sd.id)) + mock_delete.assert_called_once() + + +def test_delete_scorecard_returns_confirmation_with_name_id_project(db_session_mock): + from testgen.mcp.tools.quality_scores import delete_scorecard + + sd = _fake_definition("Sales Quality", project_code="demo") + with _patch_perms(), _patch_get(sd), patch.object(ScoreDefinition, "delete"): + out = delete_scorecard(str(sd.id)) + assert "Sales Quality" in out + assert str(sd.id) in out + assert "demo" in out + + +# --- update_scorecard tool --- + + +def _patch_orchestrator(): + """Stub the persist+refresh orchestrator so unit tests don't hit the DB.""" + return patch("testgen.mcp.tools.quality_scores.save_and_refresh_score_definition") + + +def test_update_scorecard_rejects_invalid_uuid(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + update_scorecard("not-a-uuid", name="x") + + +def test_update_scorecard_unknown_id_returns_not_accessible(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + missing_id = str(uuid4()) + with _patch_perms(), _patch_get(None), pytest.raises( + MCPResourceNotAccessible, match=missing_id + ): + update_scorecard(missing_id, name="x") + + +def test_update_scorecard_forbidden_project_does_not_call_save(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Other-project card", project_code="forbidden_proj") + with ( + _patch_perms(allowed=("demo",)), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPResourceNotAccessible, match=str(sd.id)), + ): + update_scorecard(str(sd.id), name="x") + mock_orch.assert_not_called() + + +def test_update_scorecard_no_fields_supplied_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="No fields supplied"), + ): + update_scorecard(str(sd.id)) + mock_orch.assert_not_called() + + +def test_update_scorecard_empty_name_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="name"), + ): + update_scorecard(str(sd.id), name="") + mock_orch.assert_not_called() + assert sd.name == "Sales Quality" + + +def test_update_scorecard_unknown_category_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="category"), + ): + update_scorecard(str(sd.id), category="not_a_category") + mock_orch.assert_not_called() + + +def test_update_scorecard_filter_without_field_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="field"), + ): + update_scorecard(str(sd.id), filters=[{"value": "x"}]) + mock_orch.assert_not_called() + + +def test_update_scorecard_empty_filters_list_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="filter"), + ): + update_scorecard(str(sd.id), filters=[]) + mock_orch.assert_not_called() + + +def test_update_scorecard_changes_name(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with _patch_perms(), _patch_get(sd), _patch_orchestrator() as mock_orch: + update_scorecard(str(sd.id), name="Renamed Card") + assert sd.name == "Renamed Card" + mock_orch.assert_called_once() + + +def test_update_scorecard_toggles_show_total_score(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", total=True) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard(str(sd.id), show_total_score=False) + assert sd.total_score is False + + +def test_update_scorecard_toggles_show_cde_score(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", cde=False) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard(str(sd.id), show_cde_score=True) + assert sd.cde_score is True + + +def test_update_scorecard_sets_category(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", category=None) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard(str(sd.id), category="Quality Dimension") + assert sd.category == ScoreCategory.dq_dimension + + +def test_update_scorecard_clears_category(db_session_mock): + """Passing an empty ``category`` clears it — distinct from ``None`` (no change).""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", category=ScoreCategory.dq_dimension) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard(str(sd.id), category="") + assert sd.category is None + + +def test_update_scorecard_replaces_filters(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition( + "Sales Quality", + filters=[{"field": "table_groups_name", "value": "tg1"}], + ) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard( + str(sd.id), + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + new_filters = list(sd.criteria) + assert len(new_filters) == 1 + assert new_filters[0]["field"] == "business_domain" + assert new_filters[0]["value"] == "Finance" + + +def test_update_scorecard_flat_filters_derive_group_by_field_true(db_session_mock): + """Mode 1 shape (flat category filters) → group_by_field=True, regardless of prior state.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition( + "Sales Quality", + filters=[{ + "field": "table_groups_name", + "value": "sales", + "others": [{"field": "table_name", "value": "orders"}], + }], + group_by_field=False, + ) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard( + str(sd.id), + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + assert sd.criteria.group_by_field is True + + +def test_update_scorecard_chained_filters_derive_group_by_field_false(db_session_mock): + """Mode 2 shape (any chained filter) → group_by_field=False, regardless of prior state.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition( + "Sales Quality", + filters=[{"field": "table_groups_name", "value": "tg1"}], + group_by_field=True, + ) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard( + str(sd.id), + filters=[ + { + "field": "Table Group", + "value": "sales", + "others": [{"field": "Table", "value": "orders"}], + }, + { + "field": "Table Group", + "value": "sales", + "others": [{"field": "Table", "value": "customers"}], + }, + ], + ) + assert sd.criteria.group_by_field is False + + +def test_update_scorecard_mode_1_filter_with_non_category_field_rejected(db_session_mock): + """Flat filter using "Table" (chain-leaf field) must be rejected at the flat level.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="Table"), + ): + update_scorecard( + str(sd.id), + filters=[{"field": "Table", "value": "orders"}], + ) + mock_orch.assert_not_called() + + +def test_update_scorecard_mode_2_chain_must_root_at_table_group(db_session_mock): + """Chained filters must start at "Table Group" (matches UI column-selector shape).""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="Table Group"), + ): + update_scorecard( + str(sd.id), + filters=[{ + "field": "Data Source", + "value": "S", + "others": [{"field": "Table", "value": "x"}], + }], + ) + mock_orch.assert_not_called() + + +def test_update_scorecard_mode_2_chain_must_chain_into_table_or_column(db_session_mock): + """Chain leaves must be "Table" or "Column" — not category fields.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="Business Domain"), + ): + update_scorecard( + str(sd.id), + filters=[{ + "field": "Table Group", + "value": "sales", + "others": [{"field": "Business Domain", "value": "Finance"}], + }], + ) + mock_orch.assert_not_called() + + +def test_update_scorecard_mode_2_chain_table_then_column_accepted(db_session_mock): + """A full chain "Table Group" → "Table" → "Column" is valid.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard( + str(sd.id), + filters=[{ + "field": "Table Group", + "value": "sales", + "others": [ + {"field": "Table", "value": "orders"}, + {"field": "Column", "value": "id"}, + ], + }], + ) + assert sd.criteria.group_by_field is False + roots = list(sd.criteria) + assert roots[0]["others"][0]["field"] == "table_name" + assert roots[0]["others"][1]["field"] == "column_name" + + +def test_update_scorecard_diff_uses_display_labels(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", total=True, category=None) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + out = update_scorecard( + str(sd.id), + show_total_score=False, + category="Quality Dimension", + ) + assert "Total Score" in out + assert "Category" in out + assert "Quality Dimension" in out + # Internal names must not leak. + assert "total_score" not in out + assert "dq_dimension" not in out + + +def test_update_scorecard_diff_omits_unchanged_fields(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", total=True, cde=False, category=None) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + out = update_scorecard(str(sd.id), name="Renamed") + assert "Name" in out + assert "Total Score" not in out + assert "CDE Score" not in out + assert "Category" not in out + assert "Filters" not in out + + +def test_update_scorecard_response_includes_id_and_project(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", project_code="demo") + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + out = update_scorecard(str(sd.id), name="Renamed") + assert str(sd.id) in out + assert "demo" in out + + +def test_update_scorecard_calls_save_and_refresh_with_is_new_false(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with _patch_perms(), _patch_get(sd), _patch_orchestrator() as mock_orch: + update_scorecard(str(sd.id), name="Renamed") + mock_orch.assert_called_once() + args, kwargs = mock_orch.call_args + assert args[0] is sd + assert kwargs == {"is_new": False} + + +def test_update_scorecard_does_not_call_orchestrator_on_filter_validation_failure(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="field"), + ): + update_scorecard( + str(sd.id), + name="Renamed", + filters=[{"value": "x"}], + ) + mock_orch.assert_not_called() + # Name must not be mutated when a later validation step rejects the payload. + assert sd.name == "Sales Quality" + + +# --- create_scorecard --- + + +_VALID_FILTER = [{"field": "Table Group", "value": "tg1"}] + + +def test_create_scorecard_unknown_project_returns_not_accessible(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(allowed=("demo",)), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPResourceNotAccessible, match="forbidden_proj"), + ): + create_scorecard("forbidden_proj", "My Card", filters=_VALID_FILTER) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_blank_name(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="name"), + ): + create_scorecard("demo", " ", filters=_VALID_FILTER) + mock_orch.assert_not_called() + + +def test_create_scorecard_requires_filters(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="filter"), + ): + create_scorecard("demo", "My Card", filters=[]) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_invalid_filter_field(db_session_mock): + """dq_dimension is a group_by field, not a flat scorecard filter field.""" + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="dq_dimension"), + ): + create_scorecard( + "demo", + "My Card", + filters=[{"field": "dq_dimension", "value": "Validity"}], + ) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_filter_value_with_forbidden_chars(db_session_mock): + """Persisted scorecard filters must reject SQL-injection chars — values flow + into raw SQL via ``ScoreDefinitionCriteria.get_as_sql``.""" + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="forbidden"), + ): + create_scorecard( + "demo", + "My Card", + filters=[{"field": "Table Group", "value": "tg1' OR '1'='1"}], + ) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_filter_value_too_long(db_session_mock): + """Persisted scorecard filter values must respect ``_VALUE_MAX_LEN``.""" + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="too long"), + ): + create_scorecard( + "demo", + "My Card", + filters=[{"field": "Table Group", "value": "x" * 300}], + ) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_chain_leaf_value_with_forbidden_chars(db_session_mock): + """Chain-leaf values (``others[].value``) also flow into raw SQL — same check.""" + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="forbidden"), + ): + create_scorecard( + "demo", + "My Card", + filters=[{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "t'; DROP TABLE--"}], + }], + ) + mock_orch.assert_not_called() + + +def test_update_scorecard_rejects_filter_value_with_forbidden_chars(db_session_mock): + """Update path mirrors create — persisted filter values must be safe.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="forbidden"), + ): + update_scorecard( + str(sd.id), + filters=[{"field": "Table Group", "value": 'tg1"'}], + ) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_invalid_category(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="Invalid category"), + ): + create_scorecard( + "demo", + "My Card", + filters=_VALID_FILTER, + category="Not A Category", + ) + mock_orch.assert_not_called() + + +def test_create_scorecard_persists_with_defaults(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with _patch_perms(), _patch_orchestrator() as mock_orch: + create_scorecard("demo", "My Card", filters=_VALID_FILTER) + + assert mock_orch.call_count == 1 + saved = mock_orch.call_args.args[0] + assert isinstance(saved, ScoreDefinition) + assert saved.project_code == "demo" + assert saved.name == "My Card" + assert saved.total_score is True + assert saved.cde_score is False + assert saved.category is None + assert saved.criteria.group_by_field is True + assert saved.criteria.filters[0].field == "table_groups_name" + assert saved.criteria.filters[0].value == "tg1" + assert mock_orch.call_args.kwargs == {"is_new": True} + + +def test_create_scorecard_persists_with_overrides(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with _patch_perms(), _patch_orchestrator() as mock_orch: + create_scorecard( + "demo", + "My Card", + filters=_VALID_FILTER, + category="Quality Dimension", + show_total_score=False, + show_cde_score=True, + ) + + saved = mock_orch.call_args.args[0] + assert saved.total_score is False + assert saved.cde_score is True + assert saved.category == ScoreCategory.dq_dimension + + +def test_create_scorecard_persists_mode_2_chained_filters(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + chained = [{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "Table", "value": "accounts"}, + {"field": "Column", "value": "id"}, + ], + }] + with _patch_perms(), _patch_orchestrator() as mock_orch: + create_scorecard("demo", "My Card", filters=chained) + + saved = mock_orch.call_args.args[0] + assert saved.criteria.group_by_field is False + root = saved.criteria.filters[0] + assert root.field == "table_groups_name" + assert root.value == "tg1" + assert root.next_filter is not None + assert root.next_filter.field == "table_name" + assert root.next_filter.next_filter.field == "column_name" + + +def test_create_scorecard_returns_markdown_summary(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + new_id = uuid4() + + def _set_id(definition, *, is_new): + definition.id = new_id + return definition + + with _patch_perms(), _patch_orchestrator() as mock_orch: + mock_orch.side_effect = _set_id + out = create_scorecard( + "demo", + "Finance Card", + filters=_VALID_FILTER, + category="Quality Dimension", + ) + + assert "Finance Card" in out + assert "demo" in out + assert str(new_id) in out + # Display label uses "Category", not "Breakdown By". + assert "Category" in out + assert "Breakdown By" not in out + # Friendly category label, not internal column name. + assert "Quality Dimension" in out + assert "dq_dimension" not in out + # Filter summary appears. + assert "Filters" in out + + +# ============================================================ +# Exhaustive corner-case coverage for the unified _validate_filters +# Each numbered test maps 1-to-1 to a case in the plan's Task 2 enumeration. +# Calls into _validate_filters directly (no MCP wrapper) except +# where the contract requires going through the tool itself. +# ============================================================ + +from testgen.mcp.tools.common import ( + SCORE_FILTER_FIELD_TO_COLUMN, + ScoreFilterField, +) +from testgen.mcp.tools.quality_scores import _validate_filters + +# --- A. Shape / required-field rejections --- + + +def test_validate_filters_case_01_empty_list_rejected(): + # case 1 + with pytest.raises(MCPUserError, match=r"At least one filter is required\."): + _validate_filters([]) + + +def test_validate_filters_case_02_missing_field_key_rejected(): + # case 2 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"value": "tg1"}]) + + +def test_validate_filters_case_03_missing_value_key_rejected(): + # case 3 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"field": "Table Group"}]) + + +def test_validate_filters_case_04_empty_string_field_rejected(): + # case 4 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"field": "", "value": "tg1"}]) + + +def test_validate_filters_case_05_empty_string_value_rejected(): + # case 5 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"field": "Table Group", "value": ""}]) + + +def test_validate_filters_case_06_none_field_rejected(): + # case 6 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"field": None, "value": "tg1"}]) + + +def test_validate_filters_case_07_none_value_rejected(): + # case 7 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"field": "Table Group", "value": None}]) + + +def test_validate_filters_case_08_second_filter_malformed_indexed_at_1(): + # case 8 — index propagation through enumerate + with pytest.raises(MCPUserError, match=r"filters\[1\]"): + _validate_filters([ + {"field": "Table Group", "value": "tg1"}, + {"field": "Table Group"}, + ]) + + +# --- B. SQL-injection value guard (flat path) --- + + +def test_validate_filters_case_09_value_with_single_quote_rejected(): + # case 9 + with pytest.raises(MCPUserError, match="forbidden"): + _validate_filters([{"field": "Table Group", "value": "tg1' OR '1'='1"}]) + + +def test_validate_filters_case_10_value_with_double_quote_rejected(): + # case 10 + with pytest.raises(MCPUserError, match="forbidden"): + _validate_filters([{"field": "Table Group", "value": 'tg1"'}]) + + +def test_validate_filters_case_11_value_with_semicolon_rejected(): + # case 11 + with pytest.raises(MCPUserError, match="forbidden"): + _validate_filters([{"field": "Table Group", "value": "tg1; DROP"}]) + + +def test_validate_filters_case_12_value_with_backslash_rejected(): + # case 12 + with pytest.raises(MCPUserError, match="forbidden"): + _validate_filters([{"field": "Table Group", "value": "tg1\\foo"}]) + + +def test_validate_filters_case_13_value_with_null_byte_rejected(): + # case 13 + with pytest.raises(MCPUserError, match="forbidden"): + _validate_filters([{"field": "Table Group", "value": "tg1\x00"}]) + + +def test_validate_filters_case_14_value_length_257_rejected(): + # case 14 — boundary: 257 > 256 limit + with pytest.raises(MCPUserError, match="too long"): + _validate_filters([{"field": "Table Group", "value": "x" * 257}]) + + +def test_validate_filters_case_15_value_length_256_accepted(): + # case 15 — boundary: 256 == limit, accepted + parsed, group_by_field = _validate_filters( + [{"field": "Table Group", "value": "x" * 256}] + ) + assert group_by_field is True + assert parsed[0]["field"] == "table_groups_name" + assert parsed[0]["value"] == "x" * 256 + + +@pytest.mark.parametrize( + "bad_value", + [123, [1, 2], {"k": "v"}, True], + ids=["case_16_int", "case_16_list", "case_16_dict", "case_16_bool"], +) +def test_validate_filters_case_16_value_non_string_rejected(bad_value): + # case 16 + with pytest.raises(MCPUserError, match="must be a string"): + _validate_filters([{"field": "Table Group", "value": bad_value}]) + + +# --- C. Mode 1 (flat, no others) — happy paths --- + + +def test_validate_filters_case_17_single_table_group_flat(): + # case 17 + parsed, group_by_field = _validate_filters( + [{"field": "Table Group", "value": "tg1"}] + ) + assert group_by_field is True + assert parsed == [{"field": "table_groups_name", "value": "tg1"}] + + +def test_validate_filters_case_18_two_filters_same_field(): + # case 18 — same display field, two values + parsed, group_by_field = _validate_filters([ + {"field": "Table Group", "value": "tg1"}, + {"field": "Table Group", "value": "tg2"}, + ]) + assert group_by_field is True + assert parsed == [ + {"field": "table_groups_name", "value": "tg1"}, + {"field": "table_groups_name", "value": "tg2"}, + ] + + +def test_validate_filters_case_19_two_filters_different_fields(): + # case 19 — different display fields + parsed, group_by_field = _validate_filters([ + {"field": "Table Group", "value": "tg1"}, + {"field": "Data Source", "value": "Postgres"}, + ]) + assert group_by_field is True + assert parsed == [ + {"field": "table_groups_name", "value": "tg1"}, + {"field": "data_source", "value": "Postgres"}, + ] + + +@pytest.mark.parametrize( + "field_enum", + list(ScoreFilterField), + ids=[f"case_20_{f.name}" for f in ScoreFilterField], +) +def test_validate_filters_case_20_every_score_filter_field_accepted(field_enum): + # case 20 — parametrize over every ScoreFilterField; assert translation + parsed, group_by_field = _validate_filters( + [{"field": field_enum.value, "value": "val"}] + ) + assert group_by_field is True + assert parsed[0]["field"] == SCORE_FILTER_FIELD_TO_COLUMN[field_enum] + assert parsed[0]["value"] == "val" + + +# --- D. Mode 1 rejection paths --- + + +def test_validate_filters_case_21_column_form_field_rejected(): + # case 21 — column-form `data_source` must be rejected; error lists display values + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{"field": "data_source", "value": "Postgres"}]) + msg = exc_info.value.args[0] + assert "Data Source" in msg + # Column-form must NOT appear as a "valid" suggestion + assert "`data_source`" in msg # the rejected value is quoted back + + +def test_validate_filters_case_22_lowercase_quality_dimension_rejected(): + # case 22 — case-sensitive enum lookup + with pytest.raises(MCPUserError, match="quality dimension"): + _validate_filters([{"field": "quality dimension", "value": "Validity"}]) + + +def test_validate_filters_case_23_quality_dimension_rejected_as_filter_field(): + # case 23 — valid group_by, not a valid filter field + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{"field": "Quality Dimension", "value": "Validity"}]) + assert "Quality Dimension" in exc_info.value.args[0] + + +def test_validate_filters_case_24_impact_dimension_rejected_as_filter_field(): + # case 24 + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{"field": "Impact Dimension", "value": "High"}]) + assert "Impact Dimension" in exc_info.value.args[0] + + +def test_validate_filters_case_25_invalid_field_xyz_rejected(): + # case 25 — totally bogus field + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{"field": "xyz", "value": "v"}]) + msg = exc_info.value.args[0] + assert "xyz" in msg + # Error should list display-form values + assert "Table Group" in msg + + +def test_validate_filters_case_26_empty_others_list_still_mode_1(): + # case 26 — others=[] is falsy in any(...) + parsed, group_by_field = _validate_filters( + [{"field": "Table Group", "value": "tg1", "others": []}] + ) + assert group_by_field is True + assert parsed[0]["field"] == "table_groups_name" + + +def test_validate_filters_case_27_none_others_still_mode_1(): + # case 27 — others=None is falsy in any(...) + parsed, group_by_field = _validate_filters( + [{"field": "Table Group", "value": "tg1", "others": None}] + ) + assert group_by_field is True + assert parsed[0]["field"] == "table_groups_name" + + +# --- E. Mode 2 (chained) — happy paths --- + + +def test_validate_filters_case_28_single_chain_one_step_table(): + # case 28 — Table Group → Table + parsed, group_by_field = _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }]) + assert group_by_field is False + assert parsed == [{ + "field": "table_groups_name", + "value": "tg1", + "others": [{"field": "table_name", "value": "orders"}], + }] + + +def test_validate_filters_case_29_single_chain_two_steps_table_column(): + # case 29 — Table Group → Table → Column + parsed, group_by_field = _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "Table", "value": "orders"}, + {"field": "Column", "value": "id"}, + ], + }]) + assert group_by_field is False + assert parsed[0]["others"] == [ + {"field": "table_name", "value": "orders"}, + {"field": "column_name", "value": "id"}, + ] + + +def test_validate_filters_case_30_mode_2_with_sibling_flat_table_group(): + # case 30 — chain-having filter + bare Table Group (entire-tg case) + parsed, group_by_field = _validate_filters([ + { + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }, + {"field": "Table Group", "value": "tg2"}, + ]) + assert group_by_field is False + assert len(parsed) == 2 + assert parsed[1]["field"] == "table_groups_name" + assert parsed[1]["value"] == "tg2" + + +def test_validate_filters_case_31_multiple_chained_filters_same_shape(): + # case 31 — sibling OR semantics, both translated + parsed, group_by_field = _validate_filters([ + { + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }, + { + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "customers"}], + }, + ]) + assert group_by_field is False + assert len(parsed) == 2 + for filter_ in parsed: + assert filter_["field"] == "table_groups_name" + assert filter_["others"][0]["field"] == "table_name" + + +def test_validate_filters_case_32_chain_leaf_value_length_256_accepted(): + # case 32 — boundary at chain leaf + parsed, group_by_field = _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "x" * 256}], + }]) + assert group_by_field is False + assert parsed[0]["others"][0]["value"] == "x" * 256 + + +# --- F. Mode 2 rejection paths --- + + +def test_validate_filters_case_33_root_not_table_group_with_others_rejected(): + # case 33 — has others but root is Data Source + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{ + "field": "Data Source", + "value": "S", + "others": [{"field": "Table", "value": "x"}], + }]) + assert "Table Group" in exc_info.value.args[0] + + +def test_validate_filters_case_34_sibling_with_data_source_root_in_chain_mode_rejected(): + # case 34 — one filter chains; sibling has Data Source root (no chain) → reject + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([ + { + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }, + {"field": "Data Source", "value": "Postgres"}, + ]) + assert "Table Group" in exc_info.value.args[0] + + +def test_validate_filters_case_35_column_without_preceding_table_rejected(): + # case 35 + with pytest.raises(MCPUserError, match="`Column` chain requires a `Table` step"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Column", "value": "id"}], + }]) + + +def test_validate_filters_case_36_chain_order_column_then_table_rejected(): + # case 36 + with pytest.raises(MCPUserError, match="`Column` must be the final chain step"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "Column", "value": "id"}, + {"field": "Table", "value": "orders"}, + ], + }]) + + +def test_validate_filters_case_37_chain_leaf_column_form_table_name_rejected(): + # case 37 — column-form leaf `table_name` rejected (display-form only) + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "table_name", "value": "orders"}], + }]) + msg = exc_info.value.args[0] + assert "table_name" in msg # the rejected field is quoted in the error + # Valid leaves listed in display form + assert "Table" in msg + assert "Column" in msg + + +def test_validate_filters_case_38_invalid_chain_leaf_field_rejected(): + # case 38 + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "something_else", "value": "v"}], + }]) + msg = exc_info.value.args[0] + assert "something_else" in msg + assert "Table" in msg + assert "Column" in msg + + +def test_validate_filters_case_39_chain_leaf_missing_field_rejected(): + # case 39 — indexed filters[0].others[0] + with pytest.raises(MCPUserError, match=r"filters\[0\]\.others\[0\].*field.*value"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"value": "orders"}], + }]) + + +def test_validate_filters_case_40_chain_leaf_missing_value_rejected(): + # case 40 — indexed filters[0].others[0] + with pytest.raises(MCPUserError, match=r"filters\[0\]\.others\[0\].*field.*value"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table"}], + }]) + + +def test_validate_filters_case_41_chain_leaf_value_with_forbidden_char_rejected(): + # case 41 — indexed + with pytest.raises(MCPUserError, match=r"filters\[0\]\.others\[0\].*forbidden"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "x'; DROP"}], + }]) + + +def test_validate_filters_case_42_chain_leaf_value_too_long_rejected(): + # case 42 — indexed + with pytest.raises(MCPUserError, match=r"filters\[0\]\.others\[0\].*too long"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "x" * 300}], + }]) + + +def test_validate_filters_case_43_chain_leaf_value_non_string_rejected(): + # case 43 — indexed + with pytest.raises(MCPUserError, match=r"filters\[0\]\.others\[0\].*must be a string"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": 123}], + }]) + + +def test_validate_filters_case_44_chain_with_extra_trailing_column_rejected(): + # case 44 — Table → Column → Column: second Column is in prefix, not the end + with pytest.raises(MCPUserError, match="`Column` must be the final chain step"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "Table", "value": "orders"}, + {"field": "Column", "value": "id"}, + {"field": "Column", "value": "name"}, + ], + }]) + + +# --- G. Translation correctness (output-shape) --- + + +def test_validate_filters_case_45_output_has_only_column_form_keys(): + # case 45 — every returned `field` is column-form + parsed, _ = _validate_filters([ + {"field": "Table Group", "value": "tg1"}, + {"field": "Data Source", "value": "Postgres"}, + {"field": "Quality Dimension", "value": "ignored"}, # would fail; remove + ][:2]) # only the first two — Quality Dimension isn't a valid filter + column_form_field_values = set(SCORE_FILTER_FIELD_TO_COLUMN.values()) | { + "table_name", "column_name", + } + for filter_ in parsed: + assert filter_["field"] in column_form_field_values + + # Also test chain leaves + parsed_chain, _ = _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "Table", "value": "orders"}, + {"field": "Column", "value": "id"}, + ], + }]) + for filter_ in parsed_chain: + assert filter_["field"] in column_form_field_values + for leaf in filter_.get("others", []): + assert leaf["field"] in column_form_field_values + + +def test_validate_filters_case_46_values_byte_identical_to_input(): + # case 46 — values are NEVER mutated by translation + raw = [{ + "field": "Table Group", + "value": "MixedCaseTG.with-dots_underscores", + "others": [ + {"field": "Table", "value": "Orders Table"}, + {"field": "Column", "value": "ID-Col"}, + ], + }] + parsed, _ = _validate_filters(raw) + assert parsed[0]["value"] == "MixedCaseTG.with-dots_underscores" + assert parsed[0]["others"][0]["value"] == "Orders Table" + assert parsed[0]["others"][1]["value"] == "ID-Col" + + +def test_validate_filters_case_47_group_by_field_flag_correct(): + # case 47 — True iff no filter has non-empty others + _, flag_flat = _validate_filters([{"field": "Table Group", "value": "tg1"}]) + assert flag_flat is True + _, flag_chained = _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }]) + assert flag_chained is False + + +# --- H. Error-message hygiene (regression guards) --- + + +def test_validate_filters_case_48_flat_error_message_uses_display_form(): + # case 48 — error mentions valid flat fields: at least one display-form value; + # no underscore-form column names + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{"field": "xyz", "value": "v"}]) + msg = exc_info.value.args[0] + assert "Table Group" in msg # display-form present + # None of the column-form values should appear as a "valid" suggestion + # (the rejected `xyz` is fine; we're checking valid-values listing) + column_form_values = set(SCORE_FILTER_FIELD_TO_COLUMN.values()) + for col_value in column_form_values: + # Each column-form value (e.g. "table_groups_name", "data_source") must + # not appear in the listed-valid set. The rejected field name is also + # mentioned, but that's a user-supplied string, not "xyz" matching. + assert col_value not in msg, ( + f"Error message must not list column-form `{col_value}` as a valid value. " + f"Full message: {msg}" + ) + + +def test_validate_filters_case_49_chain_leaf_error_uses_display_form(): + # case 49 — leaf error mentions Table and Column, not table_name/column_name + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "something_else", "value": "v"}], + }]) + msg = exc_info.value.args[0] + assert "Table" in msg + assert "Column" in msg + # The column-form leaf names must not appear as "valid" leaves + assert "table_name" not in msg.replace("`Table`", "").replace("Table", "") # `Table` ok, `table_name` not + assert "column_name" not in msg.replace("`Column`", "").replace("Column", "") + + +# --- Wrapper-level: column-form rejected through create_scorecard --- + + +def test_create_scorecard_rejects_column_form_field_through_wrapper(db_session_mock): + # case 21 mirrored at the MCP wrapper level + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError) as exc_info, + ): + create_scorecard( + "demo", + "My Card", + filters=[{"field": "data_source", "value": "Postgres"}], + ) + msg = exc_info.value.args[0] + assert "Data Source" in msg + mock_orch.assert_not_called() + + +# ============================================================ +# Unified validator: allow_empty + multi-error collection +# ============================================================ + + +def test_validate_filters_empty_default_rejected(): + """With the default allow_empty=False, an empty list raises.""" + with pytest.raises(MCPUserError, match=r"At least one filter is required\."): + _validate_filters([]) + + +def test_validate_filters_none_default_rejected(): + """With the default allow_empty=False, None raises.""" + with pytest.raises(MCPUserError, match=r"At least one filter is required\."): + _validate_filters(None) + + +def test_validate_filters_empty_allowed_returns_empty_tuple(): + """allow_empty=True short-circuits an empty list to ([], True).""" + parsed, group_by_field = _validate_filters([], allow_empty=True) + assert parsed == [] + assert group_by_field is True + + +def test_validate_filters_none_allowed_returns_empty_tuple(): + """allow_empty=True short-circuits None to ([], True).""" + parsed, group_by_field = _validate_filters(None, allow_empty=True) + assert parsed == [] + assert group_by_field is True + + +def test_validate_filters_collects_multiple_flat_errors(): + """Multi-error collection: every offending entry is named in the message.""" + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([ + {"field": "Quality Dimension", "value": "Accuracy"}, # not a filter field + {"field": "Business Domain", "value": "x';--"}, # forbidden chars + {"field": "Data Source", "value": ""}, # empty value + ]) + msg = exc_info.value.args[0] + assert "Quality Dimension" in msg + assert "Business Domain" in msg + assert "Data Source" in msg + + +def test_validate_filters_collects_multiple_chain_leaf_errors(): + """Chain-mode also collects per-leaf errors instead of stopping at the first.""" + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "bogus_leaf", "value": "x"}, # invalid leaf field + {"field": "Table", "value": "tbl';DROP"}, # forbidden char in valid leaf + ], + }]) + msg = exc_info.value.args[0] + assert "bogus_leaf" in msg + assert "forbidden" in msg + + +# ============================================================ +# get_quality_scores: mode-2 chained filter support +# ============================================================ + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinitionCriteria") +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_get_quality_scores_accepts_mode_2_chained_filters( + mock_definition_cls, mock_criteria_cls, db_session_mock, +): + """A chained Table Group → Table filter reaches from_filters with group_by_field=False.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + get_quality_scores( + project_code="demo", + filters=[{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }], + ) + + mock_criteria_cls.from_filters.assert_called_once() + args, kwargs = mock_criteria_cls.from_filters.call_args + passed = args[0] + assert kwargs.get("group_by_field") is False + assert passed[0]["field"] == "table_groups_name" + assert passed[0]["value"] == "tg1" + assert passed[0]["others"] == [{"field": "table_name", "value": "orders"}] + + +def test_get_quality_scores_rejects_table_group_id_with_chained_filters(db_session_mock): + """table_group_id + a mode-2 chain conflict — the implicit name filter would + shadow the chain root, so reject explicitly.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + tg = MagicMock() + tg.id = uuid4() + tg.project_code = "demo" + tg.table_groups_name = "orders_tg" + + with ( + _patch_perms(), + patch("testgen.mcp.tools.common.TableGroup") as mock_tg_cls, + pytest.raises(MCPUserError, match="chained filters"), + ): + mock_tg_cls.get.return_value = tg + get_quality_scores( + table_group_id=str(tg.id), + filters=[{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }], + ) diff --git a/tests/unit/mcp/test_tools_reference.py b/tests/unit/mcp/test_tools_reference.py index 55f6b508..96ead51f 100644 --- a/tests/unit/mcp/test_tools_reference.py +++ b/tests/unit/mcp/test_tools_reference.py @@ -176,3 +176,90 @@ def test_hygiene_issue_types_resource_empty(mock_type_cls, db_session_mock): result = hygiene_issue_types_resource() assert "No hygiene issue types found" in result + + +# --- column_profile_fields_resource --- + + +def test_column_profile_fields_resource_has_five_sections(): + from testgen.mcp.tools.reference import column_profile_fields_resource + + result = column_profile_fields_resource() + + assert "TestGen Column Profile Fields Reference" in result + assert "## All Column Types" in result + assert "## Alpha" in result + assert "## Numeric" in result + assert "## Date" in result + assert "## Boolean" in result + + +def test_column_profile_fields_resource_lists_all_pii_redacted_fields(): + """The footer must name every redactable field so the LLM can interpret `[PII Redacted]` markers.""" + from testgen.mcp.tools.reference import column_profile_fields_resource + + result = column_profile_fields_resource() + + # Friendly labels mirroring PROFILING_PII_FIELDS from testgen.common.pii_masking. + expected_labels = ( + "Frequent Values", + "Minimum Text", + "Maximum Text", + "Minimum Value", + "Minimum Value > 0", + "Maximum Value", + "Minimum Date", + "Maximum Date", + ) + for label in expected_labels: + assert label in result, f"Expected `{label}` to be named in the redaction note" + + +def test_column_profile_fields_resource_describes_redaction_trigger(): + from testgen.mcp.tools.reference import column_profile_fields_resource + + result = column_profile_fields_resource() + + # The redaction trigger: column is PII-flagged AND caller lacks permission to view PII. + assert "PII" in result + assert "permission to view PII" in result + + +def test_column_profile_fields_resource_describes_per_type_fields(): + """Each section should at least mention the most distinctive field for that type.""" + from testgen.mcp.tools.reference import column_profile_fields_resource + + result = column_profile_fields_resource() + + # All-types section + assert "Row Count" in result + assert "Hygiene Issues" in result + # Alpha + assert "Minimum Length" in result + assert "Frequent Values" in result + assert "Standard Pattern Match" in result + # Numeric + assert "Minimum Value" in result + assert "Median Value" in result + # Datetime + assert "Minimum Date" in result + assert "Before 1 Year" in result + # Boolean + assert "## Boolean Columns" in result + + +# --- server instructions reference the new resource --- + + +def test_server_instructions_reference_column_profile_fields_resource(): + """The LLM relies on SERVER_INSTRUCTIONS to learn which resources to consult. + + The new resource must be named alongside test-types and hygiene-issue-types so + the LLM knows when to look up column-profile field semantics. + """ + from testgen.mcp.server import SERVER_INSTRUCTIONS + + assert "testgen://column-profile-fields" in SERVER_INSTRUCTIONS + # Sanity check the existing references are still present. + assert "testgen://test-types" in SERVER_INSTRUCTIONS + assert "testgen://hygiene-issue-types" in SERVER_INSTRUCTIONS diff --git a/tests/unit/mcp/test_tools_schedules.py b/tests/unit/mcp/test_tools_schedules.py new file mode 100644 index 00000000..4278ca58 --- /dev/null +++ b/tests/unit/mcp/test_tools_schedules.py @@ -0,0 +1,358 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from testgen.common.enums import JobKey +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError + + +def _make_table_group(project_code="demo", name="orders_tg"): + tg = MagicMock() + tg.id = uuid4() + tg.project_code = project_code + tg.table_groups_name = name + return tg + + +def _make_suite(project_code="demo", name="suite_a", is_monitor=False): + suite = MagicMock() + suite.id = uuid4() + suite.project_code = project_code + suite.test_suite = name + suite.is_monitor = is_monitor + return suite + + +def _make_sched(*, key=None, active=True, project_code="demo", linked_id=None): + sched = MagicMock() + sched.id = uuid4() + sched.project_code = project_code + sched.key = key or JobKey.run_profile.value + sched.cron_expr = "0 3 * * *" + sched.cron_tz = "UTC" + sched.active = active + if sched.key == JobKey.run_profile.value: + sched.kwargs = {"table_group_id": linked_id or str(uuid4())} + else: + sched.kwargs = {"test_suite_id": linked_id or str(uuid4())} + sched.get_sample_triggering_timestamps.return_value = [datetime(2026, 5, 19, 3, 0)] + return sched + + +# --------------------------------------------------------------------------- +# create_profiling_schedule +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules.JobSchedule") +@patch("testgen.mcp.tools.schedules.resolve_table_group") +def test_create_profiling_schedule_happy_path(mock_resolve_tg, mock_sched_cls, db_session_mock): + tg = _make_table_group() + mock_resolve_tg.return_value = tg + saved = _make_sched(linked_id=str(tg.id)) + mock_sched_cls.return_value = saved + + from testgen.mcp.tools.schedules import create_profiling_schedule + + result = create_profiling_schedule( + table_group_id=str(tg.id), + cron_expression="0 3 * * *", + cron_tz="UTC", + ) + + assert "Profiling schedule created" in result + assert "orders_tg" in result + assert "`0 3 * * *`" in result + saved.save.assert_called_once() + + +@patch("testgen.mcp.tools.schedules.resolve_table_group") +def test_create_profiling_schedule_invalid_cron(mock_resolve_tg, db_session_mock): + mock_resolve_tg.return_value = _make_table_group() + + from testgen.mcp.tools.schedules import create_profiling_schedule + + with pytest.raises(MCPUserError) as exc: + create_profiling_schedule( + table_group_id=str(uuid4()), + cron_expression="not a cron", + cron_tz="UTC", + ) + assert "Invalid cron" in str(exc.value) + + +@patch("testgen.mcp.tools.schedules.resolve_table_group") +def test_create_profiling_schedule_invalid_timezone(mock_resolve_tg, db_session_mock): + mock_resolve_tg.return_value = _make_table_group() + + from testgen.mcp.tools.schedules import create_profiling_schedule + + with pytest.raises(MCPUserError) as exc: + create_profiling_schedule( + table_group_id=str(uuid4()), + cron_expression="0 3 * * *", + cron_tz="Not/A_Real_Timezone", + ) + assert "Invalid cron" in str(exc.value) + + +@patch("testgen.mcp.tools.schedules.resolve_table_group") +def test_create_profiling_schedule_empty_cron_rejected(mock_resolve_tg, db_session_mock): + mock_resolve_tg.return_value = _make_table_group() + + from testgen.mcp.tools.schedules import create_profiling_schedule + + with pytest.raises(MCPUserError) as exc: + create_profiling_schedule(table_group_id=str(uuid4()), cron_expression="") + assert "cron_expression" in str(exc.value) + + +@patch("testgen.mcp.tools.schedules.resolve_table_group") +def test_create_profiling_schedule_empty_tz_rejected(mock_resolve_tg, db_session_mock): + mock_resolve_tg.return_value = _make_table_group() + + from testgen.mcp.tools.schedules import create_profiling_schedule + + with pytest.raises(MCPUserError) as exc: + create_profiling_schedule( + table_group_id=str(uuid4()), cron_expression="0 3 * * *", cron_tz="" + ) + assert "cron_tz" in str(exc.value) + + +# --------------------------------------------------------------------------- +# create_test_run_schedule +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules.JobSchedule") +@patch("testgen.mcp.tools.schedules.resolve_test_suite") +def test_create_test_run_schedule_happy_path(mock_resolve_suite, mock_sched_cls, db_session_mock): + suite = _make_suite() + mock_resolve_suite.return_value = suite + saved = _make_sched(key=JobKey.run_tests.value, linked_id=str(suite.id)) + mock_sched_cls.return_value = saved + + from testgen.mcp.tools.schedules import create_test_run_schedule + + result = create_test_run_schedule( + test_suite_id=str(suite.id), + cron_expression="0 6 * * 1", + cron_tz="UTC", + ) + + assert "Test run schedule created" in result + assert "suite_a" in result + saved.save.assert_called_once() + + +@patch("testgen.mcp.tools.schedules.resolve_test_suite") +def test_create_test_run_schedule_monitor_suite_rejected(mock_resolve_suite, db_session_mock): + mock_resolve_suite.side_effect = MCPResourceNotAccessible("Test suite", "abc") + + from testgen.mcp.tools.schedules import create_test_run_schedule + + with pytest.raises(MCPResourceNotAccessible): + create_test_run_schedule( + test_suite_id=str(uuid4()), + cron_expression="0 6 * * 1", + ) + + +# --------------------------------------------------------------------------- +# list_schedules +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules._resolve_linked_names") +@patch("testgen.mcp.tools.schedules.JobSchedule") +def test_list_schedules_basic(mock_sched_cls, mock_linked, db_session_mock): + sched_a = _make_sched(key=JobKey.run_profile.value) + sched_b = _make_sched(key=JobKey.run_tests.value) + mock_sched_cls.list_for_project.return_value = ([sched_a, sched_b], 2) + mock_linked.return_value = { + ("tg", sched_a.kwargs["table_group_id"]): "orders_tg", + ("suite", sched_b.kwargs["test_suite_id"]): "suite_a", + } + + from testgen.mcp.tools.schedules import list_schedules + + result = list_schedules(project_code="demo") + + assert "Schedules" in result + assert "Profiling Run" in result + assert "Test Run" in result + assert "orders_tg" in result + assert "suite_a" in result + + +@patch("testgen.mcp.tools.schedules.JobSchedule") +def test_list_schedules_empty(mock_sched_cls, db_session_mock): + mock_sched_cls.list_for_project.return_value = ([], 0) + + from testgen.mcp.tools.schedules import list_schedules + + result = list_schedules(project_code="demo") + assert "No schedules" in result + + +@patch("testgen.mcp.tools.schedules._resolve_linked_names") +@patch("testgen.mcp.tools.schedules.JobSchedule") +def test_list_schedules_type_filter_maps_to_job_key(mock_sched_cls, mock_linked, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value) + mock_sched_cls.list_for_project.return_value = ([sched], 1) + mock_linked.return_value = {} + + from testgen.mcp.tools.schedules import list_schedules + + list_schedules(project_code="demo", schedule_type="profiling_run") + + call_kwargs = mock_sched_cls.list_for_project.call_args + assert call_kwargs.kwargs["key_filter"] == [JobKey.run_profile.value] + + +def test_list_schedules_invalid_schedule_type(db_session_mock): + from testgen.mcp.tools.schedules import list_schedules + + with pytest.raises(MCPUserError) as exc: + list_schedules(project_code="demo", schedule_type="not-a-type") + assert "Invalid schedule_type" in str(exc.value) + + +def test_list_schedules_project_not_accessible(db_session_mock): + from testgen.mcp.tools.schedules import list_schedules + + with pytest.raises(MCPResourceNotAccessible): + list_schedules(project_code="other_project") + + +# --------------------------------------------------------------------------- +# get_schedule +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules.get_current_session") +@patch("testgen.mcp.tools.schedules._resolve_linked_names") +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_get_schedule_no_executions(mock_resolve, mock_linked, mock_session, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value) + mock_resolve.return_value = sched + mock_linked.return_value = {("tg", sched.kwargs["table_group_id"]): "orders_tg"} + session = MagicMock() + session.scalars.return_value.all.return_value = [] + mock_session.return_value = session + + from testgen.mcp.tools.schedules import get_schedule + + result = get_schedule(schedule_id=str(sched.id)) + assert "orders_tg" in result + assert "No runs yet" in result + + +@patch("testgen.mcp.tools.schedules.get_current_session") +@patch("testgen.mcp.tools.schedules._resolve_linked_names") +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_get_schedule_with_executions(mock_resolve, mock_linked, mock_session, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value) + mock_resolve.return_value = sched + mock_linked.return_value = {("tg", sched.kwargs["table_group_id"]): "orders_tg"} + + je = MagicMock() + je.id = uuid4() + je.status = "Completed" + je.created_at = datetime(2026, 5, 18, 3, 0) + je.started_at = datetime(2026, 5, 18, 3, 0) + je.completed_at = datetime(2026, 5, 18, 3, 12) + session = MagicMock() + session.scalars.return_value.all.return_value = [je] + mock_session.return_value = session + + from testgen.mcp.tools.schedules import get_schedule + + result = get_schedule(schedule_id=str(sched.id)) + assert "Recent runs" in result + assert str(je.id) in result + + +# --------------------------------------------------------------------------- +# update_schedule +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_update_schedule_happy_path_diff(mock_resolve, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value, active=True) + mock_resolve.return_value = sched + + from testgen.mcp.tools.schedules import update_schedule + + result = update_schedule(schedule_id=str(sched.id), active=False) + + assert "Schedule updated" in result + assert "Active" in result and "Paused" in result + sched.save.assert_called_once() + + +def test_update_schedule_empty_payload_rejected(db_session_mock): + from testgen.mcp.tools.schedules import update_schedule + + with pytest.raises(MCPUserError) as exc: + update_schedule(schedule_id=str(uuid4())) + assert "No fields supplied" in str(exc.value) + + +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_update_schedule_invalid_cron(mock_resolve, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value) + mock_resolve.return_value = sched + + from testgen.mcp.tools.schedules import update_schedule + + with pytest.raises(MCPUserError) as exc: + update_schedule(schedule_id=str(sched.id), cron_expression="garbage") + assert "Invalid cron" in str(exc.value) + sched.save.assert_not_called() + + +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_update_schedule_monitor_schedule_not_accessible(mock_resolve, db_session_mock): + """resolve_schedule filters out monitor schedules — caller sees the unified not-accessible error.""" + mock_resolve.side_effect = MCPResourceNotAccessible("Schedule", "abc") + + from testgen.mcp.tools.schedules import update_schedule + + with pytest.raises(MCPResourceNotAccessible): + update_schedule(schedule_id=str(uuid4()), active=False) + + +# --------------------------------------------------------------------------- +# delete_schedule +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules.JobSchedule") +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_delete_schedule_happy_path(mock_resolve, mock_sched_cls, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value) + mock_resolve.return_value = sched + + from testgen.mcp.tools.schedules import delete_schedule + + result = delete_schedule(schedule_id=str(sched.id)) + assert "Schedule deleted" in result + mock_sched_cls.delete.assert_called_once_with(sched.id) + + +@patch("testgen.mcp.tools.schedules.JobSchedule") +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_delete_schedule_monitor_schedule_not_accessible(mock_resolve, mock_sched_cls, db_session_mock): + """resolve_schedule filters out monitor schedules — caller sees the unified not-accessible error.""" + mock_resolve.side_effect = MCPResourceNotAccessible("Schedule", "abc") + + from testgen.mcp.tools.schedules import delete_schedule + + with pytest.raises(MCPResourceNotAccessible): + delete_schedule(schedule_id=str(uuid4())) + mock_sched_cls.delete.assert_not_called() diff --git a/tests/unit/mcp/test_tools_test_definitions.py b/tests/unit/mcp/test_tools_test_definitions.py index 5dea0d03..46488143 100644 --- a/tests/unit/mcp/test_tools_test_definitions.py +++ b/tests/unit/mcp/test_tools_test_definitions.py @@ -3,6 +3,7 @@ import pytest +from testgen.common.custom_test_validation import CustomQueryResult from testgen.mcp.exceptions import MCPUserError # -- list_tests --------------------------------------------------------------- @@ -414,9 +415,11 @@ def test_list_test_notes_basic(mock_td, mock_notes, db_session_mock): td.column_name = "name" mock_td.get_for_project.return_value = td + note_id_1 = str(uuid4()) + note_id_2 = str(uuid4()) mock_notes.get_notes.return_value = [ - {"detail": "Threshold looks wrong", "created_by": "alice", "created_at": "2026-04-01T10:00:00", "updated_at": None}, - {"detail": "Confirmed with team", "created_by": "bob", "created_at": "2026-04-02T14:30:00", "updated_at": "2026-04-03T09:00:00"}, + {"id": note_id_1, "detail": "Threshold looks wrong", "created_by": "alice", "created_at": "2026-04-01T10:00:00", "updated_at": None}, + {"id": note_id_2, "detail": "Confirmed with team", "created_by": "bob", "created_at": "2026-04-02T14:30:00", "updated_at": "2026-04-03T09:00:00"}, ] from testgen.mcp.tools.test_definitions import list_test_notes @@ -431,6 +434,9 @@ def test_list_test_notes_basic(mock_td, mock_notes, db_session_mock): assert "alice" in result assert "2026-04-01 10:00" in result assert "2026-04-03 09:00" in result + assert "Test note ID" in result + assert note_id_1 in result + assert note_id_2 in result @patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") @@ -467,6 +473,187 @@ def test_list_test_notes_invalid_uuid(db_session_mock): list_test_notes("garbage") +# -- create_test_note --------------------------------------------------------- + + +def _make_note_summary(): + """Minimal TestDefinitionSummary mock for note-tool rendering.""" + summary = MagicMock() + summary.display_name = "Alpha Truncation" + summary.table_name = "orders" + summary.column_name = "email" + return summary + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_create_test_note_happy_path( + mock_resolve_td, mock_td, mock_note_model, mcp_user, db_session_mock, +): + mcp_user.username = "test_user" + td = MagicMock(id=uuid4()) + mock_resolve_td.return_value = td + + note_instance = MagicMock( + id=uuid4(), + detail="Threshold widened — confirmed with team", + created_at="2026-05-27T10:00:00", + ) + mock_note_model.add_note.return_value = note_instance + mock_td.get_for_project.return_value = _make_note_summary() + + from testgen.mcp.tools.test_definitions import create_test_note + + result = create_test_note(str(td.id), "Threshold widened — confirmed with team") + + assert "Note added" in result + assert "Alpha Truncation" in result + assert "`email`" in result + assert "`orders`" in result + assert "test_user" in result + assert str(note_instance.id) in result + mock_note_model.add_note.assert_called_once_with(td.id, "Threshold widened — confirmed with team", "test_user") + + +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_create_test_note_rejects_empty_body(mock_resolve_td, db_session_mock): + from testgen.mcp.tools.test_definitions import create_test_note + + with pytest.raises(MCPUserError, match="cannot be empty"): + create_test_note(str(uuid4()), "") + with pytest.raises(MCPUserError, match="cannot be empty"): + create_test_note(str(uuid4()), " \n\t ") + + mock_resolve_td.assert_not_called() + + +def test_create_test_note_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_definitions import create_test_note + + with pytest.raises(MCPUserError, match="not a valid UUID"): + create_test_note("garbage", "valid detail") + + +# -- update_test_note --------------------------------------------------------- + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.resolve_test_note") +def test_update_test_note_happy_path( + mock_resolve_note, mock_td, mock_note_model, mcp_user, db_session_mock, +): + mcp_user.username = "test_user" + note = MagicMock( + id=uuid4(), + test_definition_id=uuid4(), + created_by="test_user", + detail="original body", + ) + mock_resolve_note.return_value = note + mock_td.get_for_project.return_value = _make_note_summary() + + from testgen.mcp.tools.test_definitions import update_test_note + + result = update_test_note(str(note.id), "rewritten body") + + assert "Note updated" in result + assert "Alpha Truncation" in result + assert "original body" in result + assert "rewritten body" in result + mock_note_model.update_note.assert_called_once_with(note.id, "rewritten body") + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.resolve_test_note") +def test_update_test_note_non_author_rejected( + mock_resolve_note, mock_note_model, mcp_user, db_session_mock, +): + mcp_user.username = "test_user" + note = MagicMock(created_by="someone_else") + mock_resolve_note.return_value = note + + from testgen.mcp.tools.test_definitions import update_test_note + + with pytest.raises(MCPUserError, match="You can only edit notes you authored"): + update_test_note(str(uuid4()), "new body") + + mock_note_model.update_note.assert_not_called() + + +@patch("testgen.mcp.tools.test_definitions.resolve_test_note") +def test_update_test_note_rejects_empty_body(mock_resolve_note, db_session_mock): + from testgen.mcp.tools.test_definitions import update_test_note + + with pytest.raises(MCPUserError, match="cannot be empty"): + update_test_note(str(uuid4()), "") + with pytest.raises(MCPUserError, match="cannot be empty"): + update_test_note(str(uuid4()), " ") + + mock_resolve_note.assert_not_called() + + +def test_update_test_note_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_definitions import update_test_note + + with pytest.raises(MCPUserError, match="not a valid UUID"): + update_test_note("garbage", "valid detail") + + +# -- delete_test_note --------------------------------------------------------- + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.resolve_test_note") +def test_delete_test_note_happy_path( + mock_resolve_note, mock_td, mock_note_model, mcp_user, db_session_mock, +): + mcp_user.username = "test_user" + note = MagicMock( + id=uuid4(), + test_definition_id=uuid4(), + created_by="test_user", + created_at="2026-05-27T10:00:00", + ) + mock_resolve_note.return_value = note + mock_td.get_for_project.return_value = _make_note_summary() + + from testgen.mcp.tools.test_definitions import delete_test_note + + result = delete_test_note(str(note.id)) + + assert "Note deleted" in result + assert "Alpha Truncation" in result + assert "test_user" in result + mock_note_model.delete_note.assert_called_once_with(note.id) + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.resolve_test_note") +def test_delete_test_note_non_author_rejected( + mock_resolve_note, mock_note_model, mcp_user, db_session_mock, +): + mcp_user.username = "test_user" + note = MagicMock(created_by="someone_else") + mock_resolve_note.return_value = note + + from testgen.mcp.tools.test_definitions import delete_test_note + + with pytest.raises(MCPUserError, match="You can only delete notes you authored"): + delete_test_note(str(uuid4())) + + mock_note_model.delete_note.assert_not_called() + + +def test_delete_test_note_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_definitions import delete_test_note + + with pytest.raises(MCPUserError, match="not a valid UUID"): + delete_test_note("garbage") + + # -- list_test_types ---------------------------------------------------------- @@ -545,3 +732,584 @@ def test_list_test_types_filter_description(mock_tt, db_session_mock): assert "scope: table" in result assert "dimension: Completeness" in result + + +# -- create_test -------------------------------------------------------------- + + +def _make_suite(suite_id=None, table_groups_id=None): + suite = MagicMock() + suite.id = suite_id or uuid4() + suite.test_suite = "demo_suite" + suite.project_code = "demo" + suite.table_groups_id = table_groups_id or uuid4() + return suite + + +def _make_test_type( + code="Alpha_Trunc", + short_name="Alpha Truncation", + scope="column", + param_columns=None, + default_parm_columns="threshold_value", + default_parm_required=None, + default_severity="Fail", +): + tt = MagicMock() + tt.test_type = code + tt.test_name_short = short_name + tt.test_scope = scope + tt.param_columns = param_columns if param_columns is not None else {"threshold_value"} + tt.default_parm_columns = default_parm_columns + tt.default_parm_required = default_parm_required + tt.default_severity = default_severity + return tt + + +def _make_table_group(schema="public"): + tg = MagicMock() + tg.id = uuid4() + tg.table_group_schema = schema + return tg + + +def _make_td_summary(table_name="orders", column_name="email", severity="Warning"): + """Mock TestDefinitionSummary as returned by TestDefinition.get_for_project().""" + summary = MagicMock() + summary.id = uuid4() + summary.display_name = "Alpha Truncation" + summary.test_type = "Alpha_Trunc" + summary.test_name_short = "Alpha Truncation" + summary.table_name = table_name + summary.column_name = column_name + summary.schema_name = "demo" + summary.test_scope = "column" + summary.test_suite_id = uuid4() + summary.impact_dimension = None + summary.default_impact_dimension = "Conformance" + summary.dq_dimension = "Validity" + summary.test_active = True + summary.severity = severity + summary.default_severity = "Fail" + summary.lock_refresh = False + summary.export_to_observability = True + summary.flagged = False + summary.last_auto_gen_date = None + summary.last_manual_update = None + summary.default_parm_columns = "threshold_value" + summary.param_columns = {"threshold_value"} + summary.param_fields = [("threshold_value", "Maximum String Length at Baseline", "")] + summary.threshold_value = "64" + summary.custom_query = None + summary.match_schema_name = None + summary.match_table_name = None + summary.match_column_names = None + summary.match_subset_condition = None + summary.match_groupby_names = None + summary.match_having_condition = None + return summary + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_happy_path( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, mock_notes, db_session_mock, +): + suite = _make_suite() + mock_resolve_suite.return_value = suite + mock_resolve_tt.return_value = "Alpha_Trunc" + mock_tt_model.get.return_value = _make_test_type() + mock_tg.get.return_value = _make_table_group() + + saved = MagicMock() + saved.id = uuid4() + saved.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", "column_name", + } + mock_td.return_value = saved + mock_td.get_for_project.return_value = _make_td_summary() + mock_notes.get_notes.return_value = [] + + from testgen.mcp.tools.test_definitions import create_test + + result = create_test( + test_suite_id=str(uuid4()), + test_type="Alpha Truncation", + table_name="orders", + fields={"column_name": "email", "threshold_value": "64", "severity": "Warning"}, + ) + + # New shared body: entity-first heading + "Created in suite" lead-in + assert "Created" in result + assert "Alpha Truncation on `email` in `orders`" in result + # Parameters table uses the test type's prompt, not a hardcoded label + assert "Maximum String Length at Baseline" in result + assert "64" in result + assert "Warning" in result + saved.save.assert_called_once() + + +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_column_scope_requires_column_name( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, db_session_mock, +): + """Column-scoped types: missing column_name → validate() raises before save.""" + from testgen.common.models.test_definition import InvalidTestDefinitionFields + + mock_resolve_suite.return_value = _make_suite() + mock_resolve_tt.return_value = "Alpha_Trunc" + mock_tt_model.get.return_value = _make_test_type() + mock_tg.get.return_value = _make_table_group() + + saved = MagicMock(id=uuid4()) + saved.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", "column_name", + } + saved.validate.side_effect = InvalidTestDefinitionFields( + {"column_name": "required for test type `Alpha_Trunc`"} + ) + mock_td.return_value = saved + + from testgen.mcp.tools.test_definitions import create_test + + with pytest.raises(MCPUserError) as exc_info: + create_test( + test_suite_id=str(uuid4()), + test_type="Alpha Truncation", + table_name="orders", + fields={"threshold_value": "64"}, + ) + assert "column_name" in str(exc_info.value) + assert "rejected" in str(exc_info.value).lower() + saved.save.assert_not_called() + + +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_unknown_field_rejected_by_whitelist( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, db_session_mock, +): + """Unknown field in ``fields`` (e.g. custom_query on Alpha_Trunc) is rejected by editable_fields whitelist.""" + mock_resolve_suite.return_value = _make_suite() + mock_resolve_tt.return_value = "Alpha_Trunc" + mock_tt_model.get.return_value = _make_test_type() + mock_tg.get.return_value = _make_table_group() + + saved = MagicMock(id=uuid4()) + saved.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", "column_name", + } + mock_td.return_value = saved + + from testgen.mcp.tools.test_definitions import create_test + + with pytest.raises(MCPUserError) as exc_info: + create_test( + test_suite_id=str(uuid4()), + test_type="Alpha Truncation", + table_name="orders", + fields={"column_name": "email", "threshold_value": "64", "custom_query": "SELECT 1"}, + ) + assert "custom_query" in str(exc_info.value) + assert "not editable" in str(exc_info.value) + saved.save.assert_not_called() + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_fields_dict_supports_test_type_params( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, mock_notes, db_session_mock, +): + """``fields`` accepts any param in editable_fields — e.g. window_days for a trend test.""" + mock_resolve_suite.return_value = _make_suite() + mock_resolve_tt.return_value = "Some_Trend" + mock_tt_model.get.return_value = _make_test_type( + code="Some_Trend", + param_columns={"threshold_value", "window_days"}, + default_parm_columns="threshold_value,window_days", + ) + mock_tg.get.return_value = _make_table_group() + + saved = MagicMock(id=uuid4()) + saved.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", "window_days", "column_name", + } + mock_td.return_value = saved + mock_td.get_for_project.return_value = _make_td_summary() + mock_notes.get_notes.return_value = [] + + from testgen.mcp.tools.test_definitions import create_test + + create_test( + test_suite_id=str(uuid4()), + test_type="Some Trend", + table_name="orders", + fields={"column_name": "amount", "threshold_value": "10", "window_days": "7"}, + ) + + # Both common and type-specific fields applied via setattr + assert saved.threshold_value == "10" + assert saved.window_days == "7" + saved.validate.assert_called_once() + saved.save.assert_called_once() + + +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_severity_invalid( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, db_session_mock, +): + """severity outside the StrEnum → validate() raises.""" + from testgen.common.models.test_definition import InvalidTestDefinitionFields + + mock_resolve_suite.return_value = _make_suite() + mock_resolve_tt.return_value = "Alpha_Trunc" + mock_tt_model.get.return_value = _make_test_type() + mock_tg.get.return_value = _make_table_group() + + saved = MagicMock(id=uuid4()) + saved.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", "column_name", + } + saved.validate.side_effect = InvalidTestDefinitionFields( + {"severity": "must be `Fail` or `Warning` (got `critical`)"} + ) + mock_td.return_value = saved + + from testgen.mcp.tools.test_definitions import create_test + + with pytest.raises(MCPUserError) as exc_info: + create_test( + test_suite_id=str(uuid4()), + test_type="Alpha Truncation", + table_name="orders", + fields={"column_name": "email", "threshold_value": "64", "severity": "critical"}, + ) + assert "severity" in str(exc_info.value) + saved.save.assert_not_called() + + +# -- update_test -------------------------------------------------------------- + + +def _make_td_orm(test_type="Alpha_Trunc", threshold_value="64", severity="Warning"): + td = MagicMock() + td.id = uuid4() + td.test_type = test_type + td.threshold_value = threshold_value + td.severity = severity + td.test_active = True + td.lock_refresh = False + td.flagged = False + # Mirror TestDefinition.editable_fields(tt) for an Alpha_Trunc-shaped test type + td.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", + } + return td + + +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_update_test_happy_path(mock_resolve_td, mock_tt_model, db_session_mock): + td = _make_td_orm() + mock_resolve_td.return_value = td + mock_tt_model.get.return_value = _make_test_type() + + from testgen.mcp.tools.test_definitions import update_test + + result = update_test(str(td.id), fields={"threshold_value": "80"}) + + assert "updated" in result.lower() + assert "threshold_value" in result + assert "80" in result + assert td.threshold_value == "80" + td.save.assert_called_once() + + +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_update_test_empty_fields_rejected(mock_resolve_td, mock_tt_model, db_session_mock): + td = _make_td_orm() + mock_resolve_td.return_value = td + mock_tt_model.get.return_value = _make_test_type() + + from testgen.mcp.tools.test_definitions import update_test + + with pytest.raises(MCPUserError): + update_test(str(td.id), fields={}) + td.save.assert_not_called() + + +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_update_test_unknown_field_rejected_no_partial(mock_resolve_td, mock_tt_model, db_session_mock): + td = _make_td_orm() + mock_resolve_td.return_value = td + mock_tt_model.get.return_value = _make_test_type() + + from testgen.mcp.tools.test_definitions import update_test + + with pytest.raises(MCPUserError) as exc_info: + # threshold_value is valid, table_name is not — must reject ALL + update_test(str(td.id), fields={"threshold_value": "80", "table_name": "new"}) + assert "table_name" in str(exc_info.value) + # td.threshold_value should NOT have been mutated + assert td.threshold_value == "64" + td.save.assert_not_called() + + +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_update_test_multi_field(mock_resolve_td, mock_tt_model, db_session_mock): + td = _make_td_orm() + mock_resolve_td.return_value = td + mock_tt_model.get.return_value = _make_test_type() + + from testgen.mcp.tools.test_definitions import update_test + + result = update_test( + str(td.id), + fields={"threshold_value": "80", "severity": "Fail", "test_active": False}, + ) + assert "3 field" in result + td.save.assert_called_once() + + +# -- validate_custom_test ----------------------------------------------------- + + +@patch("testgen.mcp.tools.test_definitions.validate_custom_query") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.Connection") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_validate_custom_test_would_pass_when_no_rows( + mock_resolve_suite, mock_conn, mock_tg, mock_validate, db_session_mock, +): + + mock_resolve_suite.return_value = _make_suite() + conn = MagicMock() + conn.connection_name = "warehouse" + conn.sql_flavor_code = "snowflake" + conn.sql_flavor = "snowflake" + mock_conn.get_by_table_group.return_value = conn + mock_tg.get.return_value = _make_table_group() + mock_validate.return_value = CustomQueryResult(row_count=0, preview_rows=[]) + + from testgen.mcp.tools.test_definitions import validate_custom_test + + result = validate_custom_test(str(uuid4()), "SELECT 1 WHERE 1=0") + + assert "ran successfully" in result.lower() + assert "would pass" in result.lower() + assert "0 rows matching the failure criteria" in result + + +@patch("testgen.mcp.permissions._compute_project_permissions") +@patch("testgen.mcp.tools.test_definitions.validate_custom_query") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.Connection") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_validate_custom_test_would_fail_shows_preview_with_view_pii( + mock_resolve_suite, mock_conn, mock_tg, mock_validate, mock_compute, db_session_mock, +): + from testgen.mcp.permissions import ProjectPermissions + + # Grant view_pii on "demo" so values are visible in the preview. + perms = MagicMock(spec=ProjectPermissions) + perms.allowed_codes = ["demo"] + perms.codes_allowed_to.return_value = ["demo"] + perms.has_access.side_effect = lambda code: code == "demo" + mock_compute.return_value = perms + + mock_resolve_suite.return_value = _make_suite() + conn = MagicMock() + conn.connection_name = "warehouse" + conn.sql_flavor_code = "snowflake" + conn.sql_flavor = "snowflake" + mock_conn.get_by_table_group.return_value = conn + mock_tg.get.return_value = _make_table_group() + + row = MagicMock() + row.keys.return_value = ["order_id", "amount"] + row.__getitem__.side_effect = lambda k: {"order_id": "ORD-123", "amount": "-45.99"}[k] + mock_validate.return_value = CustomQueryResult(row_count=3, preview_rows=[row]) + + from testgen.mcp.tools.test_definitions import validate_custom_test + + result = validate_custom_test(str(uuid4()), "SELECT * FROM orders WHERE amount < 0") + + assert "would fail" in result.lower() + assert "3 row(s) matching the failure criteria" in result + assert "order_id" in result + assert "ORD-123" in result + assert "[redacted]" not in result + + +@patch("testgen.mcp.tools.test_definitions.validate_custom_query") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.Connection") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_validate_custom_test_redacts_when_no_view_pii( + mock_resolve_suite, mock_conn, mock_tg, mock_validate, db_session_mock, +): + + # Default fixture user has role_a with edit but not view_pii. + mock_resolve_suite.return_value = _make_suite() + conn = MagicMock() + conn.connection_name = "warehouse" + conn.sql_flavor_code = "snowflake" + conn.sql_flavor = "snowflake" + mock_conn.get_by_table_group.return_value = conn + mock_tg.get.return_value = _make_table_group() + + row = MagicMock() + row.keys.return_value = ["order_id", "customer_email"] + row.__getitem__.side_effect = lambda k: {"order_id": "ORD-123", "customer_email": "jane@example.com"}[k] + mock_validate.return_value = CustomQueryResult(row_count=1, preview_rows=[row]) + + from testgen.mcp.tools.test_definitions import validate_custom_test + + result = validate_custom_test(str(uuid4()), "SELECT * FROM orders") + + # Column names always visible + assert "order_id" in result + assert "customer_email" in result + # Values redacted; PII footer mentions permissions (no `view_pii` jargon) + assert "[redacted]" in result + assert "jane@example.com" not in result + assert "ORD-123" not in result + assert "permissions to view PII" in result + assert "view_pii" not in result + + +@patch("testgen.mcp.tools.test_definitions.validate_custom_query") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.Connection") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_validate_custom_test_sql_error_surfaced( + mock_resolve_suite, mock_conn, mock_tg, mock_validate, db_session_mock, +): + mock_resolve_suite.return_value = _make_suite() + conn = MagicMock() + conn.connection_name = "warehouse" + conn.sql_flavor_code = "postgresql" + conn.sql_flavor = "postgresql" + mock_conn.get_by_table_group.return_value = conn + mock_tg.get.return_value = _make_table_group() + mock_validate.side_effect = Exception('syntax error at or near "FROMM"') + + from testgen.mcp.tools.test_definitions import validate_custom_test + + result = validate_custom_test(str(uuid4()), "SELECT * FROMM orders") + + assert "did not execute" in result.lower() + assert "syntax error" in result + + +@patch("testgen.mcp.tools.test_definitions.Connection") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_validate_custom_test_missing_connection(mock_resolve_suite, mock_conn, db_session_mock): + mock_resolve_suite.return_value = _make_suite() + mock_conn.get_by_table_group.return_value = None + + from testgen.mcp.tools.test_definitions import validate_custom_test + + with pytest.raises(MCPUserError, match="No connection"): + validate_custom_test(str(uuid4()), "SELECT 1") + + +# -- bulk_update_tests -------------------------------------------------------- + + +@patch("testgen.mcp.tools.test_definitions.get_current_session") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_bulk_update_tests_disable_no_filter(mock_resolve_suite, mock_session, db_session_mock): + mock_resolve_suite.return_value = _make_suite() + result_mock = MagicMock() + result_mock.rowcount = 3 + mock_session.return_value.execute.return_value = result_mock + + from testgen.mcp.tools.test_definitions import bulk_update_tests + + result = bulk_update_tests(test_suite_id=str(uuid4()), action="disable") + + assert "Disabled" in result + assert "3 test" in result + assert "no filter" in result + + +@patch("testgen.mcp.tools.test_definitions.get_current_session") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_bulk_update_tests_enable_with_table_filter( + mock_resolve_suite, mock_resolve_tt, mock_session, db_session_mock +): + mock_resolve_suite.return_value = _make_suite() + result_mock = MagicMock() + result_mock.rowcount = 1 + mock_session.return_value.execute.return_value = result_mock + + from testgen.mcp.tools.test_definitions import bulk_update_tests + + result = bulk_update_tests( + test_suite_id=str(uuid4()), action="enable", table_name="legacy_orders" + ) + + assert "Enabled" in result + assert "legacy_orders" in result + mock_resolve_tt.assert_not_called() # not called when test_type filter absent + + +@patch("testgen.mcp.tools.test_definitions.get_current_session") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_bulk_update_tests_invalid_action(mock_resolve_suite, mock_session, db_session_mock): + mock_resolve_suite.return_value = _make_suite() + + from testgen.mcp.tools.test_definitions import bulk_update_tests + + with pytest.raises(MCPUserError, match="`action`"): + bulk_update_tests(test_suite_id=str(uuid4()), action="toggle") + + # Suite resolution happens before action validation in current code path? + # Actually, action is validated first; resolve_test_suite shouldn't have been called. + mock_resolve_suite.assert_not_called() + + +@patch("testgen.mcp.tools.test_definitions.get_current_session") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_bulk_update_tests_no_match(mock_resolve_suite, mock_session, db_session_mock): + mock_resolve_suite.return_value = _make_suite() + result_mock = MagicMock() + result_mock.rowcount = 0 + mock_session.return_value.execute.return_value = result_mock + + from testgen.mcp.tools.test_definitions import bulk_update_tests + + result = bulk_update_tests(test_suite_id=str(uuid4()), action="disable", table_name="nonexistent") + + assert "No tests matched" in result + assert "nonexistent" in result diff --git a/tests/unit/mcp/test_tools_test_results.py b/tests/unit/mcp/test_tools_test_results.py index cadcb86c..7b3d08e4 100644 --- a/tests/unit/mcp/test_tools_test_results.py +++ b/tests/unit/mcp/test_tools_test_results.py @@ -4,6 +4,7 @@ import pytest +from testgen.common.enums import JobStatus from testgen.common.models.test_result import TestResultStatus from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import ProjectPermissions @@ -487,7 +488,7 @@ def test_get_failure_summary_passes_project_codes( @patch("testgen.mcp.tools.test_results.TestType") @patch("testgen.mcp.tools.test_results.TestResult") -def test_get_test_result_history_basic(mock_result, mock_tt_cls, db_session_mock): +def test_list_test_result_history_basic(mock_result, mock_tt_cls, db_session_mock): def_id = str(uuid4()) r1 = MagicMock() r1.test_type = "Unique_Pct" @@ -512,9 +513,9 @@ def test_get_test_result_history_basic(mock_result, mock_tt_cls, db_session_mock tt.test_name_short = "Unique Percent" mock_tt_cls.select_where.return_value = [tt] - from testgen.mcp.tools.test_results import get_test_result_history + from testgen.mcp.tools.test_results import list_test_result_history - result = get_test_result_history(def_id) + result = list_test_result_history(def_id) assert "Unique Percent" in result assert "Unique_Pct" not in result @@ -526,26 +527,26 @@ def test_get_test_result_history_basic(mock_result, mock_tt_cls, db_session_mock @patch("testgen.mcp.tools.test_results.TestResult") -def test_get_test_result_history_empty(mock_result, db_session_mock): +def test_list_test_result_history_empty(mock_result, db_session_mock): mock_result.select_history.return_value = [] - from testgen.mcp.tools.test_results import get_test_result_history + from testgen.mcp.tools.test_results import list_test_result_history - result = get_test_result_history(str(uuid4())) + result = list_test_result_history(str(uuid4())) assert "No historical results" in result -def test_get_test_result_history_invalid_uuid(db_session_mock): - from testgen.mcp.tools.test_results import get_test_result_history +def test_list_test_result_history_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_results import list_test_result_history with pytest.raises(MCPUserError, match="not a valid UUID"): - get_test_result_history("bad-uuid") + list_test_result_history("bad-uuid") @patch("testgen.mcp.tools.test_results.TestResult") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_result_history_passes_project_codes( +def test_list_test_result_history_passes_project_codes( mock_compute, mock_result, db_session_mock, ): mock_compute.return_value = ProjectPermissions( @@ -555,9 +556,9 @@ def test_get_test_result_history_passes_project_codes( ) mock_result.select_history.return_value = [] - from testgen.mcp.tools.test_results import get_test_result_history + from testgen.mcp.tools.test_results import list_test_result_history - get_test_result_history(str(uuid4())) + list_test_result_history(str(uuid4())) call_kwargs = mock_result.select_history.call_args.kwargs assert call_kwargs["project_codes"] == ["proj_a"] @@ -866,33 +867,54 @@ def test_get_failure_trend_exclude_today_shifts_end_date(mock_compute, mock_fail # ---------------------------------------------------------------------- -# get_test_run_diff +# compare_test_runs # ---------------------------------------------------------------------- -def _mock_diff_row(status_a, status_b, **overrides): +def _mock_diff_row(status_baseline, status_target, **overrides): row = MagicMock() row.test_definition_id = uuid4() row.test_type = "Pattern_Match" row.test_name_short = "Pattern Match" row.table_name = "orders" row.column_names = "customer_id" - row.status_a = status_a - row.status_b = status_b - row.measure_a = "5" - row.measure_b = "12" - row.threshold_a = "0" - row.threshold_b = "0" + row.status_baseline = status_baseline + row.status_target = status_target + row.measure_baseline = "5" + row.measure_target = "12" + row.threshold_baseline = "0" + row.threshold_target = "0" for k, v in overrides.items(): setattr(row, k, v) return row +def _mock_run(suite_id, je_id=None): + run = MagicMock(id=uuid4(), test_suite_id=suite_id) + run.job_execution_id = je_id or uuid4() + return run + + +def _je(status=JobStatus.COMPLETED): + """Build a JobExecution mock for ``session.get(JobExecution, ...)`` returns.""" + je = MagicMock() + je.status = status + return je + + +def _patch_test_results_session(jes): + """Patch ``get_current_session`` in test_results so ``session.get(JobExecution, ...)`` + returns the given JEs in order (one per ``_require_completed`` call).""" + session = MagicMock() + session.get.side_effect = jes + return patch("testgen.mcp.tools.test_results.get_current_session", return_value=session) + + @patch("testgen.mcp.tools.test_results.TestSuite") @patch("testgen.mcp.tools.test_results.TestResult") @patch("testgen.mcp.tools.test_results.TestRun") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_run_diff_happy_path( +def test_compare_test_runs_happy_path( mock_compute, mock_test_run_cls, mock_result, mock_test_suite_cls, db_session_mock, ): mock_compute.return_value = ProjectPermissions( @@ -901,18 +923,21 @@ def test_get_test_run_diff_happy_path( username="test_user", ) suite_id = uuid4() - run_a = MagicMock(id=uuid4(), test_suite_id=suite_id) - run_b = MagicMock(id=uuid4(), test_suite_id=suite_id) - mock_test_run_cls.get_by_id_or_job.side_effect = [run_a, run_b] - mock_test_suite_cls.id = MagicMock() # support .in_(...) on attribute mock - mock_test_suite_cls.select_where.return_value = [MagicMock(id=suite_id, project_code="proj_a", is_monitor=False)] + baseline_run = _mock_run(suite_id) + target_run = _mock_run(suite_id) + # Tool resolves target first, then baseline. + mock_test_run_cls.get_by_id_or_job.side_effect = [target_run, baseline_run] + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_a") diff = MagicMock() - diff.total_a = 100 - diff.total_b = 100 + diff.total_baseline = 100 + diff.total_target = 100 diff.regressions = [ _mock_diff_row( - TestResultStatus.Passed, TestResultStatus.Failed, threshold_a="1", threshold_b="3", + TestResultStatus.Passed, + TestResultStatus.Failed, + threshold_baseline="1", + threshold_target="3", ) ] diff.improvements = [] @@ -921,69 +946,154 @@ def test_get_test_run_diff_happy_path( diff.removed_tests = [] mock_result.diff_with_details.return_value = diff - from testgen.mcp.tools.test_results import get_test_run_diff + from testgen.mcp.tools.test_results import compare_test_runs - out = get_test_run_diff(str(uuid4()), str(uuid4())) + with _patch_test_results_session([_je(), _je()]): + out = compare_test_runs(str(uuid4()), str(uuid4())) - assert "Test Run Diff" in out + assert "Test Run Comparison" in out assert "Regressions" in out assert "Pattern Match" in out assert "Passed → Failed" in out - assert "Threshold A" in out and "Threshold B" in out + assert "Threshold Baseline" in out and "Threshold Target" in out assert "| 1 | 3 |" in out # threshold columns populated when thresholds changed + # diff_with_details called with (baseline_run.id, target_run.id) in that order. + mock_result.diff_with_details.assert_called_once_with(baseline_run.id, target_run.id) @patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestResult") @patch("testgen.mcp.tools.test_results.TestRun") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_run_diff_run_not_found( +def test_compare_test_runs_single_arg_resolves_previous( + mock_compute, mock_test_run_cls, mock_result, mock_test_suite_cls, db_session_mock, +): + """Only target supplied — baseline is resolved via target_run.get_previous().""" + mock_compute.return_value = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="view", + username="test_user", + ) + suite_id = uuid4() + target_run = _mock_run(suite_id) + baseline_run = _mock_run(suite_id) + target_run.get_previous.return_value = baseline_run + mock_test_run_cls.get_by_id_or_job.return_value = target_run + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_a") + + diff = MagicMock( + total_baseline=10, total_target=10, + regressions=[], improvements=[], persistent_failures=[], new_tests=[], removed_tests=[], + ) + mock_result.diff_with_details.return_value = diff + + from testgen.mcp.tools.test_results import compare_test_runs + + with _patch_test_results_session([_je()]): + out = compare_test_runs(str(uuid4())) + + target_run.get_previous.assert_called_once_with() + mock_result.diff_with_details.assert_called_once_with(baseline_run.id, target_run.id) + # Rendered Baseline cell shows the resolved JE ID, not an input string. + assert str(baseline_run.job_execution_id) in out + + +@patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestRun") +@patch("testgen.mcp.permissions._compute_project_permissions") +def test_compare_test_runs_single_arg_no_previous_raises( mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, ): - """One run missing, other accessible — unified error without leaking which side failed.""" + """Target is the oldest run — get_previous() returns None — clear user-facing error.""" mock_compute.return_value = ProjectPermissions( memberships={"proj_a": "role_a"}, permission="view", username="test_user", ) suite_id = uuid4() - mock_test_run_cls.get_by_id_or_job.side_effect = [None, MagicMock(id=uuid4(), test_suite_id=suite_id)] - mock_test_suite_cls.id = MagicMock() - mock_test_suite_cls.select_where.return_value = [MagicMock(id=suite_id, project_code="proj_a", is_monitor=False)] + target_run = _mock_run(suite_id) + target_run.get_previous.return_value = None + mock_test_run_cls.get_by_id_or_job.return_value = target_run + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_a") - from testgen.mcp.tools.test_results import get_test_run_diff + from testgen.mcp.tools.test_results import compare_test_runs + + with _patch_test_results_session([_je()]), pytest.raises(MCPUserError, match="no earlier completed test run"): + compare_test_runs(str(uuid4())) + + +@patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestRun") +@patch("testgen.mcp.permissions._compute_project_permissions") +def test_compare_test_runs_single_arg_inaccessible_target( + mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, +): + """Inaccessible target — error raised before get_previous() is consulted.""" + mock_compute.return_value = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="view", + username="test_user", + ) + suite_id = uuid4() + target_run = _mock_run(suite_id) + mock_test_run_cls.get_by_id_or_job.return_value = target_run + # Monitor suite or inaccessible project — get_regular returns None either way. + mock_test_suite_cls.get_regular.return_value = None + + from testgen.mcp.tools.test_results import compare_test_runs with pytest.raises(MCPResourceNotAccessible, match="Test run .* not found or not accessible"): - get_test_run_diff(str(uuid4()), str(uuid4())) + compare_test_runs(str(uuid4())) + target_run.get_previous.assert_not_called() @patch("testgen.mcp.tools.test_results.TestSuite") @patch("testgen.mcp.tools.test_results.TestRun") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_run_diff_rejects_inaccessible_project( +def test_compare_test_runs_run_not_found( mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, ): - """Runs in an inaccessible project produce the same unified message, not a separate one.""" + """Target not found — unified not-found-or-inaccessible error.""" + mock_compute.return_value = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="view", + username="test_user", + ) + mock_test_run_cls.get_by_id_or_job.return_value = None + + from testgen.mcp.tools.test_results import compare_test_runs + + with pytest.raises(MCPResourceNotAccessible, match="Test run .* not found or not accessible"): + compare_test_runs(str(uuid4()), str(uuid4())) + + +@patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestRun") +@patch("testgen.mcp.permissions._compute_project_permissions") +def test_compare_test_runs_rejects_inaccessible_project( + mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, +): + """Runs in an inaccessible project produce the unified message.""" mock_compute.return_value = ProjectPermissions( memberships={"proj_a": "role_a"}, permission="view", username="test_user", ) suite_id = uuid4() - run = MagicMock(id=uuid4(), test_suite_id=suite_id) - mock_test_run_cls.get_by_id_or_job.side_effect = [run, run] - mock_test_suite_cls.id = MagicMock() - mock_test_suite_cls.select_where.return_value = [MagicMock(id=suite_id, project_code="proj_forbidden", is_monitor=False)] + run = _mock_run(suite_id) + mock_test_run_cls.get_by_id_or_job.return_value = run + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_forbidden") - from testgen.mcp.tools.test_results import get_test_run_diff + from testgen.mcp.tools.test_results import compare_test_runs - with pytest.raises(MCPUserError, match="not found or not accessible"): - get_test_run_diff(str(uuid4()), str(uuid4())) + with pytest.raises(MCPResourceNotAccessible, match="not found or not accessible"): + compare_test_runs(str(uuid4()), str(uuid4())) @patch("testgen.mcp.tools.test_results.TestSuite") @patch("testgen.mcp.tools.test_results.TestRun") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_run_diff_rejects_different_suites( +def test_compare_test_runs_rejects_different_suites( mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, ): """Both runs accessible but in different suites → suite-mismatch error.""" @@ -992,51 +1102,97 @@ def test_get_test_run_diff_rejects_different_suites( permission="view", username="test_user", ) - suite_id_a = uuid4() - suite_id_b = uuid4() - run_a = MagicMock(id=uuid4(), test_suite_id=suite_id_a) - run_b = MagicMock(id=uuid4(), test_suite_id=suite_id_b) - mock_test_run_cls.get_by_id_or_job.side_effect = [run_a, run_b] - mock_test_suite_cls.id = MagicMock() - mock_test_suite_cls.select_where.return_value = [ - MagicMock(id=suite_id_a, project_code="proj_a", is_monitor=False), - MagicMock(id=suite_id_b, project_code="proj_a", is_monitor=False), + suite_id_target = uuid4() + suite_id_baseline = uuid4() + target_run = _mock_run(suite_id_target) + baseline_run = _mock_run(suite_id_baseline) + mock_test_run_cls.get_by_id_or_job.side_effect = [target_run, baseline_run] + mock_test_suite_cls.get_regular.side_effect = [ + _mock_test_suite(suite_id=suite_id_target, project_code="proj_a"), + _mock_test_suite(suite_id=suite_id_baseline, project_code="proj_a"), ] - from testgen.mcp.tools.test_results import get_test_run_diff + from testgen.mcp.tools.test_results import compare_test_runs - with pytest.raises(MCPUserError, match="must belong to the same test suite"): - get_test_run_diff(str(uuid4()), str(uuid4())) + with _patch_test_results_session([_je()]), pytest.raises(MCPUserError, match="must belong to the same test suite"): + compare_test_runs(str(uuid4()), str(uuid4())) -def test_get_test_run_diff_invalid_uuid(db_session_mock): - from testgen.mcp.tools.test_results import get_test_run_diff +def test_compare_test_runs_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_results import compare_test_runs with pytest.raises(MCPUserError, match="not a valid UUID"): - get_test_run_diff("bad-uuid", str(uuid4())) + compare_test_runs("bad-uuid", str(uuid4())) @patch("testgen.mcp.tools.test_results.TestSuite") @patch("testgen.mcp.tools.test_results.TestRun") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_run_diff_rejects_monitor_suite( +def test_compare_test_runs_rejects_monitor_suite( mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, ): - """Monitor suites are hidden from this tool, same as inaccessible projects — unified message.""" + """Monitor suites are hidden — TestSuite.get_regular returns None — unified message.""" mock_compute.return_value = ProjectPermissions( memberships={"proj_a": "role_a"}, permission="view", username="test_user", ) suite_id = uuid4() - run = MagicMock(id=uuid4(), test_suite_id=suite_id) - mock_test_run_cls.get_by_id_or_job.side_effect = [run, run] - mock_test_suite_cls.id = MagicMock() - mock_test_suite_cls.select_where.return_value = [ - MagicMock(id=suite_id, project_code="proj_a", is_monitor=True) - ] + run = _mock_run(suite_id) + mock_test_run_cls.get_by_id_or_job.return_value = run + mock_test_suite_cls.get_regular.return_value = None - from testgen.mcp.tools.test_results import get_test_run_diff + from testgen.mcp.tools.test_results import compare_test_runs - with pytest.raises(MCPUserError, match="not found or not accessible"): - get_test_run_diff(str(uuid4()), str(uuid4())) + with pytest.raises(MCPResourceNotAccessible, match="not found or not accessible"): + compare_test_runs(str(uuid4()), str(uuid4())) + + +@patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestRun") +@patch("testgen.mcp.permissions._compute_project_permissions") +def test_compare_test_runs_rejects_target_not_completed( + mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, +): + """Target run still Running — comparison rejected before any diff work.""" + mock_compute.return_value = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="view", + username="test_user", + ) + suite_id = uuid4() + target_run = _mock_run(suite_id) + mock_test_run_cls.get_by_id_or_job.return_value = target_run + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_a") + + from testgen.mcp.tools.test_results import compare_test_runs + + with _patch_test_results_session([_je(status=JobStatus.RUNNING)]), \ + pytest.raises(MCPUserError, match=r"Target run is in `Running` state"): + compare_test_runs(str(uuid4())) + target_run.get_previous.assert_not_called() + + +@patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestRun") +@patch("testgen.mcp.permissions._compute_project_permissions") +def test_compare_test_runs_rejects_baseline_not_completed( + mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, +): + """Two-arg path: target completes the check but baseline is in Error state.""" + mock_compute.return_value = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="view", + username="test_user", + ) + suite_id = uuid4() + target_run = _mock_run(suite_id) + baseline_run = _mock_run(suite_id) + mock_test_run_cls.get_by_id_or_job.side_effect = [target_run, baseline_run] + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_a") + + from testgen.mcp.tools.test_results import compare_test_runs + + with _patch_test_results_session([_je(), _je(status=JobStatus.ERROR)]), \ + pytest.raises(MCPUserError, match=r"Baseline run is in `Error` state"): + compare_test_runs(str(uuid4()), str(uuid4())) diff --git a/tests/unit/mcp/test_tools_test_runs.py b/tests/unit/mcp/test_tools_test_runs.py index c914dd25..7a71380e 100644 --- a/tests/unit/mcp/test_tools_test_runs.py +++ b/tests/unit/mcp/test_tools_test_runs.py @@ -1,188 +1,305 @@ +from datetime import UTC, datetime from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from testgen.mcp.exceptions import MCPPermissionDenied +from testgen.common.enums import JobStatus +from testgen.mcp.exceptions import MCPPermissionDenied, MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import ProjectPermissions +_CREATED = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC) +_STARTED = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC) +_COMPLETED = datetime(2024, 1, 15, 10, 5, 0, tzinfo=UTC) + def _make_run_summary(**overrides): defaults = { "test_run_id": uuid4(), "job_execution_id": uuid4(), - "test_suite": "Quality Suite", "project_name": "Demo", - "table_groups_name": "core_tables", "status": "completed", + "test_suite": "Quality Suite", "project_name": "Demo", "project_code": "demo", + "table_groups_name": "core_tables", "status": JobStatus.COMPLETED, "status_label": "Completed", - "created_at": "2024-01-15T10:00:00", - "started_at": "2024-01-15T10:00:00", "completed_at": "2024-01-15T10:05:00", + "created_at": _CREATED, "started_at": _STARTED, "completed_at": _COMPLETED, "test_ct": 50, "passed_ct": 45, "failed_ct": 3, "warning_ct": 2, "error_ct": 0, "log_ct": 0, "dismissed_ct": 0, "dq_score_testing": 92.5, + "error_message": None, } defaults.update(overrides) return MagicMock(**defaults) +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_default_limit(mock_suite, mock_run, db_session_mock): - """Default limit=1 returns one run per suite.""" - runs = [_make_run_summary(test_run_id=uuid4()) for _ in range(7)] +def test_list_test_runs_default(mock_suite, mock_run, mock_next, db_session_mock): + runs = [_make_run_summary() for _ in range(3)] mock_run.select_summary.return_value = (runs, len(runs)) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo") + result = list_test_runs(project_code="demo") - # All 7 runs have test_suite="Quality Suite", so only 1 should appear - assert "1 run(s)" in result + mock_run.select_summary.assert_called_once_with( + project_code="demo", + table_group_id=None, + test_suite_id=None, + statuses=None, + page=1, + page_size=10, + ) + assert "Test runs" in result + assert "demo" in result assert "Quality Suite" in result assert "92.5" in result - mock_run.select_summary.assert_called_once_with(project_code="demo", test_suite_id=None, page_size=1000) - - -@patch("testgen.mcp.tools.test_runs.TestRun") -@patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_custom_limit(mock_suite, mock_run, db_session_mock): - """Custom limit returns up to N runs per suite.""" - runs = [_make_run_summary() for _ in range(3)] - mock_run.select_summary.return_value = (runs, len(runs)) - - from testgen.mcp.tools.test_runs import get_recent_test_runs - - result = get_recent_test_runs("demo", limit=10) - - assert "3 run(s)" in result +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_per_suite_grouping(mock_suite, mock_run, db_session_mock): - """With multiple suites, returns limit runs per suite.""" - runs = [ - _make_run_summary(test_suite="Suite A", test_run_id=uuid4()), - _make_run_summary(test_suite="Suite A", test_run_id=uuid4()), - _make_run_summary(test_suite="Suite B", test_run_id=uuid4()), - _make_run_summary(test_suite="Suite B", test_run_id=uuid4()), - ] - mock_run.select_summary.return_value = (runs, len(runs)) +def test_list_test_runs_with_status_filter(mock_suite, mock_run, mock_next, db_session_mock): + mock_run.select_summary.return_value = ([], 0) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo") + list_test_runs(project_code="demo", status="Pending") - # limit=1 (default), so 1 per suite = 2 total - assert "2 run(s)" in result - assert "Suite A" in result - assert "Suite B" in result + call_kwargs = mock_run.select_summary.call_args.kwargs + assert call_kwargs["statuses"] == [JobStatus.PENDING, JobStatus.CLAIMED] +@patch("testgen.mcp.tools.test_runs.JobExecution") +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_with_suite_name(mock_suite, mock_run, db_session_mock): +def test_list_test_runs_with_suite_name(mock_suite, mock_run, mock_next, mock_je, db_session_mock): + mock_je.select_active_by_kwargs.return_value = [] suite_id = uuid4() - suite_minimal = MagicMock() - suite_minimal.id = suite_id + suite_minimal = MagicMock(id=suite_id) mock_suite.select_minimal_where.return_value = [suite_minimal] mock_run.select_summary.return_value = ([_make_run_summary(test_suite="My Suite")], 1) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo", test_suite="My Suite") + result = list_test_runs(project_code="demo", test_suite="My Suite") - mock_run.select_summary.assert_called_once_with(project_code="demo", test_suite_id=str(suite_id), page_size=1000) + call_kwargs = mock_run.select_summary.call_args.kwargs + assert call_kwargs["test_suite_id"] == str(suite_id) assert "My Suite" in result +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_suite_not_found(mock_suite, mock_run, db_session_mock): +def test_list_test_runs_suite_not_found(mock_suite, mock_run, mock_next, db_session_mock): mock_suite.select_minimal_where.return_value = [] - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo", test_suite="Nonexistent") - - assert "not found" in result + with pytest.raises(MCPResourceNotAccessible): + list_test_runs(project_code="demo", test_suite="Nonexistent") mock_run.select_summary.assert_not_called() +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_no_runs(mock_suite, mock_run, db_session_mock): +def test_list_test_runs_empty(mock_suite, mock_run, mock_next, db_session_mock): mock_run.select_summary.return_value = ([], 0) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo") + result = list_test_runs(project_code="demo") - assert "No completed test runs" in result + assert "No test runs" in result +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_shows_failure_counts(mock_suite, mock_run, db_session_mock): - mock_run.select_summary.return_value = ([_make_run_summary(failed_ct=5, warning_ct=2)], 1) +def test_list_test_runs_includes_pending_run(mock_suite, mock_run, mock_next, db_session_mock): + pending = _make_run_summary( + status=JobStatus.PENDING, status_label="Pending", + started_at=None, completed_at=None, + test_ct=None, passed_ct=None, failed_ct=None, warning_ct=None, error_ct=None, + log_ct=None, dismissed_ct=None, dq_score_testing=None, + ) + mock_run.select_summary.return_value = ([pending], 1) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo") + result = list_test_runs(project_code="demo") - assert "5 failed" in result - assert "2 warnings" in result + assert "Pending" in result + assert "In progress" in result +@patch("testgen.mcp.tools.test_runs.JobExecution") +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value="2026-06-01T02:00:00") @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_outputs_job_execution_id(mock_suite, mock_run, db_session_mock): - """Output should contain job_execution_id, not test_run_id.""" - job_exec_id = uuid4() - run = _make_run_summary(job_execution_id=job_exec_id) - mock_run.select_summary.return_value = ([run], 1) +def test_list_test_runs_shows_next_scheduled(mock_suite, mock_run, mock_next, mock_je, db_session_mock): + mock_je.select_active_by_kwargs.return_value = [] + suite_id = uuid4() + mock_suite.select_minimal_where.return_value = [MagicMock(id=suite_id)] + mock_run.select_summary.return_value = ([], 0) + + from testgen.mcp.tools.test_runs import list_test_runs + + result = list_test_runs(project_code="demo", test_suite="Quality") + + assert "Next scheduled run" in result + + +@patch("testgen.mcp.tools.test_runs.JobExecution") +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) +@patch("testgen.mcp.tools.test_runs.TestRun") +@patch("testgen.mcp.tools.test_runs.TestSuite") +def test_list_test_runs_renders_pending_section( + mock_suite, mock_run, mock_next, mock_je, db_session_mock, +): + """When scoped by suite, pending JEs are surfaced in a separate section.""" + suite_id = uuid4() + mock_suite.select_minimal_where.return_value = [MagicMock(id=suite_id)] + mock_run.select_summary.return_value = ([], 0) + pending_je = MagicMock( + id=uuid4(), status=JobStatus.PENDING, + created_at=_CREATED, started_at=None, completed_at=None, + ) + mock_je.select_active_by_kwargs.return_value = [pending_je] + + from testgen.mcp.tools.test_runs import list_test_runs + + result = list_test_runs(project_code="demo", test_suite="Quality") - from testgen.mcp.tools.test_runs import get_recent_test_runs + assert "Pending (1)" in result + assert "In progress" in result + mock_je.select_active_by_kwargs.assert_called_once() - result = get_recent_test_runs("demo") - assert str(job_exec_id) in result - assert "job_execution_id" in result +def test_list_test_runs_invalid_status(db_session_mock): + from testgen.mcp.tools.test_runs import list_test_runs + with pytest.raises(MCPUserError, match="Invalid status"): + list_test_runs(project_code="demo", status="Bogus") -def test_get_recent_test_runs_empty_project_code(db_session_mock): - from testgen.mcp.tools.test_runs import get_recent_test_runs - result = get_recent_test_runs("") +def test_list_test_runs_requires_project_or_table_group(db_session_mock): + from testgen.mcp.tools.test_runs import list_test_runs - assert "Missing required parameter" in result - assert "project_code" in result + with pytest.raises(MCPUserError, match="Provide either"): + list_test_runs() @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_recent_test_runs_raises_not_found_for_inaccessible_project( - mock_compute, db_session_mock, -): +def test_list_test_runs_raises_not_found_for_inaccessible_project(mock_compute, db_session_mock): mock_compute.return_value = ProjectPermissions( memberships={"other_project": "role_a"}, permission="view", username="test_user", ) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - with pytest.raises(MCPPermissionDenied, match="No completed test runs found in project `secret_project`"): - get_recent_test_runs("secret_project") + with pytest.raises(MCPPermissionDenied): + list_test_runs(project_code="secret_project") -@patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_recent_test_runs_raises_denial_for_insufficient_permission( - mock_compute, db_session_mock, -): - mock_compute.return_value = ProjectPermissions( - memberships={"other_project": "role_a", "secret_project": "role_c"}, - permission="view", - username="test_user", +# ---------------------------------------------------------------------- +# get_test_run +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.test_runs.TestRun") +def test_get_test_run_returns_detail(mock_run, db_session_mock): + summary = _make_run_summary(project_code="demo") + mock_run.select_summary.return_value = ([summary], 1) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, + permission="view", + username="test_user", + ) + with patch( + "testgen.mcp.permissions.PluginHook" + ) as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + from testgen.mcp.tools.test_runs import get_test_run + + result = get_test_run(str(summary.job_execution_id)) + + assert "Quality Suite" in result + assert "Completed" in result + assert "92.5" in result + + +@patch("testgen.mcp.tools.test_runs.TestRun") +def test_get_test_run_pending_no_results(mock_run, db_session_mock): + summary = _make_run_summary( + project_code="demo", + status=JobStatus.PENDING, status_label="Pending", + started_at=None, completed_at=None, + test_ct=None, passed_ct=None, failed_ct=None, warning_ct=None, error_ct=None, + log_ct=None, dismissed_ct=None, dq_score_testing=None, ) + mock_run.select_summary.return_value = ([summary], 1) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, + permission="view", + username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + from testgen.mcp.tools.test_runs import get_test_run + + result = get_test_run(str(summary.job_execution_id)) + + assert "Pending" in result + assert "In progress" in result + assert "Results" not in result + + +@patch("testgen.mcp.tools.test_runs.TestRun") +def test_get_test_run_not_found(mock_run, db_session_mock): + mock_run.select_summary.return_value = ([], 0) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, + permission="view", + username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + from testgen.mcp.tools.test_runs import get_test_run + + with pytest.raises(MCPResourceNotAccessible): + get_test_run(str(uuid4())) + + +@patch("testgen.mcp.tools.test_runs.TestRun") +def test_get_test_run_inaccessible_project(mock_run, db_session_mock): + summary = _make_run_summary(project_code="secret") + mock_run.select_summary.return_value = ([summary], 1) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, + permission="view", + username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + from testgen.mcp.tools.test_runs import get_test_run + + with pytest.raises(MCPResourceNotAccessible): + get_test_run(str(summary.job_execution_id)) + - from testgen.mcp.tools.test_runs import get_recent_test_runs +def test_get_test_run_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_runs import get_test_run - with pytest.raises(MCPPermissionDenied, match="necessary permission"): - get_recent_test_runs("secret_project") + with pytest.raises(MCPUserError, match="not a valid UUID"): + get_test_run("not-a-uuid") diff --git a/tests/unit/mcp/test_transport_security.py b/tests/unit/mcp/test_transport_security.py new file mode 100644 index 00000000..22b101ec --- /dev/null +++ b/tests/unit/mcp/test_transport_security.py @@ -0,0 +1,86 @@ +"""Tests for testgen.mcp.server._build_transport_security — DNS rebinding allowlist builder.""" + +from unittest.mock import patch + +from testgen.mcp.server import _build_transport_security + + +def _build_with(base_url: str, extras: list[str] | None = None): + with ( + patch("testgen.mcp.server.settings.BASE_URL", base_url), + patch("testgen.mcp.server.settings.MCP_EXTRA_ALLOWED_HOSTS", extras or []), + ): + return _build_transport_security() + + +def test_loopback_and_base_url_always_present(): + """With no extras, the allowlist is BASE_URL hosts + loopback variants.""" + settings = _build_with("http://tg.example.com:8530") + + assert settings.enable_dns_rebinding_protection is True + assert "tg.example.com:8530" in settings.allowed_hosts + assert "tg.example.com:*" in settings.allowed_hosts + assert "127.0.0.1:*" in settings.allowed_hosts + assert "localhost:*" in settings.allowed_hosts + assert "[::1]:*" in settings.allowed_hosts + + assert "http://tg.example.com:8530" in settings.allowed_origins + # Loopback origins covered for both schemes + assert "http://localhost:*" in settings.allowed_origins + assert "https://localhost:*" in settings.allowed_origins + + +def test_extra_host_without_port_gets_wildcard(): + """An extras entry without `:` gets `:*` automatically appended.""" + settings = _build_with("http://localhost:8530", extras=["tg.example.com"]) + + assert "tg.example.com:*" in settings.allowed_hosts + assert "tg.example.com" not in settings.allowed_hosts # bare entry should NOT be present + assert "http://tg.example.com:*" in settings.allowed_origins + assert "https://tg.example.com:*" in settings.allowed_origins + + +def test_extra_host_with_explicit_port_preserved_literally(): + """An extras entry with an explicit port is kept as-is, no wildcard appended.""" + settings = _build_with("http://localhost:8530", extras=["tg.example.com:8080"]) + + assert "tg.example.com:8080" in settings.allowed_hosts + assert "tg.example.com:8080:*" not in settings.allowed_hosts # no double-port + + assert "http://tg.example.com:8080" in settings.allowed_origins + assert "https://tg.example.com:8080" in settings.allowed_origins + + +def test_extra_host_with_explicit_wildcard_preserved(): + """An extras entry with `:*` is kept as-is.""" + settings = _build_with("http://localhost:8530", extras=["tg.example.com:*"]) + + assert "tg.example.com:*" in settings.allowed_hosts + assert "http://tg.example.com:*" in settings.allowed_origins + + +def test_mixed_extras(): + """Multiple extras with different shapes are all handled correctly.""" + settings = _build_with( + "http://localhost:8530", + extras=["foo.com", "bar.io:9000", "baz.net:*"], + ) + + assert "foo.com:*" in settings.allowed_hosts + assert "bar.io:9000" in settings.allowed_hosts + assert "baz.net:*" in settings.allowed_hosts + + +def test_https_base_url_origin_uses_https_scheme(): + """Origin scheme tracks BASE_URL's scheme.""" + settings = _build_with("https://tg.example.com") + + assert "https://tg.example.com" in settings.allowed_origins + + +def test_results_are_sorted_lists(): + """allowed_hosts and allowed_origins are deterministic (sorted) for stable diffs.""" + settings = _build_with("http://localhost:8530", extras=["zeta.com", "alpha.com"]) + + assert settings.allowed_hosts == sorted(settings.allowed_hosts) + assert settings.allowed_origins == sorted(settings.allowed_origins) diff --git a/tests/unit/scheduler/test_scheduler_cli.py b/tests/unit/scheduler/test_scheduler_cli.py index ce2acec3..ed8f1556 100644 --- a/tests/unit/scheduler/test_scheduler_cli.py +++ b/tests/unit/scheduler/test_scheduler_cli.py @@ -7,6 +7,8 @@ import pytest +from testgen.commands.job_registry import JobConfig +from testgen.common.enums import JobKey, JobSource from testgen.common.models.job_execution import JobExecution from testgen.common.models.scheduler import JobSchedule from testgen.scheduler.base import DelayedPolicy @@ -42,19 +44,18 @@ def popen_mock(popen_proc_mock): @pytest.fixture def db_jobs(scheduler_instance): with ( - patch("testgen.scheduler.cli_scheduler.JobSchedule.select_where") as mock, + patch("testgen.scheduler.cli_scheduler.JobSchedule.select_runnable") as mock, ): yield mock @pytest.fixture def job_data(): - with patch.dict("testgen.commands.job_registry.JOB_DISPATCH", {"test-job": Mock()}): + with patch.dict("testgen.commands.job_registry.JOB_DISPATCH", {"test-job": JobConfig(handler=Mock())}): yield { "cron_expr": "*/5 9-17 * * *", "cron_tz": "UTC", "key": "test-job", - "args": ["a"], "kwargs": {"b": "c"}, } @@ -76,7 +77,7 @@ def test_get_jobs(scheduler_instance, db_jobs, job_sched): assert len(jobs) == 1 assert isinstance(jobs[0], CliJob) - for attr in ("cron_expr", "cron_tz", "key", "args", "kwargs"): + for attr in ("cron_expr", "cron_tz", "key", "kwargs"): assert getattr(jobs[0], attr) == getattr(job_sched, attr), f"Attribute '{attr}' does not match" @@ -96,6 +97,32 @@ def test_job_start(scheduler_instance, cli_job): mock_session.commit.assert_called_once() +def test_job_start_tags_source_from_job_config(scheduler_instance, job_data): + """Scheduled executions inherit `JobConfig.scheduler_source` as their + `JobExecution.source`. Retention cleanup (registered with + `scheduler_source="system"`) gets `source="system"` so MCP / REST / + api/deps filters auto-hide it from user-facing surfaces.""" + from testgen.scheduler.cli_scheduler import CliJob + + system_job_data = {**job_data, "key": JobKey.run_data_cleanup} + system_cli_job = CliJob(**system_job_data, delayed_policy=DelayedPolicy.SKIP) + + mock_session = MagicMock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=False) + with ( + patch.dict( + "testgen.commands.job_registry.JOB_DISPATCH", + {JobKey.run_data_cleanup: JobConfig(handler=Mock(), scheduler_source=JobSource.system)}, + ), + patch("testgen.common.models.Session", return_value=mock_session), + ): + scheduler_instance.start_job(system_cli_job, datetime.now(UTC)) + + added = mock_session.add.call_args[0][0] + assert added.source == "system" + + @pytest.mark.parametrize("proc_exit_code", [0, 1]) def test_proc_wrapper_status(proc_exit_code, scheduler_instance): mock_session = MagicMock() diff --git a/tests/unit/scheduler/test_scheduler_poll.py b/tests/unit/scheduler/test_scheduler_poll.py index 20b3bfeb..39082942 100644 --- a/tests/unit/scheduler/test_scheduler_poll.py +++ b/tests/unit/scheduler/test_scheduler_poll.py @@ -4,8 +4,9 @@ import pytest -from testgen.commands.job_registry import JOB_DISPATCH -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.commands.job_registry import JOB_DISPATCH, JobConfig +from testgen.common.enums import JobKey, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.scheduler.cli_scheduler import CliScheduler pytestmark = pytest.mark.unit @@ -35,8 +36,7 @@ def scheduler_instance(): def job_exec(): return JobExecution( id=uuid4(), - job_key="run-tests", - args=[], + job_key=JobKey.run_tests, kwargs={"test_suite_id": "suite-123"}, source="scheduler", status="claimed", @@ -56,7 +56,7 @@ def test_dispatch_spawns_process(scheduler_instance, job_exec, mock_session): proc_mock = MagicMock() with ( - patch.dict(JOB_DISPATCH, {"run-tests": Mock()}, clear=False), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock())}, clear=False), patch(f"{SCHEDULER_MODULE}.subprocess.Popen", return_value=proc_mock) as popen_mock, patch(f"{SCHEDULER_MODULE}.threading.Thread") as thread_mock, ): @@ -73,7 +73,6 @@ def test_dispatch_unknown_job_key(scheduler_instance, mock_session): job_exec = JobExecution( id=uuid4(), job_key="nonexistent", - args=[], kwargs={}, source="ui", status="claimed", @@ -219,8 +218,7 @@ def test_poll_loop_routes_cancel_requested(scheduler_instance, mock_session): """Cancel_requested rows are routed to _handle_cancellation, not _dispatch.""" cancel_job = JobExecution( id=uuid4(), - job_key="run-tests", - args=[], + job_key=JobKey.run_tests, kwargs={}, source="ui", status=JobStatus.CANCEL_REQUESTED, @@ -256,8 +254,7 @@ def test_start_job_submits_execution(scheduler_instance, mock_session): cron_expr="*/5 * * * *", cron_tz="UTC", delayed_policy=DelayedPolicy.SKIP, - key="run-profile", - args=[], + key=JobKey.run_profile, kwargs={"table_group_id": "tg-123"}, job_schedule_id=schedule_id, ) @@ -266,7 +263,7 @@ def test_start_job_submits_execution(scheduler_instance, mock_session): mock_session.add.assert_called_once() added = mock_session.add.call_args[0][0] - assert added.job_key == "run-profile" + assert added.job_key == JobKey.run_profile assert added.kwargs == {"table_group_id": "tg-123"} assert added.source == "scheduler" assert added.job_schedule_id == schedule_id diff --git a/tests/unit/server/test_middleware.py b/tests/unit/server/test_middleware.py new file mode 100644 index 00000000..2225deb2 --- /dev/null +++ b/tests/unit/server/test_middleware.py @@ -0,0 +1,320 @@ +"""Tests for testgen.server.middleware — pure-ASGI body cap and security headers.""" + +# ASGI test stubs (receive/send/inner-app) must be async per protocol but don't +# await anything in these tests. RUF029 is a false positive for that pattern. +# ruff: noqa: RUF029 + +import asyncio +import json + +from testgen.server.middleware import BodySizeLimitMiddleware, SecurityHeadersMiddleware + + +def _http_scope(method: str = "POST", headers: list[tuple[bytes, bytes]] | None = None) -> dict: + return {"type": "http", "method": method, "headers": headers or []} + + +# -------------------------- BodySizeLimitMiddleware -------------------------- + + +def test_body_cap_content_length_over_limit_rejects_immediately(): + """Content-Length > max_bytes → 413 sent without invoking the inner app.""" + inner_called = False + + async def inner(scope, receive, send): + nonlocal inner_called + inner_called = True + + mw = BodySizeLimitMiddleware(inner, max_bytes=1024) + scope = _http_scope(headers=[(b"content-length", b"2048")]) + + sent: list[dict] = [] + + async def send(msg): + sent.append(msg) + + async def receive(): + return {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + assert not inner_called + assert sent[0]["type"] == "http.response.start" + assert sent[0]["status"] == 413 + assert json.loads(sent[1]["body"]) == {"detail": "Request body too large"} + + +def test_body_cap_content_length_under_limit_passes_through(): + """Content-Length under the limit → inner app runs normally.""" + received_by_inner: list[dict] = [] + + async def inner(scope, receive, send): + received_by_inner.append(await receive()) + await send({"type": "http.response.start", "status": 200, "headers": []}) + + mw = BodySizeLimitMiddleware(inner, max_bytes=1024) + scope = _http_scope(headers=[(b"content-length", b"100")]) + + queued = [{"type": "http.request", "body": b"x" * 100, "more_body": False}] + sent: list[dict] = [] + + async def send(msg): + sent.append(msg) + + async def receive(): + return queued.pop(0) if queued else {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + assert received_by_inner[0]["body"] == b"x" * 100 + assert sent[0]["status"] == 200 + + +def test_body_cap_streaming_disconnects_when_exceeded(): + """Without Content-Length, accumulating body chunks past the limit returns disconnect.""" + received_by_inner: list[dict] = [] + + async def inner(scope, receive, send): + # Drain three chunks: third one pushes past the limit + for _ in range(3): + received_by_inner.append(await receive()) + + mw = BodySizeLimitMiddleware(inner, max_bytes=150) + scope = _http_scope(headers=[]) + + queued = [ + {"type": "http.request", "body": b"x" * 100, "more_body": True}, + {"type": "http.request", "body": b"y" * 100, "more_body": True}, + {"type": "http.request", "body": b"z" * 100, "more_body": False}, + ] + + async def send(msg): + pass + + async def receive(): + return queued.pop(0) if queued else {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + # First chunk passes (100 bytes < 150). Second chunk pushes total to 200, exceeds, returns disconnect. + assert received_by_inner[0]["body"] == b"x" * 100 + assert received_by_inner[1]["type"] == "http.disconnect" + + +def test_body_cap_latch_holds_across_repeated_receives(): + """Regression: once exceeded, every subsequent receive() returns disconnect. + + Without the latch, an inner app that drains receive() multiple times after + seeing http.disconnect could read more body bytes from the underlying socket, + bypassing the cap. + """ + received_by_inner: list[dict] = [] + + async def inner(scope, receive, send): + # Drain 5 times, well past the disconnect + for _ in range(5): + received_by_inner.append(await receive()) + + mw = BodySizeLimitMiddleware(inner, max_bytes=50) + scope = _http_scope(headers=[]) + + queued = [ + {"type": "http.request", "body": b"x" * 100, "more_body": True}, # exceeds immediately + {"type": "http.request", "body": b"y" * 100, "more_body": True}, # would exceed again if reached + {"type": "http.request", "body": b"z" * 100, "more_body": False}, + ] + + async def send(msg): + pass + + async def receive(): + return queued.pop(0) if queued else {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + # First call: real chunk (100 bytes), exceeds → returns disconnect + assert received_by_inner[0]["type"] == "http.disconnect" + # Subsequent calls: latch keeps returning disconnect, never forwards real chunks + for msg in received_by_inner[1:]: + assert msg["type"] == "http.disconnect" + + +def test_body_cap_get_request_bypasses(): + """GET requests skip the cap — no body to inspect.""" + received_by_inner: list[dict] = [] + + async def inner(scope, receive, send): + received_by_inner.append("called") + + mw = BodySizeLimitMiddleware(inner, max_bytes=100) + scope = _http_scope(method="GET", headers=[(b"content-length", b"99999")]) + + async def send(msg): + pass + + async def receive(): + return {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + assert received_by_inner == ["called"] # inner ran despite huge Content-Length + + +def test_body_cap_non_http_scope_passes_through(): + """Lifespan/websocket scopes bypass entirely.""" + inner_called = False + + async def inner(scope, receive, send): + nonlocal inner_called + inner_called = True + + mw = BodySizeLimitMiddleware(inner, max_bytes=10) + scope = {"type": "lifespan"} + + async def send(msg): + pass + + async def receive(): + return {"type": "lifespan.shutdown"} + + asyncio.run(mw(scope, receive, send)) + + assert inner_called + + +def test_body_cap_malformed_content_length_falls_through_to_streaming(): + """Non-numeric Content-Length doesn't crash; streaming guard still applies.""" + received_by_inner: list[dict] = [] + + async def inner(scope, receive, send): + received_by_inner.append(await receive()) + + mw = BodySizeLimitMiddleware(inner, max_bytes=50) + scope = _http_scope(headers=[(b"content-length", b"not-a-number")]) + + queued = [{"type": "http.request", "body": b"x" * 100, "more_body": False}] + + async def send(msg): + pass + + async def receive(): + return queued.pop(0) if queued else {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + # Streaming guard catches the oversized body + assert received_by_inner[0]["type"] == "http.disconnect" + + +# -------------------------- SecurityHeadersMiddleware -------------------------- + + +def test_security_headers_added_to_response_start(): + """All configured headers are injected on http.response.start.""" + async def inner(scope, receive, send): + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + + mw = SecurityHeadersMiddleware( + inner, + hsts="max-age=63072000", + csp="frame-ancestors 'none'", + referrer="no-referrer", + nosniff=True, + ) + scope = _http_scope(method="GET") + sent: list[dict] = [] + + async def send(msg): + sent.append(msg) + + async def receive(): + return {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + headers = dict(sent[0]["headers"]) + assert headers[b"strict-transport-security"] == b"max-age=63072000" + assert headers[b"content-security-policy"] == b"frame-ancestors 'none'" + assert headers[b"referrer-policy"] == b"no-referrer" + assert headers[b"x-content-type-options"] == b"nosniff" + + +def test_security_headers_preserve_handler_set_value(): + """If the handler already sets CSP, the middleware does not override it. + + Case-insensitive: handler-set 'Content-Security-Policy' wins over middleware's lowercase form. + """ + async def inner(scope, receive, send): + await send({ + "type": "http.response.start", + "status": 200, + "headers": [(b"Content-Security-Policy", b"default-src 'self'")], + }) + + mw = SecurityHeadersMiddleware( + inner, + hsts=None, + csp="frame-ancestors 'none'", + referrer="no-referrer", + nosniff=True, + ) + scope = _http_scope(method="GET") + sent: list[dict] = [] + + async def send(msg): + sent.append(msg) + + async def receive(): + return {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + csp_values = [v for k, v in sent[0]["headers"] if k.lower() == b"content-security-policy"] + assert csp_values == [b"default-src 'self'"] + + +def test_security_headers_skip_hsts_when_none(): + """hsts=None → no HSTS header emitted (the API_TLS_ENABLED=False default path).""" + async def inner(scope, receive, send): + await send({"type": "http.response.start", "status": 200, "headers": []}) + + mw = SecurityHeadersMiddleware( + inner, hsts=None, csp="frame-ancestors 'none'", referrer="no-referrer", nosniff=True, + ) + scope = _http_scope(method="GET") + sent: list[dict] = [] + + async def send(msg): + sent.append(msg) + + async def receive(): + return {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + header_names = {k.lower() for k, _ in sent[0]["headers"]} + assert b"strict-transport-security" not in header_names + + +def test_security_headers_non_http_scope_passes_through(): + """Lifespan and other non-http scopes are unmodified.""" + inner_called = False + + async def inner(scope, receive, send): + nonlocal inner_called + inner_called = True + + mw = SecurityHeadersMiddleware( + inner, hsts=None, csp="frame-ancestors 'none'", referrer="no-referrer", nosniff=True, + ) + + async def send(msg): + pass + + async def receive(): + return {"type": "lifespan.shutdown"} + + asyncio.run(mw({"type": "lifespan"}, receive, send)) + + assert inner_called diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index aea93451..eb444028 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -4,10 +4,12 @@ from enum import Enum from uuid import UUID +import pandas as pd import pytest from testgen.utils import ( chunk_queries, + dataframe_to_json_records, friendly_score, friendly_score_impact, get_exception_message, @@ -124,6 +126,16 @@ def test_make_json_safe_datetime(): assert make_json_safe(dt) == int(dt.timestamp()) +def test_make_json_safe_nat(): + assert make_json_safe(pd.NaT) is None + + +@pytest.mark.parametrize("dt", [datetime(1, 1, 1), datetime(9999, 12, 31)]) +def test_make_json_safe_out_of_nanosecond_range_datetime(dt): + # Datetimes outside pandas' nanosecond Timestamp range (1677..2262) must still serialize. + assert make_json_safe(dt) == int(dt.replace(tzinfo=UTC).timestamp()) + + def test_make_json_safe_decimal(): assert make_json_safe(Decimal("3.14")) == 3.14 @@ -152,6 +164,26 @@ def test_make_json_safe_passthrough(): assert make_json_safe(None) is None +# --- dataframe_to_json_records --- + +def test_dataframe_to_json_records_empty(): + assert dataframe_to_json_records(pd.DataFrame()) == [] + + +def test_dataframe_to_json_records_handles_out_of_range_dates_and_nulls(): + # Rows mixing out-of-nanosecond-range datetimes with NaT/NaN must serialize without overflow. + df = pd.DataFrame([ + {"id": "1", "min_date": datetime(1, 1, 1), "max_date": datetime(9999, 12, 31), "frac": 1.5}, + {"id": "2", "min_date": datetime(2020, 6, 1), "max_date": pd.NaT, "frac": None}, + ]) + records = dataframe_to_json_records(df) + + assert records[0]["min_date"] == int(datetime(1, 1, 1, tzinfo=UTC).timestamp()) + assert records[0]["max_date"] == int(datetime(9999, 12, 31, tzinfo=UTC).timestamp()) + assert records[1]["max_date"] is None + assert records[1]["frac"] is None + + # --- chunk_queries --- def test_chunk_queries_fits_in_one(): diff --git a/tests/unit/ui/services/__init__.py b/tests/unit/ui/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/ui/services/test_query_cache.py b/tests/unit/ui/services/test_query_cache.py new file mode 100644 index 00000000..aefdcb60 --- /dev/null +++ b/tests/unit/ui/services/test_query_cache.py @@ -0,0 +1,60 @@ +"""Wiring tests for query_cache.py wrappers. + +Verifies that each cached UI wrapper exists, is callable, and exposes ``.clear()`` +for targeted cache invalidation. Does NOT exercise Streamlit cache logic itself. +""" + +from __future__ import annotations + +import pytest + +from testgen.ui.services import query_cache + +# Wrappers that replace cached calls to model methods after TG-1091. +# Names match the per-(entity, method) convention documented in the spec. +EXPECTED_WRAPPERS = [ + # Connection + "get_connection", + "select_connections_where", + "get_connection_minimal", + "select_connections_minimal_where", + # User + "get_user", + "select_users_where", + # TableGroup + "get_table_group", + "get_table_group_minimal", + "select_table_groups_minimal_where", + # TestSuite + "get_test_suite", + "get_test_suite_minimal", + "select_test_suites_minimal_where", + # TestRun + "get_test_run_minimal", + "select_test_runs_where", + # ProfilingRun + "get_profiling_run_minimal", + "select_profiling_runs_where", + "select_profiling_runs_minimal_where", + # TestDefinition + "get_test_definition", + "select_test_definitions_where", + "select_test_definitions_minimal_where", + "select_test_definitions_page", + # Project + "get_project", + "select_projects_where", + # ProjectMembership + "get_project_membership", + "select_project_memberships_where", +] + + +@pytest.mark.parametrize("name", EXPECTED_WRAPPERS) +def test_wrapper_exists_and_is_cached(name: str) -> None: + wrapper = getattr(query_cache, name, None) + assert wrapper is not None, f"Missing wrapper: query_cache.{name}" + assert callable(wrapper), f"Wrapper is not callable: query_cache.{name}" + assert hasattr(wrapper, "clear"), ( + f"Wrapper missing .clear() (cache decorator dropped?): query_cache.{name}" + ) diff --git a/tests/unit/ui/test_project_settings.py b/tests/unit/ui/test_project_settings.py index e38f4488..d5acb6b2 100644 --- a/tests/unit/ui/test_project_settings.py +++ b/tests/unit/ui/test_project_settings.py @@ -2,6 +2,7 @@ import pytest +from testgen.common.enums import JobKey from testgen.ui.views.project_settings import ProjectSettingsPage pytestmark = pytest.mark.unit @@ -19,11 +20,13 @@ def mock_session(): yield session -def _make_page(use_dq_score_weights=True): +def _make_page(use_dq_score_weights=True, data_retention_enabled=False, data_retention_days=None): page = ProjectSettingsPage.__new__(ProjectSettingsPage) page.project = MagicMock() page.project.use_dq_score_weights = use_dq_score_weights page.project.project_name = "My Project" + page.project.data_retention_enabled = data_retention_enabled + page.project.data_retention_days = data_retention_days return page @@ -34,9 +37,9 @@ def test_update_project_submits_recalculate_job_when_weights_toggled_on(mock_ses page.update_project("proj", {"name": "My Project", "use_dq_score_weights": True}) mock_je.submit.assert_called_once_with( - job_key="recalculate-project-scores", + job_key=JobKey.recalculate_project_scores, kwargs={"project_code": "proj"}, - source="user", + source="ui", project_code="proj", ) @@ -48,9 +51,9 @@ def test_update_project_submits_recalculate_job_when_weights_toggled_off(mock_se page.update_project("proj", {"name": "My Project", "use_dq_score_weights": False}) mock_je.submit.assert_called_once_with( - job_key="recalculate-project-scores", + job_key=JobKey.recalculate_project_scores, kwargs={"project_code": "proj"}, - source="user", + source="ui", project_code="proj", ) @@ -81,8 +84,85 @@ def test_update_project_raises_on_duplicate_name(mock_session): ] with ( - patch(f"{MODULE}.Project") as mock_project_cls, + patch(f"{MODULE}.select_projects_where") as mock_select, pytest.raises(ValueError, match="Other Project"), ): - mock_project_cls.select_where.return_value = [MagicMock(project_name="Other Project")] + mock_select.return_value = [MagicMock(project_name="Other Project")] page.update_project("proj", {"name": "Other Project", "use_dq_score_weights": True}) + + +# ─── Data retention ────────────────────────────────────────────────── + + +def test_update_project_upserts_schedule_when_retention_enabled(mock_session): + page = _make_page(data_retention_enabled=False) + payload = { + "name": "My Project", + "use_dq_score_weights": True, + "data_retention_enabled": True, + "data_retention_days": 90, + "retention_cron_expr": "0 2 * * *", + "retention_cron_tz": "America/New_York", + } + + with ( + patch(f"{MODULE}.JobExecution"), + patch(f"{MODULE}.JobSchedule") as mock_schedule, + ): + page.update_project("proj", payload) + + mock_schedule.upsert_for_retention.assert_called_once_with( + project_code="proj", + retention_days=90, + cron_expr="0 2 * * *", + cron_tz="America/New_York", + ) + mock_schedule.delete_for_retention.assert_not_called() + assert page.project.data_retention_enabled is True + assert page.project.data_retention_days == 90 + + +def test_update_project_deletes_schedule_when_retention_disabled(mock_session): + """No-op cleanup contract: disabling retention removes the schedule so the + cleanup job never fires for this project.""" + page = _make_page(data_retention_enabled=True, data_retention_days=180) + payload = { + "name": "My Project", + "use_dq_score_weights": True, + "data_retention_enabled": False, + } + + with ( + patch(f"{MODULE}.JobExecution"), + patch(f"{MODULE}.JobSchedule") as mock_schedule, + ): + page.update_project("proj", payload) + + mock_schedule.delete_for_retention.assert_called_once_with("proj") + mock_schedule.upsert_for_retention.assert_not_called() + assert page.project.data_retention_enabled is False + # When disabled the days column is nulled out (matches the migration's nullable column). + assert page.project.data_retention_days is None + + +def test_update_project_uses_default_days_when_missing(mock_session): + """Enabling retention without an explicit days value falls back to the page's + DEFAULT_RETENTION_DAYS constant (180) so the schedule is still well-formed.""" + page = _make_page(data_retention_enabled=False) + payload = { + "name": "My Project", + "use_dq_score_weights": True, + "data_retention_enabled": True, + # data_retention_days omitted + "retention_cron_expr": "0 1 * * *", + "retention_cron_tz": "UTC", + } + + with ( + patch(f"{MODULE}.JobExecution"), + patch(f"{MODULE}.JobSchedule") as mock_schedule, + ): + page.update_project("proj", payload) + + kwargs = mock_schedule.upsert_for_retention.call_args.kwargs + assert kwargs["retention_days"] == 180